# Load functions, set variables, create new functions for dataset and model

In [8]:
import pandas as pd
import os

import sys
sys.path.append('../../../../icml18-jtnn')
sys.path.append('../../../../icml18-jtnn/jtnn')
from tqdm import tqdm

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
from optparse import OptionParser
from collections import deque

from jtnn import *
import rdkit

from jtnn_enc import JTNNEncoder

In [10]:


VOCAB_PATH = '../../../../icml18-jtnn/data/zinc/vocab.txt'
MODEL_PATH = '../../../../icml18-jtnn/molvae'
DATASET_PATH = '../../../data/3_final_data/split_data'

RAW_PATH = '../../../data/raw/baselines/jtree'

SMILES_COLUMN = 'smiles'
VALUE_COLUMN = 'logP'

In [11]:
vocab = [x.strip("\r\n ") for x in open(VOCAB_PATH)] 
vocab = Vocab(vocab)

In [25]:
batch_size = int(200)
hidden_size = int(450)
latent_size = int(56)
depth = int(3)
beta = float(0)
lr = float(1e-5)
stereo = True if int(1) == 1 else False



In [13]:
def set_batch_nodeID(mol_batch, vocab):
    tot = 0
    for mol_tree in mol_batch:
        for node in mol_tree.nodes:
            node.idx = tot
            node.wid = vocab.get_index(node.smiles)
            tot += 1

In [14]:
class JTPredict(nn.Module):

    def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
        super(JTPredict, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.mpn = MPN(hidden_size, depth)
        
        self.output_size = 1

        self.T_mean = nn.Linear(hidden_size, latent_size / 2)
        self.T_var = nn.Linear(hidden_size, latent_size / 2)
        self.G_mean = nn.Linear(hidden_size, latent_size / 2)
        self.G_var = nn.Linear(hidden_size, latent_size / 2)
        
        self.use_stereo = stereo
        if stereo:
            self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
    
    def encode(self, mol_batch):
        set_batch_nodeID(mol_batch, self.vocab)
        root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch]
        tree_mess,tree_vec = self.jtnn(root_batch)

        smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
        mol_vec = self.mpn(mol2graph(smiles_batch))
        return tree_mess, tree_vec, mol_vec

    def encode_latent_mean(self, smiles_list):
        print(smiles_list)
        mol_batch = [MolTree(s) for s in smiles_list]
#         print(mol_batch)
        for mol_tree in mol_batch:
            mol_tree.recover()

        _, tree_vec, mol_vec = self.encode(mol_batch)
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        return torch.cat([tree_mean,mol_mean], dim=1)
    
    def create_ffn(self, ffn_num_layers = 3, ffn_hidden_size = 50):
        """
        Creates the feed-forward layers for the model.
        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
        """
        dropout = nn.Dropout(0.5)
        activation = nn.ReLU()
        
        first_linear_dim = self.latent_size

        # Create FFN layers
        if ffn_num_layers == 1:
            ffn = [
                dropout,
                nn.Linear(first_linear_dim, self.output_size)
            ]
        else:
            ffn = [
                dropout,
                nn.Linear(first_linear_dim, ffn_hidden_size)
            ]
            for _ in range(ffn_num_layers - 2):
                ffn.extend([
                    activation,
                    dropout,
                    nn.Linear(ffn_hidden_size, ffn_hidden_size),
                ])
            ffn.extend([
                activation,
                dropout,
                nn.Linear(ffn_hidden_size, self.output_size),
            ])

        # Create FFN model
        self.ffn = nn.Sequential(*ffn)

    def forward(self, mol_batch, beta=0):
        batch_size = len(mol_batch)

        _, tree_vec, mol_vec = self.encode(mol_batch)
        
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        
        feature_vec =  torch.cat([tree_mean, mol_mean], dim=1)
        
        return self.ffn(feature_vec)

    


In [55]:
from torch.utils.data import Dataset
from mol_tree import MolTree
import numpy as np

# global SMILES_TO_MOLTREE
SMILES_TO_MOLTREE = {}


class MoleculeDataset(Dataset):

    def __init__(self, data_file, raw_path, SMILES_COLUMN = 'smiles', TARGET_COLUMN = 'logP'):
        global SMILES_TO_MOLTREE
        
        self.data = pd.read_csv(data_file)
        data_options = ['train','val','test']
        
        for option in data_options:
            if option in data_file:
                broken_smiles  = [x.strip("\r\n ") for x in open(os.path.join(raw_path,option+'_errs.txt'))] 
                
        self.data = self.data[~self.data[SMILES_COLUMN].isin(broken_smiles)]    [:10]
        self.SMILES_COLUMN = SMILES_COLUMN
        self.TARGET_COLUMN = TARGET_COLUMN
        
        for i in range(len(self.data)):
            if self.data.iloc[i][SMILES_COLUMN] in SMILES_TO_MOLTREE:
                mol_tree = SMILES_TO_MOLTREE[self.data.iloc[i][SMILES_COLUMN]]
            else:
                mol_tree = MolTree(self.data.iloc[i][SMILES_COLUMN])
                SMILES_TO_MOLTREE[self.data.iloc[i][SMILES_COLUMN]] = mol_tree
                mol_tree.recover()
                mol_tree.assemble()
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        global SMILES_TO_MOLTREE
        smiles = self.data.iloc[idx][self.SMILES_COLUMN]
        target = self.data.iloc[idx][self.TARGET_COLUMN]
        if smiles in SMILES_TO_MOLTREE.keys():
            mol_tree = SMILES_TO_MOLTREE[smiles]
        else:
            mol_tree = MolTree(smiles)
            SMILES_TO_MOLTREE[smiles] = mol_tree
        return mol_tree, target        
    

# Create model and load data

In [30]:
model = JTPredict(vocab, hidden_size, latent_size, depth, stereo=stereo)

In [31]:
model.create_ffn()

In [32]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)


## Load pretrained weights

In [54]:
from jtnn_vae import JTNNVAE

model_VAE = JTNNVAE(vocab, hidden_size, latent_size, depth, stereo=stereo)
model_VAE.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'MPNVAE-h450-L56-d3-beta0.001/model.iter-4')))

In [47]:
model.jtnn = model_VAE.jtnn
model.mpn = model_VAE.mpn
model.embedding = model_VAE.embedding
model.T_mean = model_VAE.T_mean
model.T_var = model_VAE.T_var
model.G_mean = model_VAE.G_mean
model.G_var = model_VAE.G_var

In [50]:
del model_VAE

In [51]:
for name, param in model.named_parameters():
    print name, param.data

embedding.weight tensor([[ 0.0594,  0.0186, -0.0397,  ...,  0.0985,  0.1281, -0.0242],
        [ 0.0607, -0.2044, -0.0376,  ..., -0.0056,  0.0935, -0.0021],
        [ 0.0100, -0.0274, -0.0751,  ...,  0.0867, -0.0365,  0.0391],
        ...,
        [ 0.0686,  0.0343,  0.0617,  ..., -0.0690,  0.0751,  0.0362],
        [ 0.0369,  0.0062,  0.0455,  ..., -0.0077, -0.0310, -0.0112],
        [ 0.1684,  0.0122, -0.0144,  ...,  0.0881,  0.0752,  0.0060]])
jtnn.W_z.weight tensor([[-0.0506,  0.0644,  0.2186,  ..., -0.1192, -0.0397, -0.0724],
        [-0.3613,  0.0557,  0.3665,  ...,  0.0265, -0.0641,  0.2011],
        [ 0.2609,  0.0762,  0.1762,  ...,  0.2432, -0.1284, -0.0840],
        ...,
        [ 0.1006, -0.0124,  0.2878,  ...,  0.0095, -0.2124,  0.0286],
        [-0.1054, -0.0414,  0.3038,  ...,  0.0535, -0.0193, -0.1250],
        [ 0.1979,  0.1238,  0.3254,  ...,  0.0086, -0.0346,  0.4131]])
jtnn.W_z.bias tensor([-0.3559, -0.4492, -0.1669, -0.2501, -0.1203, -0.6630, -0.4290, -0.5225,
     

In [56]:
model = model.cuda()
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

dataset = MoleculeDataset(os.path.join(DATASET_PATH, 'logp_wo_averaging_validation.csv'), RAW_PATH)


Model #Params: 2474K


# Train function draft

In [57]:
MAX_EPOCH = 10
PRINT_ITER = 20
criterion = nn.MSELoss()

In [58]:
for epoch in xrange(MAX_EPOCH):
    dataloader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=lambda x: x, num_workers=1, drop_last=True)

    print(epoch)

    for it,  batch in enumerate(dataloader):
        X = []
        y = []
        for elem in batch:
            X.append(elem[0])
            y.append(elem[1])
        y = torch.Tensor(y).cuda()
# #         batch, target = batch
# #         for mol_tree in batch:
# #             for node in mol_tree.nodes:
# #                 if node.label not in node.cands:
# #                     node.cands.append(node.label)
# #                     node.cand_mols.append(node.label_mol)

# #         try:
        model.zero_grad()
        pred = model(X)
        loss = criterion(y, pred)
        loss.backward()
        optimizer.step()
        
        print(loss)
#         print('sdfsd')
#         except Exception as e:
#             with open('broken_smiles.txt', 'a') as f:
#                 smiles = [elem.smiles for elem in batch]
#                 f.write('\n'.join(smiles))
#                 f.write('\n')
#                 f.write(e.args[0])
#                 f.write('\n\n\n\n')
#             print(e)
#             continue

        



0
tensor(21.4522, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(33.3940, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(4.6014, device='cuda:0', grad_fn=<MeanBackward1>)
1
tensor(61.3403, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(142.5458, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(50.6553, device='cuda:0', grad_fn=<MeanBackward1>)
2
tensor(22.7041, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(29.3614, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(42.1143, device='cuda:0', grad_fn=<MeanBackward1>)
3
tensor(30.2984, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(33.3286, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(56.2206, device='cuda:0', grad_fn=<MeanBackward1>)
4
tensor(43.3054, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(23.1365, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(23.1205, device='cuda:0', grad_fn=<MeanBackward1>)
5
tensor(74.3726, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(23.8956, device='cuda:0', grad_fn=<MeanBackward1>)
te

In [79]:
SMILES_TO_MOLTREE

{'CC(=O)OCCC(C)C': <mol_tree.MolTree at 0x7efbabb36d90>,
 'CC(=O)c1ccc2cc(C(C)C(=O)OCC(=O)N(C)C)ccc2c1': <mol_tree.MolTree at 0x7efbaaa1d750>,
 'CC(C)(C)S(=O)(=O)CC(Cc1ccccc1)C(=O)NC(Cc1c[nH]cn1)C(=O)NC(CC1CCCCC1)C(O)C(O)C1CC1': <mol_tree.MolTree at 0x7efbaaa7b090>,
 'CN=C(NC#N)NCCSCc1nccs1': <mol_tree.MolTree at 0x7efbabb68750>,
 'Clc1ccc(-c2nc3cccnc3[nH]2)cc1': <mol_tree.MolTree at 0x7efbaaa7bfd0>,
 'N#CN=[N+]([O-])c1ccc(Br)cc1': <mol_tree.MolTree at 0x7efbabb685d0>,
 'N=c1nc(-c2ccccc2Br)[nH]c(=N)[nH]1': <mol_tree.MolTree at 0x7efbafced490>,
 'O=C1CCc2ccccc2N1': <mol_tree.MolTree at 0x7efbaaa7bc90>,
 'O=C1c2ccccc2C(=O)N1SC(Cl)(Cl)Cl': <mol_tree.MolTree at 0x7efbabb85690>,
 'O=c1ccc(=O)[nH][nH]1': <mol_tree.MolTree at 0x7efbaaa7b050>}

In [62]:
dataset[0]

('N#CN=[N+]([O-])c1ccc(Br)cc1', {})


(<mol_tree.MolTree at 0x7efbabb61f10>, 2.5)

In [63]:
dataset[1]

('O=C1c2ccccc2C(=O)N1SC(Cl)(Cl)Cl', {'N#CN=[N+]([O-])c1ccc(Br)cc1': <mol_tree.MolTree object at 0x7efbabb61f10>})


(<mol_tree.MolTree at 0x7efbaa980350>, 2.85)

In [64]:
dataset[2]

('CC(C)(C)S(=O)(=O)CC(Cc1ccccc1)C(=O)NC(Cc1c[nH]cn1)C(=O)NC(CC1CCCCC1)C(O)C(O)C1CC1', {'N#CN=[N+]([O-])c1ccc(Br)cc1': <mol_tree.MolTree object at 0x7efbabb61f10>, 'O=C1c2ccccc2C(=O)N1SC(Cl)(Cl)Cl': <mol_tree.MolTree object at 0x7efbaa980350>})


(<mol_tree.MolTree at 0x7efbaa980510>, 2.75)

In [66]:
dataset[0]

('N#CN=[N+]([O-])c1ccc(Br)cc1', {'N#CN=[N+]([O-])c1ccc(Br)cc1': <mol_tree.MolTree object at 0x7efbabb61f10>, 'CC(C)(C)S(=O)(=O)CC(Cc1ccccc1)C(=O)NC(Cc1c[nH]cn1)C(=O)NC(CC1CCCCC1)C(O)C(O)C1CC1': <mol_tree.MolTree object at 0x7efbaa980510>, 'O=C1c2ccccc2C(=O)N1SC(Cl)(Cl)Cl': <mol_tree.MolTree object at 0x7efbaa980350>})


(<mol_tree.MolTree at 0x7efbabb61f10>, 2.5)