In [1]:
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.distributions.categorical import Categorical
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR

import numpy as np
import pickle as pkl

from tqdm.notebook import tqdm
import networkx as nx

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import QED

%matplotlib notebook
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff

from IPython.display import SVG

from collections import defaultdict
import multiprocessing as mp
from statistics import stdev, mean
from pathlib import Path
import pickle as pkl
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

In [2]:
NUM_CPU = 3
SEED = 3287450

In [3]:
torch.manual_seed(SEED)
torch.cuda.empty_cache()
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})
np.set_printoptions(threshold=sys.maxsize)

In [4]:
def save(thing, path):
    with open(path, 'wb') as fp:
        pkl.dump(thing, fp)
    print(f'saved [{type(thing)}] to {path}')

# Data

In [5]:
!ls ./data

lipinski


In [6]:
def load_smiles(path, test_split=0.15):
    all_smiles = []
    with open(path) as fp:
        for line in fp.readlines():
            all_smiles += [GO_TOKEN + line]
    return 
        

In [7]:
data_file = Path('./data/lipinski/druglike_lipinski_100k.smi')

In [8]:
GO_TOKEN = 'G'
END_TOKEN = '\n'
PAD_TOKEN = 'A'

In [9]:
ALL_SMILES = []
with data_file.open() as fp:
    for line in fp.readlines():
        ALL_SMILES += [GO_TOKEN + line]
NUM_TOTAL = len(ALL_SMILES)

In [10]:
TEST_SPLIT = 0.15
split = int(NUM_TOTAL*TEST_SPLIT)
TRAIN_SMILES = ALL_SMILES[split:]
TEST_SMILES = ALL_SMILES[:split]

In [11]:
len(TRAIN_SMILES), len(TEST_SMILES)

(85000, 15000)

In [12]:
alphabet = list(set(''.join(ALL_SMILES+[PAD_TOKEN])))
NUM_SYM = len(alphabet)

In [13]:
SYM_TO_ID = {s:i for i,s in enumerate(alphabet)}
ID_TO_SYM = {i:s for s,i in SYM_TO_ID.items()}

In [14]:
SYM_TO_ID

{')': 0,
 'C': 1,
 ']': 2,
 'A': 3,
 'F': 4,
 'l': 5,
 '#': 6,
 'O': 7,
 'o': 8,
 '1': 9,
 '3': 10,
 '-': 11,
 'N': 12,
 '4': 13,
 's': 14,
 '/': 15,
 'S': 16,
 'P': 17,
 '(': 18,
 'c': 19,
 '\n': 20,
 'I': 21,
 '+': 22,
 '6': 23,
 '7': 24,
 'B': 25,
 'p': 26,
 'r': 27,
 '2': 28,
 '=': 29,
 '8': 30,
 '@': 31,
 'H': 32,
 'n': 33,
 'G': 34,
 '[': 35,
 '\\': 36,
 '5': 37}

In [15]:
# MAX_ATOMS = max(map(lambda s: Chem.MolFromSmiles(s).GetNumAtoms(), D.keys()))
MAX_SYM = max(map(len, ALL_SMILES)) # GO_TOKEN and END_TOKEN already considered 

In [16]:
MAX_SYM, NUM_TOTAL

(103, 100000)

In [17]:
def encode(smiles):
    """
    encode(simles): 
        - takes in a variable length smiles string (up to MAX_SYM) 
          and outputs a fixed size vector (MAX_SYM by NUM_SYM)
        
    """
    x = np.zeros((MAX_SYM, NUM_SYM))
    x_n = len(smiles)
    for i, sym in enumerate(smiles):
        x[i, SYM_TO_ID[sym]] = 1
    x[x_n:, SYM_TO_ID[PAD_TOKEN]] = 1
    return x, x_n

In [18]:
def decode(x):
    assert x.shape[1] == NUM_SYM
    smiles = ''
    for i in range(x.shape[0]):
        topi = np.argmax(x[i, :])
        smiles += ID_TO_SYM[topi]
    return smiles if smiles[0] != 'G' else smiles[1:]

def decode_short(x):
    s = decode(x)
    return  s[:s.find(END_TOKEN)]

In [19]:
decode(encode(ALL_SMILES[0])[0])

'Cc1onc(NC(=O)c2ccc(Cl)cc2Cl)c1Br\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'

In [20]:
decode_short(encode(ALL_SMILES[0])[0])

'Cc1onc(NC(=O)c2ccc(Cl)cc2Cl)c1Br'

In [21]:
class SMILESDataset(torch.utils.data.Dataset):
    
    def __init__(self, smiles):
        self.all_smiles = smiles
        self.size = len(smiles)

    def __len__(self):
        return self.size
    
    def __getitem__(self, i):
        """
        __getitem__(self, i):
            - return 
        """
        x_i, x_n = encode(self.all_smiles[i])
        
        y_i = x_i[1:].copy()
        x_i = x_i[:-1].copy()
        
        return x_i, y_i, x_n
    

In [22]:
train_dataset = SMILESDataset(TRAIN_SMILES)
test_dataset = SMILESDataset(TEST_SMILES)

In [23]:
xi, yi, xlen = train_dataset[4]

In [24]:
decode(xi), decode(yi)

('CC(C)Oc1ccc(cc1)c2cc(NCC(O)CO)c3ccccc3n2\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA',
 'CC(C)Oc1ccc(cc1)c2cc(NCC(O)CO)c3ccccc3n2\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA')

In [25]:
xi, yi, xlen = test_dataset[4]

In [26]:
decode(xi), decode(yi)

('FC(F)(F)Oc1cccc(NC(=O)c2oc(Br)cc2)c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA',
 'FC(F)(F)Oc1cccc(NC(=O)c2oc(Br)cc2)c1\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA')

In [27]:
BATCH_SIZE = 20

In [28]:
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_CPU, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=10, num_workers=NUM_CPU, shuffle=True)

In [29]:
def count_valid(smiles):
    """
    count_valid(smiles):
        - returns % of valid smiles
    """
    invalid = 0
    
    for m in map(Chem.MolFromSmiles, smiles):
        invalid += m is None
        
    return 1 - (invalid/len(smiles))

# Model

In [30]:
# plain LSTM model
class GeneratorLSTM(nn.Module):
    def __init__(self, input_size, output_size, num_layers, hidden_size, embedding_size):
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size=embedding_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True,
                            dropout=0.15)
        
        self.input_module = nn.Sequential(nn.Linear(input_size, 256),
                                         nn.ReLU(),
                                         nn.Dropout(0.10),
                                         nn.Linear(256, 512),
                                         nn.ReLU(),
                                         nn.Dropout(0.10),
                                         nn.Linear(512, embedding_size),
                                         nn.ReLU())
        
        self.output_module = nn.Sequential(nn.Linear(hidden_size, 256),
                                           nn.Dropout(0.10),
                                           nn.ReLU(),
                                           nn.Linear(256, output_size))

        self.hidden = None 

    def init_hidden(self, batch_size, cuda=True):
        ht = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
        ct = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
        
        if cuda:
            return ht.cuda(), ct.cuda()
        else:
            return ht, ct

    def forward(self, input_raw, pack=False, input_lens=None):
        """
        forward(self, input_raw, state):
             - input_raw = (bs, x_{i,t})
             - state = (ht, ct)
        """
        input_ = self.input_module(input_raw)
        
        if pack:
            input_ = pack_padded_sequence(input_, input_lens, batch_first=True)
            
        input_, self.hidden = self.lstm(input_, self.hidden)
        
        if pack: 
            input_ = pad_packed_sequence(input_, batch_first=True)[0]

            
        input_ = self.output_module(input_)
        return input_

In [31]:
model = GeneratorLSTM(input_size = NUM_SYM, 
                      output_size = NUM_SYM, 
                      num_layers = 3, 
                      hidden_size = 512, 
                      embedding_size = 512)

In [32]:
model.cuda()

GeneratorLSTM(
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True, dropout=0.15)
  (input_module): Sequential(
    (0): Linear(in_features=38, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
  )
  (output_module): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=38, bias=True)
  )
)

In [33]:
model_optim = optim.Adam(list(model.parameters()), lr=3e-3/2)

In [34]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

GeneratorLSTM(
  (lstm): LSTM(512, 512, num_layers=3, batch_first=True, dropout=0.15)
  (input_module): Sequential(
    (0): Linear(in_features=38, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
  )
  (output_module): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=38, bias=True)
  )
)

In [35]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'model has [{count_parameters(model):,}] trainable params')

model has [6,849,062] trainable params


## Train

In [36]:
def train_one_epoch(model, model_optim, dataloader, 
                    test_interval):
    
    criterion = nn.NLLLoss()
    loss_history = []
    total = len(dataloader)
    
    test_loss_history = defaultdict(list)
    new_smiles = defaultdict(list)
    
    
    for batch_idx, batch in tqdm(enumerate(dataloader), total=total):
        
        if batch_idx % test_interval == 0:
            new_smiles[f'[batch num {batch_idx}] topk'] = generate(model)
            new_smiles[f'[batch num {batch_idx}] softmax sample'] = generate(model, sample_f=softmax_temp_sample)
            
            smiles_valid_batch = generate(model, test_samples=100, sample_f=softmax_temp_sample, v=False)
            validity = count_valid(smiles_valid_batch)
            
            test_loss_history[batch_idx] = test(model, testloader)
            # we can test later
            
            print(f'[batch num: {batch_idx}] sampled; [{validity:.3f}%] valid smiles sampled')
            
            if batch_idx == 0:
                print_loss = 'nan'
            else:
                print_loss = f'{loss.data.item():.4f}'

            
            path = Path(f'./data/lipinski/results/lstm-1-epoch-{batch_idx}-batch-{print_loss}-loss')
            path.mkdir(exist_ok=True)
            
            torch.save(model.state_dict(), str(path / 'model_dict.torch'))
            save(new_smiles, str(path / 'generated_smiles.dict' ))
            save(test_loss_history, str(path / 'test_loss.dict' ))
        
        model.train()
        model.zero_grad()
    
        x_batch, y_batch, batch_lens = batch
        
        batch_size = x_batch.size(0)
        max_len = int(max(batch_lens).item())
        
        x_batch = x_batch[:, 0:max_len, :]
        y_batch = y_batch[:, 0:max_len, :]

        # sort input
        batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
        batch_len_sorted = batch_len_sorted.numpy().tolist()
        
        x_batch = torch.index_select(x_batch, 0, sort_index)
        y_batch = torch.index_select(y_batch, 0, sort_index)

        x_batch = Variable(x_batch.float()).cuda()
        y_batch = Variable(y_batch.float()).cuda()

        # init state
        model.hidden = model.init_hidden(batch_size=x_batch.size(0))
        try:
            y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)
        except Exception as e:
            print(f'[ERROR] got exception {e}')
            print(f'[ERROR] skipping batch...')
            continue
        
        y_pred = F.log_softmax(y_pred.view(-1, NUM_SYM), dim=-1)
        _, y_batch = y_batch.topk(1, dim=-1)
        y_batch = y_batch.view(-1)
        loss = criterion(y_pred, y_batch)
        
        loss.backward()
        model_optim.step()

        
        print(f'[batch num: {batch_idx}] loss: {loss.data.item():.4f}')
                
    return loss_history, test_loss_history, new_smiles

In [37]:
P1_SMILES = (TRAIN_SMILES, TEST_SMILES)
P2_SMILES = 

SyntaxError: invalid syntax (<ipython-input-37-b4a44cde6c10>, line 2)

In [None]:
def transfer_learning(model, phases):
    """
    transfer_learning(model, phases):
        - model: pytorch model
        - phases: [Datasets] for datapaths
    """
    
    

## Testing and Generation

In [None]:
def softmax_temp_sample(y_t, temperature = 1.0):
    prediction_vector = F.softmax(y_t / temperature, dim=-1)
    x_index_t = torch.multinomial(prediction_vector, 1)[:, 0]
    return x_index_t

In [None]:
def topk_sample(y_t):
    _, pred_idx =  y_t.topk(1, dim=-1)
    return pred_idx

In [None]:
def generate(model, test_samples=5, sample_f=topk_sample, v=True):
    with torch.no_grad():
        model.eval()

        x = torch.zeros(test_samples, MAX_SYM, NUM_SYM).cuda()

        x[:, 0, SYM_TO_ID[GO_TOKEN]] = 1
        
        for i in range(MAX_SYM-1):
            model.hidden = model.init_hidden(batch_size=test_samples)
            
            pred = model(x, pack=True, input_lens=np.ones(test_samples)*(i+1))
            pred_idx = sample_f(pred[:,i,:])
            temp = torch.zeros(test_samples, MAX_SYM, NUM_SYM).cuda()
            for j in range(test_samples):
                temp[j, i+1, pred_idx[j]] = 1  
            x.add_(temp)

        if v: print('\n',10*'-' + 'GENERATED SMILES STRINGS' + 10*'-')
        smiles = []

        for j in range(test_samples):
            s = decode_short(x[j].cpu().numpy())
            smiles += [s]

            if v: print(s)

        return smiles
        
    

In [None]:
def test(model, testloader):
    with torch.no_grad():

        model.eval()

        criterion = nn.NLLLoss()
        total_loss = 0

        for batch in tqdm(testloader,total=len(testloader)):

            x_batch, y_batch, batch_lens = batch

            batch_size = x_batch.size(0)
            max_len = int(max(batch_lens).item())

            x_batch = x_batch[:, 0:max_len, :]
            y_batch = y_batch[:, 0:max_len, :]

            # sort input
            batch_len_sorted, sort_index = torch.sort(batch_lens, 0, descending=True)
            batch_len_sorted = batch_len_sorted.numpy().tolist()

            x_batch = torch.index_select(x_batch, 0, sort_index)
            y_batch = torch.index_select(y_batch, 0, sort_index)

            x_batch = Variable(x_batch.float()).cuda()
            y_batch = Variable(y_batch.float()).cuda()

            # init state
            model.hidden = model.init_hidden(batch_size=x_batch.size(0))

            y_pred = model(x_batch, pack=True, input_lens=batch_len_sorted)

            y_pred = y_pred.view(-1, NUM_SYM)
            _, y_batch = y_batch.topk(1, dim=-1)
            y_batch = y_batch.view(-1)

            total_loss += criterion(y_pred, y_batch)

        return total_loss / len(test_dataset)
    

## Results

In [None]:
train_history, train_history, new_smiles = train_one_epoch(model, model_optim, dataloader, test_interval=200)

In [None]:
def plot_history(history, l=-1):
    if l < 0:
        l = len(history)
    fig = px.line(x=np.arange(l), y=history[:l], labels={'x':'batch number', 'y':'binary cross-entropy loss'})
    return fig

In [None]:
plot_history(history)

In [None]:
newsmiles = generate(model, test_samples=500, sample_f=softmax_temp_sample, v=False)

In [None]:
count_valid(newsmiles)

In [None]:
validsmiles = [s for s in newsmiles if Chem.MolFromSmiles(s)]

In [None]:
validsmiles

In [None]:
newmols = [Chem.MolFromSmiles(s) for s in validsmiles]

In [None]:
def get_qed_mol(m):
    properties = ['MW', 'ALOGP', 'HBA', 'HBD', 'PSA', 'ROTB', 'AROM', 'ALERTS']
    mol_prop = QED.properties(m)
    mol_prop_num = [getattr(mol_prop, attr) for attr in properties]
    return mol_prop_num


In [None]:
druglike_mols = [Chem.MolFromSmiles(s[1:]) for s in TEST_SMILES]

active_mols = []
with open('./data/lipinski/active_lipinski.smi') as fp:
    for line in fp.readlines():
        active_mols += [Chem.MolFromSmiles(line[:-1])]

very_active_mols = []
with open('./data/lipinski/very_active_lipinski.smi') as fp:
    for line in fp.readlines():
        very_active_mols += [Chem.MolFromSmiles(line[:-1])]

In [None]:
ALL_QED_VALUES = []

for mol in tqdm(druglike_mols):
    values = ['DRUGLIKE_LIP'] + get_qed_mol(mol)
    ALL_QED_VALUES += [values]

for mol in tqdm(active_mols):
    values = ['ACTIVE_LIP'] + get_qed_mol(mol)
    ALL_QED_VALUES += [values]
    
for mol in tqdm(very_active_mols):
    values = ['VERY_ACTIVE_LIP'] + get_qed_mol(mol)
    ALL_QED_VALUES += [values]


In [None]:
mol_df = pd.DataFrame(data=ALL_QED_VALUES,columns=['TYPE', 'MW', 'ALOGP', 'HBA', 'HBD', 'PSA', 'ROTB', 'AROM', 'ALERTS' ])

In [None]:
len(mol_df)

In [None]:
mol_df.to_pickle('./data/lipinski/qed_df.pickle')

In [None]:
qed_data = mol_df.values[:,1:]

In [None]:
qed_data = qed_data.astype(np.float64)

In [None]:
import umap
reducer = umap.UMAP(n_components=3)

reducer.fit(qed_data)
embedding = reducer.transform(qed_data)

In [None]:
embedding.shape

In [None]:
emb_df = pd.DataFrame(data=embedding,columns=['COMP_1','COMP_2','COMP_3'])

In [None]:
emb_df['TYPE'] = mol_df['TYPE']

In [None]:
emb_df.head()

In [None]:
fig = px.scatter_3d(emb_df, x='COMP_1', y='COMP_2', z='COMP_3',
                    color='TYPE')
fig.show()