In [1]:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.distributions.categorical import Categorical
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

torch.manual_seed(1)

import numpy as np
import pickle as pkl
np.set_printoptions(threshold=sys.maxsize)

from tqdm.notebook import tqdm
import networkx as nx

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem

%matplotlib notebook
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff

from IPython.display import SVG

from collections import defaultdict
import multiprocessing as mp
from statistics import stdev, mean
from pathlib import Path
import pickle as pkl
from sklearn.preprocessing import MinMaxScaler



In [2]:
NUM_CPU = 5

In [3]:
def save(thing, path):
    with open(path, 'wb') as fp:
        pkl.dump(thing, fp)
    print(f'saved [{type(thing)}] to {path}')

# Data

In [4]:
!ls data/lipinski

active_lipinski.smi             druglike_lipinski_1k.smi
all_lipinski.smi                druglike_lipinski_50k.test.smi
[1m[36mby_num_atoms[m[m                    druglike_lipinski_50k.train.smi
druglike_lipinski.smi           gcpn_smiles.500.smi
druglike_lipinski_100k.smi      lstm_smiles.500.smi
druglike_lipinski_10k.smi       very_active_lipinski.smi


In [5]:
data_file = Path('./data/lipinski/druglike_lipinski_100k.smi')

In [6]:
GO_TOKEN = 'G'
END_TOKEN = '\n'
PAD_TOKEN = 'A'

In [7]:
ALL_SMILES = []
with data_file.open() as fp:
    for line in fp.readlines():
        ALL_SMILES += [GO_TOKEN + line]
NUM_TOTAL = len(ALL_SMILES)

In [8]:
TEST_SPLIT = 0.15
split = int(NUM_TOTAL*TEST_SPLIT)
TRAIN_SMILES = ALL_SMILES[split:]
TEST_SMILES = ALL_SMILES[:split]

In [9]:
len(TRAIN_SMILES), len(TEST_SMILES)

(85000, 15000)

In [10]:
alphabet = list(set(''.join(ALL_SMILES+[PAD_TOKEN])))
NUM_SYM = len(alphabet)

In [11]:
SYM_TO_ID = {s:i for i,s in enumerate(alphabet)}
ID_TO_SYM = {i:s for s,i in SYM_TO_ID.items()}

In [12]:
SYM_TO_ID

{'l': 0,
 'B': 1,
 '\\': 2,
 '6': 3,
 'P': 4,
 'O': 5,
 '[': 6,
 'p': 7,
 '7': 8,
 '8': 9,
 '(': 10,
 '=': 11,
 '3': 12,
 '#': 13,
 '+': 14,
 'n': 15,
 's': 16,
 'o': 17,
 'C': 18,
 '/': 19,
 '@': 20,
 '5': 21,
 'A': 22,
 '1': 23,
 'H': 24,
 'r': 25,
 '-': 26,
 '\n': 27,
 'I': 28,
 ']': 29,
 'S': 30,
 ')': 31,
 'F': 32,
 '4': 33,
 'c': 34,
 'N': 35,
 'G': 36,
 '2': 37}

In [13]:
# MAX_ATOMS = max(map(lambda s: Chem.MolFromSmiles(s).GetNumAtoms(), D.keys()))
MAX_SYM = max(map(len, ALL_SMILES)) # GO_TOKEN and END_TOKEN already considered 

In [14]:
MAX_SYM, NUM_TOTAL

(103, 100000)

In [15]:
def encode(smiles):
    """
    encode(simles): 
        - takes in a variable length smiles string (up to MAX_SYM) 
          and outputs a fixed size vector (MAX_SYM by NUM_SYM)
        
    """
    x = np.zeros((MAX_SYM, NUM_SYM))
    x_n = len(smiles)
    for i, sym in enumerate(smiles):
        x[i, SYM_TO_ID[sym]] = 1
    x[x_n:, SYM_TO_ID[PAD_TOKEN]] = 1
    return x, x_n

In [16]:
def decode(x):
    assert x.shape[1] == NUM_SYM
    smiles = ''
    for i in range(x.shape[0]):
        topi = np.argmax(x[i, :])
        smiles += ID_TO_SYM[topi]
    return smiles
def decode_short(x):
    s = decode(x)
    return  s[:s.find(END_TOKEN)]

In [17]:
decode(encode(ALL_SMILES[0])[0])

'GCc1onc(NC(=O)c2ccc(Cl)cc2Cl)c1Br\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'

In [18]:
decode_short(encode(ALL_SMILES[0])[0])

'GCc1onc(NC(=O)c2ccc(Cl)cc2Cl)c1Br'

In [19]:
class SMILESDataset(torch.utils.data.Dataset):
    
    def __init__(self, smiles):
        self.all_smiles = smiles
        self.size = len(smiles)

    def __len__(self):
        return self.size
    
    def __getitem__(self, i):
        """
        __getitem__(self, i):
            - return 
        """
        x_i, x_n = encode(self.all_smiles[i])
        
        y_i = x_i[1:].copy()
        x_i = x_i[:-1].copy()
        
        return x_i, y_i, x_n
    

In [20]:
train_dataset = SMILESDataset(TRAIN_SMILES)
test_dataset = SMILESDataset(TEST_SMILES)

In [21]:
xi, yi, xlen = train_dataset[4]

In [22]:
decode(xi), decode(yi)

('GCC(C)Oc1ccc(cc1)c2cc(NCC(O)CO)c3ccccc3n2\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA',
 'CC(C)Oc1ccc(cc1)c2cc(NCC(O)CO)c3ccccc3n2\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA')

In [23]:
xi, yi, xlen = test_dataset[4]

In [24]:
decode(xi), decode(yi)

('GFC(F)(F)Oc1cccc(NC(=O)c2oc(Br)cc2)c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA',
 'FC(F)(F)Oc1cccc(NC(=O)c2oc(Br)cc2)c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA')

In [25]:
BATCH_SIZE = 52

In [26]:
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_CPU, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=len(TEST_SMILES), num_workers=NUM_CPU, shuffle=True)

# Model

In [27]:
# plain LSTM model
class GeneratorLSTM(nn.Module):
    def __init__(self, input_size, output_size, num_layers, hidden_size, embedding_size):
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size=embedding_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True,
                            dropout=0.15)
        
        self.input_module = nn.Sequential(nn.Linear(input_size, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.10),
                                         nn.Linear(256, 512),
                                         nn.ReLU(),
                                          nn.Dropout(0.10),
                                         nn.Linear(512, embedding_size),
                                         nn.ReLU())
        
        self.output_module = nn.Sequential(nn.Linear(hidden_size, 256),
                                           nn.Dropout(0.10),
                                           nn.ReLU(),
                                           nn.Linear(256, output_size))

        self.hidden = None 

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)),
                Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size)))

    def forward(self, input_raw, pack=False, input_lens=None):
        """
        forward(self, input_raw, state):
             - input_raw = (bs, x_{i,t})
             - state = (ht, ct)
        """
        input_ = self.input_module(input_raw)
        
        if pack:
            packed_input = pack_padded_sequence(input_, input_lens, batch_first=True)
            input_, self.hidden = self.lstm(packed_input, self.hidden)
            input_ = pad_packed_sequence(input_, batch_first=True)[0]
            
        input_ = self.output_module(input_)
        return input_

In [28]:
model = GeneratorLSTM(input_size = NUM_SYM, 
                      output_size = NUM_SYM, 
                      num_layers = 3, 
                      hidden_size = 512, 
                      embedding_size = 512)

In [29]:
model_optim = optim.Adam(list(model.parameters()), lr=3e-3)

In [30]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

GeneratorLSTM(
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True, dropout=0.15)
  (input_module): Sequential(
    (0): Linear(in_features=38, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
  )
  (output_module): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=38, bias=True)
  )
)

In [31]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'model has [{count_parameters(model):,}] trainable params')

model has [6,849,062] trainable params


## Train

In [32]:
def train_one_epoch(model, model_optim, dataloader, 
                    test_interval):
    
    criterion = nn.NLLLoss()
    loss_history = []
    total = len(dataloader)
    
    test_loss_history = defaultdict(list)
    new_smiles = defaultdict(list)
    
    
    for batch_idx, batch in tqdm(enumerate(dataloader), total=total):
        
        model.train()
        model.zero_grad()
    
        x_batch, y_batch, batch_lens = batch
        
        batch_size = x_batch.size(0)
        max_len = int(max(batch_lens).item())
        
        x_batch = x_batch[:, 0:max_len, :]
        y_batch = y_batch[:, 0:max_len, :]

        # sort input
        batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
        batch_len_sorted = batch_len_sorted.numpy().tolist()
        
        x_batch = torch.index_select(x_batch, 0, sort_index)
        y_batch = torch.index_select(y_batch, 0, sort_index)

        x_batch = Variable(x_batch.float())
        y_batch = Variable(y_batch.float())

        # init state
        model.hidden = model.init_hidden(batch_size=x_batch.size(0))
        
        y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)
        
        y_pred = F.log_softmax(y_pred.view(-1, NUM_SYM), dim=-1)
        _, y_batch = y_batch.topk(1, dim=-1)
        y_batch = y_batch.view(-1)
        loss = criterion(y_pred, y_batch)
        
        loss.backward()
        model_optim.step()
        
        if batch_idx % test_interval == 0:
            new_smiles[f'[batch num {batch_idx}] topk'] = generate(model)
            new_smiles[f'[batch num {batch_idx}] softmax sample'] = generate(model, sample_f=softmax_temp_sample)

            test_loss_history[batch_idx] = test(model, testloader)
            print(f'[batch num: {batch_idx}] sampled')
            
            path = Path(f'./data/lipinski/results/lstm-1-epoch-{batch_idx}-batch-{loss.data.item():.4f}-loss')
            path.mkdir(exist_ok=True)
            
            torch.save(model.state_dict(), str(path / 'model_dict.torch'))
            save(new_smiles, str(path / 'generated_smiles.dict' ))
            save(test_loss_history, str(path / 'test_loss.dict' ))

        
        print(f'[batch num: {batch_idx}] loss: {loss.data.item():.4f}')
                
    return loss_history, test_loss_history, new_smiles

## Testing and Generation

In [33]:
def softmax_temp_sample(y_t, temperature = 1.0):
    prediction_vector = F.softmax(y_t / temperature, dim=-1)
    x_index_t = torch.multinomial(prediction_vector, 1)[:, 0]
    return x_index_t

In [34]:
def topk_sample(y_t):
    _, pred_idx =  y_t.topk(1, dim=-1)
    return pred_idx

In [35]:
def generate(model, test_samples=5, sample_f=topk_sample, v=True):
    model.eval()
    
    x = torch.zeros(test_samples, MAX_SYM, NUM_SYM)
    
    x[:, 0, SYM_TO_ID[GO_TOKEN]] = 1
    
    for i in range(MAX_SYM-1):
        pred = model(x)
        pred_idx = softmax_temp_sample(pred[:,i,:])
        
        temp = torch.zeros(test_samples, MAX_SYM, NUM_SYM)
        temp[:, i+1, pred_idx] = 1  
        x.add_(temp)
    
    if v: print('\n',10*'-' + 'GENERATED SMILES STRINGS' + 10*'-')
    smiles = []
    
    for j in range(test_samples):
        s = decode_short(x[j].numpy())
        smiles += [s]
        
        if v: print(s)
            
    return smiles
        
    

In [36]:
def test(model, testloader):
    model.eval()
    
    criterion = nn.NLLLoss()
    total_loss = 0
    
    for batch in testloader:
        
        x_batch, y_batch, batch_lens = batch
        
        batch_size = x_batch.size(0)
        max_len = int(max(batch_lens).item())
        
        x_batch = x_batch[:, 0:max_len, :]
        y_batch = y_batch[:, 0:max_len, :]

        # sort input
        batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
        batch_len_sorted = batch_len_sorted.numpy().tolist()
        
        x_batch = torch.index_select(x_batch, 0, sort_index)
        y_batch = torch.index_select(y_batch, 0, sort_index)

        x_batch = Variable(x_batch.float())
        y_batch = Variable(y_batch.float())

        # init state
        model.hidden = model.init_hidden(batch_size=x_batch.size(0))
        
        y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)
        
        y_pred = y_pred.view(-1, NUM_SYM)
        _, y_batch = y_batch.topk(1, dim=-1)
        y_batch = y_batch.view(-1)
        
        total_loss += criterion(y_pred, y_batch)
                
    return total_loss / len(test_dataset)
    

## Results

In [None]:
train_history, train_history, new_smiles = train_one_epoch(model, model_optim, dataloader, test_interval=200)

HBox(children=(IntProgress(value=0, max=1635), HTML(value='')))


 ----------GENERATED SMILES STRINGS----------
G667\pp/O38#s\p\OBOsB\[sO\pr#B\\p\=n\P6O=7OOOP8n[P\6BBln(O6(l8l7O\p#ppn[\Ol\\Bln[OOlBBOp[PlB6B7O8B(BBO
G667\pp/O38#s\p\OBOsB\[sO\pr#B\\p\=n\P6O=7OOOP8n[P\6BBln(O6(l8l7O\p#ppn[\Ol\\Bln[OOlBBOp[PlB6B7O8B(BBO
G667\pp/O38#s\p\OBOsB\[sO\pr#B\\p\=n\P6O=7OOOP8n[P\6BBln(O6(l8l7O\p#ppn[\Ol\\Bln[OOlBBOp[PlB6B7O8B(BBO
G667\pp/O38#s\p\OBOsB\[sO\pr#B\\p\=n\P6O=7OOOP8n[P\6BBln(O6(l8l7O\p#ppn[\Ol\\Bln[OOlBBOp[PlB6B7O8B(BBO
G667\pp/O38#s\p\OBOsB\[sO\pr#B\\p\=n\P6O=7OOOP8n[P\6BBln(O6(l8l7O\p#ppn[\Ol\\Bln[OOlBBOp[PlB6B7O8B(BBO

 ----------GENERATED SMILES STRINGS----------
G=OB((68lBP=n@l[B8=BPO\BOC6[l7P=OOBP\l\PBlBl(@n\l35s6BpB3Pl6lP6llP3=\B[8\67\\p88P7\Bl6Pl666BB=B6lBP\P/
G=OB((68lBP=n@l[B8=BPO\BOC6[l7P=OOBP\l\PBlBl(@n\l35s6BpB3Pl6lP6llP3=\B[8\67\\p88P7\Bl6Pl666BB=B6lBP\P/
G=OB((68lBP=n@l[B8=BPO\BOC6[l7P=OOBP\l\PBlBl(@n\l35s6BpB3Pl6lP6llP3=\B[8\67\\p88P7\Bl6Pl666BB=B6lBP\P/
G=OB((68lBP=n@l[B8=BPO\BOC6[l7P=OOBP\l\PBlBl(@n\l35s6BpB3Pl6lP6llP3=\B[8\67\\p88P7

In [None]:
def plot_history(history, l=-1):
    if l < 0:
        l = len(history)
    fig = px.line(x=np.arange(l), y=history[:l], labels={'x':'batch number', 'y':'binary cross-entropy loss'})
    return fig

In [None]:
plot_history(history)