In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch 
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.distributions.categorical import Categorical

from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange
from datetime import datetime

In [2]:
TRAIN_FILE = '/kaggle/input/smiles/smiles_train.txt'
MODEL_FILE_OUT = f"/kaggle/working/model-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.pth"
MODEL_FILE_IN = f'/kaggle/input/model1/model-2022_05_02_02_21_14.pth'
SUBMISSION_FILE = f"/kaggle/working/submission-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.txt"
TRAIN_FROM_SCRATCH = False
SOS = 'A'
EOS = 'Z'
SEED = 2022 
TEST_SIZE = 0.1

In [3]:
def set_seed(seed):
    pass

## Exploration

In [4]:
with open(TRAIN_FILE) as f:
    smiles = f.readlines()

In [5]:
# first visual inspection
print(smiles[:30])

In [6]:
# number of training samples
print(f'Number of smiles strings ... {len(smiles)}')

In [7]:
# distribution of smiles strings lengths
smiles_lengths = [len(s) for s in smiles]
plt.hist(smiles_lengths)

In [8]:
# does every smiles string end with \n
sum('\n' not in s for s in smiles)

In [9]:
def get_alphabet(string_list, add_eos=False, add_sos=False):
    chars = set(''.join(string_list))
    chars = sorted(chars)
    if add_eos:
        chars = [c if c!='\n' else EOS for c in chars]
    if add_sos:
        chars = [SOS] + chars
    return chars

alphabet = get_alphabet(smiles)
print(alphabet)

## Preparation

In [10]:
alphabet = get_alphabet(smiles, add_eos=True, add_sos=True)
print(alphabet)

In [11]:
def prepare_smiles(smiles, sos_token, eos_token):
    return ''.join([sos_token, smiles.replace('\n', eos_token)])

smiles_eos = [prepare_smiles(s, SOS, EOS) for s in smiles]
print(smiles_eos[0])

In [12]:
smiles_train, smiles_val = train_test_split(smiles_eos, test_size=TEST_SIZE, random_state=SEED)
print(f'Training set size ... {len(smiles_train)}')
print(f'Validation set size ... {len(smiles_val)}')

In [13]:
class Encoder:
    def __init__(self, alphabet):
        self.alphabet = alphabet

    def smiles2indices(self, smiles):
        smiles = [s for s in smiles if s in self.alphabet]
        indices = [self.alphabet.index(s) for s in smiles]
        return indices
    
    def indices2smiles(self, indices):
        indices = [i for i in indices if i < len(self.alphabet)]
        smiles = [self.alphabet[i] for i in indices]
        return ''.join(smiles)

    def __call__(self, x):
        return self.smiles2indices(x) if type(x) is str else self.indices2smiles(x)

In [14]:
class SmilesDataset(Dataset):
    def __init__(self, smiles, encoder):
        super().__init__()
        self.smiles = smiles
        self.encoder = encoder

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        x = self.encoder.smiles2indices(self.smiles[idx])
        t = len(x)
        return x, t

In [15]:
class BidirectionalSmilesCollate:
    def __call__(self, batch):
        lengths = [t for (x, t) in batch]
        max_length = max(lengths)
        x_for = [np.pad(x, (0, max_length-t), constant_values=EOS_IND) for (x, t) in batch]
        x_rev = [np.pad(list(reversed(x)), (max_length-t, 0), constant_values=SOS_IND) for (x, t) in batch]
        t = [t for (x, t) in batch]
        return (
            torch.tensor(x_for, dtype=torch.long),
            torch.tensor(x_rev, dtype=torch.long),
            torch.tensor(t, dtype=torch.long)
        )

## Model

In [16]:
class Embedding(nn.Module):
    def __init__(self, alphabet_size, embedding_dim):
        super().__init__()
        self.alphabet_size = alphabet_size
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(alphabet_size, embedding_dim)

    def forward(self, x):
        embedding = self.embedding(x)
        return embedding

In [17]:
class EncoderLSTM(nn.Module):
    def __init__(self, embedding, hidden_dim, latent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.embedding = embedding
        self.forward_lstm = nn.LSTM(embedding.embedding_dim, hidden_dim, batch_first=True)
        self.reverse_lstm = nn.LSTM(embedding.embedding_dim, hidden_dim, batch_first=True)
        self.mu = nn.Linear(2*hidden_dim, latent_dim)
        self.log_var = nn.Linear(2*hidden_dim, latent_dim)

    def forward(self, x_for, x_rev, t):
        embeddings_for = self.embedding(x_for)
        embeddings_rev = self.embedding(x_rev)
        outputs_for, (h_l_for, c_l_for) = self.forward_lstm(embeddings_for)
        outputs_rev, (h_l_rev, c_l_rev) = self.reverse_lstm(embeddings_rev)
        features_for = torch.gather(outputs_for, dim=1, index=(t-1).reshape(-1, 1).unsqueeze(-1).repeat_interleave(self.hidden_dim, dim=2)).squeeze(dim=1)
        features_rev = torch.gather(outputs_rev, dim=1, index=(t-1).reshape(-1, 1).unsqueeze(-1).repeat_interleave(self.hidden_dim, dim=2)).squeeze(dim=1)
        features = torch.cat((features_for, features_rev), dim=1)
        mu = self.mu(features)
        log_var = self.log_var(features)
        return mu, log_var

In [18]:
class DecoderLSTM(nn.Module):
    def __init__(self, embedding, hidden_dim, latent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = embedding
        self.alphabet_size = embedding.alphabet_size
        self.lstm = nn.LSTM(embedding.embedding_dim + latent_dim, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim, self.alphabet_size)

    def forward(self, latent_code, last_token, hidden, cell):
        embedding = self.embedding(last_token)
        input = torch.cat((embedding, latent_code), dim=1).unsqueeze(dim=1)
        output, (hidden, cell) = self.lstm(input, (hidden, cell))
        logits = self.linear(output).squeeze(dim=1)
        return logits, hidden, cell

    def init_hidden(self, N):
        last_token = SOS_IND * torch.ones(N).long()
        last_hidden = torch.zeros((1, N, self.hidden_dim))
        last_cell = torch.zeros((1, N, self.hidden_dim))
        return last_token, last_hidden, last_cell

In [19]:
class LSTMVAE(nn.Module):
    def __init__(self, alphabet_size, embedding_dim, hidden_dim, latent_dim, beta):
        super().__init__()
        self.alphabet_size = alphabet_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.beta = beta
        self.embedding = Embedding(alphabet_size, embedding_dim)
        self.encoder = EncoderLSTM(self.embedding, hidden_dim, latent_dim)
        self.decoder = DecoderLSTM(self.embedding, hidden_dim, latent_dim)

    def forward(self, x):
        pass

    def encode(self, x_for, x_rev, t):
        mu, log_var =  self.encoder(x_for, x_rev, t)
        std = torch.exp(log_var / 2)
        z = self.draw_z(mu, std)
        return z, mu, log_var

    @torch.no_grad()
    def decode(self, z, max_length, random=False):
        results = torch.zeros((z.shape[0], max_length))
        last_token, last_hidden, last_cell = self.decoder.init_hidden(z.shape[0])
        for i in range(max_length):
            logits, last_hidden, last_cell = self.decoder(z, last_token, last_hidden, last_cell)
            if random:
                dist = Categorical(logits=logits)
                last_token = dist.sample()
            else:
                last_token = torch.argmax(logits, dim=1)
            results[:, i] = last_token
        return results

    def draw_z(self, mu, std):
        q = torch.distributions.Normal(mu, std)
        return q.rsample()

    def kl_divergence(self, mu, log_var):
        kl_divs = -1 * torch.sum(1 + log_var - mu**2 - log_var.exp()) / 2
        return kl_divs.mean()

    def neg_elbo(self, x_for, x_rev, t, metric):
        z, mu, log_var = self.encode(x_for, x_rev, t)
        last_token, last_hidden, last_cell = self.decoder.init_hidden(x_for.shape[0])
        last_token, last_hidden, last_cell = last_token.to(self.device), last_hidden.to(self.device), last_cell.to(self.device)
        # init neg_elbo
        neg_elbo = 0
        # reconstruction loss
        for i in range(t.max()):
            logits, last_hidden, last_cell = self.decoder(z, last_token, last_hidden, last_cell)
            targets = x_for[:, i]
            neg_elbo += metric(logits, targets) * (i<t).long()
            last_token = targets
        neg_elbo /= t
        neg_elbo = neg_elbo.mean()
        # kl-divergence
        neg_elbo += self.beta * self.kl_divergence(mu, log_var)
        return neg_elbo
    
    @property
    def device(self):
        return next(lstm_vae.parameters()).device

In [20]:
@torch.enable_grad()
def train(model, dataloader, metric, optimizer, max_batches=100000, verbose=1000):
    errors = []
    model.train()
    device = model.device
    for i, (x_for, x_rev, t) in enumerate(tqdm(dataloader)):
        x_for, x_rev, t = x_for.to(device), x_rev.to(device), t.to(device)
        loss = model.neg_elbo(x_for, x_rev, t, metric)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        errors.append(loss.item())
        if i%verbose==0 and i>0:
            print(f'Training Loss ... {np.mean(errors[-verbose:])}')
        if i==max_batches-1: break
    return errors

@torch.no_grad()
def validate(model, dataloader, metric, max_batches=100000):
    errors = []
    model.eval()
    device = model.device
    for i, (x_for, x_rev, t) in enumerate(tqdm(dataloader)):
        x_for, x_rev, t = x_for.to(device), x_rev.to(device), t.to(device)
        loss = model.neg_elbo(x_for, x_rev, t, metric)
        errors.append(loss.item())
        if i==max_batches-1: break
    return errors

## Training

In [21]:
# optimization hyperparameters
epochs = 10
batch_size = 256
shuffle = True
learning_rate = 0.001
step_size = 5
gamma = 0.1
max_batches = 5000
verbose = 1000

# model hyperparameters
embedding_dim = 8
hidden_dim = 128
latent_dim = 64
beta = 1

In [22]:
alphabet = get_alphabet(smiles, add_eos=True, add_sos=True)
alphabet_size = len(alphabet)

encoder = Encoder(alphabet)
SOS_IND = encoder(SOS)[0] 
EOS_IND = encoder(EOS)[0]

train_dataset = SmilesDataset(smiles_train, encoder) 
val_dataset = SmilesDataset(smiles_val, encoder)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle, collate_fn=BidirectionalSmilesCollate())
val_dataloader = DataLoader(val_dataset, batch_size, collate_fn=BidirectionalSmilesCollate())

lstm_vae = LSTMVAE(alphabet_size, embedding_dim, hidden_dim, latent_dim, beta)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lstm_vae.to(device)
optimizer = Adam(lstm_vae.parameters(), learning_rate)
scheduler = StepLR(optimizer, step_size, gamma)
metric = nn.CrossEntropyLoss(reduction='none')

In [23]:
train_errors = []
val_errors = validate(lstm_vae, val_dataloader, metric)

print(f'Initial Validation Loss ... {np.mean(val_errors)}')
plt.plot(range(len(val_errors)), val_errors)
plt.title('Validation Errors')
plt.show()

In [24]:
if TRAIN_FROM_SCRATCH:
    
    for e in range(epochs):
        terrors = train(lstm_vae, train_dataloader, metric, optimizer, max_batches, verbose)
        train_errors.extend(terrors)
        print(f'Epoch {e} ... Train Loss {np.mean(terrors)}')

        verrors = validate(lstm_vae, val_dataloader, metric)
        val_errors.extend(verrors)
        print(f'Epoch {e} ... Validation Loss {np.mean(verrors)}')

        scheduler.step()

        plt.figure(1)
        plt.subplot(121)
        plt.plot(range(len(train_errors)), train_errors)
        plt.title('Training Errors')
        plt.subplot(122)
        plt.plot(range(len(val_errors)), val_errors)
        plt.title('Validation Errors')
        plt.show()

In [25]:
def save_model(model, model_path):
    torch.save(model.state_dict(), model_path)

def load_model(model_path):
    model = LSTMVAE(alphabet_size, embedding_dim, hidden_dim, latent_dim, beta)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

In [26]:
if TRAIN_FROM_SCRATCH:
    save_model(lstm_vae, MODEL_FILE_OUT)
else:
    lstm_vae = load_model(MODEL_FILE_IN)
    val_errors = validate(lstm_vae, val_dataloader, metric)
    print(f'Initial Validation Loss ... {np.mean(val_errors)}')

## Evaluation

In [27]:
lstm_vae.to(torch.device('cpu'))

with torch.no_grad():
    x_for, x_rev, t = next(iter(val_dataloader))
    z, mu, log_var = lstm_vae.encode(x_for, x_rev, t)
    reconstructions = lstm_vae.decode(z, max_length=100, random=True).long()

for i in range(min(8, batch_size)):
    print(encoder(x_for[i, :].tolist()))
    print(encoder(reconstructions[i, :].tolist()))
    print(mu[i, :])
    print(torch.exp(log_var)[i, :])
    print(100*'-')

## Generation

In [28]:
def clean_smiles(smiles):
    smiles = smiles.replace('A', '')
    if 'Z' in smiles:
        first_eos = smiles.index('Z')
        smiles = smiles[:first_eos]
    return smiles

In [29]:
def generate_smiles(model, n_samples, max_length, random=True):
    generated_smiles = []
    with torch.no_grad():
        Z = torch.randn(n_samples, latent_dim)
        results = model.decode(Z, max_length, random)
        for i in range(n_samples):
            indices = results[i, :].long().tolist()
            smiles = encoder(indices)
            smiles = clean_smiles(smiles)
            generated_smiles.append(smiles)
    return generated_smiles

In [30]:
def write_output(smiles, out_file):
    smiles = [s+'\n' for s in smiles]
    with open(out_file, 'w') as f:
        f.writelines(smiles)

In [31]:
n_samples = 100
times = 100
max_length = 110
random = True

In [32]:
generated_smiles = []
for _ in trange(times):
    new_smiles = generate_smiles(lstm_vae, n_samples, max_length, random)
    generated_smiles.extend(new_smiles)

In [33]:
generated_smiles[:10]

In [34]:
write_output(generated_smiles, SUBMISSION_FILE)