In [36]:
import pandas as pd
import os

import sys
sys.path.append('../../../../../icml18-jtnn')
sys.path.append('../../../../../icml18-jtnn/jtnn')
from tqdm import tqdm

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
from optparse import OptionParser
from collections import deque

from jtnn import *
import rdkit

from jtnn_enc import JTNNEncoder

In [30]:
VOCAB_PATH = '../../../../../icml18-jtnn/data/zinc_our/vocab.txt'

DATASET_PATH = '../../../data/3_final_data/split_data'

In [7]:
vocab = [x.strip("\r\n ") for x in open(VOCAB_PATH)] 
vocab = Vocab(vocab)

In [8]:
batch_size = int(200)
hidden_size = int(100)
latent_size = int(56)
depth = int(3)
beta = float(0)
lr = float(1e-5)
stereo = True if int(1) == 1 else False



In [107]:
def set_batch_nodeID(mol_batch, vocab):
    tot = 0
    for mol_tree in mol_batch:
        for node in mol_tree.nodes:
            node.idx = tot
            node.wid = vocab.get_index(node.smiles)
            tot += 1

In [116]:
class JTPredict(nn.Module):

    def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
        super(JTPredict, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth

        self.embedding = nn.Embedding(vocab.size(), hidden_size)
        self.jtnn = JTNNEncoder(vocab, hidden_size, self.embedding)
        self.mpn = MPN(hidden_size, depth)
        
        self.output_size = 1

        self.T_mean = nn.Linear(hidden_size, latent_size / 2)
        self.T_var = nn.Linear(hidden_size, latent_size / 2)
        self.G_mean = nn.Linear(hidden_size, latent_size / 2)
        self.G_var = nn.Linear(hidden_size, latent_size / 2)
        
        self.use_stereo = stereo
        if stereo:
            self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
    
    def encode(self, mol_batch):
        set_batch_nodeID(mol_batch, self.vocab)
        root_batch = [mol_tree.nodes[0] for mol_tree in mol_batch]
        tree_mess,tree_vec = self.jtnn(root_batch)

        smiles_batch = [mol_tree.smiles for mol_tree in mol_batch]
        mol_vec = self.mpn(mol2graph(smiles_batch))
        return tree_mess, tree_vec, mol_vec

    def encode_latent_mean(self, smiles_list):
        print(smiles_list)
        mol_batch = [MolTree(s) for s in smiles_list]
#         print(mol_batch)
        for mol_tree in mol_batch:
            mol_tree.recover()

        _, tree_vec, mol_vec = self.encode(mol_batch)
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        return torch.cat([tree_mean,mol_mean], dim=1)
    
    def create_ffn(self, ffn_num_layers = 3, ffn_hidden_size = 50):
        """
        Creates the feed-forward layers for the model.
        :param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
        """
        dropout = nn.Dropout(0.5)
        activation = nn.ReLU()
        
        first_linear_dim = self.latent_size

        # Create FFN layers
        if ffn_num_layers == 1:
            ffn = [
                dropout,
                nn.Linear(first_linear_dim, self.output_size)
            ]
        else:
            ffn = [
                dropout,
                nn.Linear(first_linear_dim, ffn_hidden_size)
            ]
            for _ in range(ffn_num_layers - 2):
                ffn.extend([
                    activation,
                    dropout,
                    nn.Linear(ffn_hidden_size, ffn_hidden_size),
                ])
            ffn.extend([
                activation,
                dropout,
                nn.Linear(ffn_hidden_size, self.output_size),
            ])

        # Create FFN model
        self.ffn = nn.Sequential(*ffn)

    def forward(self, mol_batch, beta=0):
        batch_size = len(mol_batch)

        _, tree_vec, mol_vec = self.encode(mol_batch)
        
        tree_mean = self.T_mean(tree_vec)
        mol_mean = self.G_mean(mol_vec)
        
        feature_vec =  torch.cat([tree_mean, mol_mean], dim=1)
        
        return self.ffn(feature_vec)

    


In [117]:
model = JTPredict(vocab, hidden_size, latent_size, depth, stereo=stereo)

In [118]:
model.create_ffn()

In [119]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)


In [120]:
from torch.utils.data import Dataset
from mol_tree import MolTree
import numpy as np

class MoleculeDataset(Dataset):

    def __init__(self, data_file, SMILES_COLUMN = 'smiles', TARGET_COLUMN = 'logP'):
        with open(data_file) as f:
            self.data = pd.read_csv(data_file)[:10]
            
        self.SMILES_COLUMN = SMILES_COLUMN
        self.TARGET_COLUMN = TARGET_COLUMN
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        smiles = self.data.iloc[idx][self.SMILES_COLUMN]
        target = self.data.iloc[idx][self.TARGET_COLUMN]
        mol_tree = MolTree(smiles)
        mol_tree.recover()
        mol_tree.assemble()
        return mol_tree, target

In [121]:
model = model.cuda()
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

dataset = MoleculeDataset(os.path.join(DATASET_PATH, 'logp_wo_averaging_validation.csv'))


Model #Params: 354K


In [122]:
MAX_EPOCH = 10
PRINT_ITER = 20
criterion = nn.MSELoss()

In [125]:
for epoch in xrange(MAX_EPOCH):
    dataloader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=lambda x: x, num_workers=4, drop_last=True)

    print(epoch)

    for it,  batch in enumerate(dataloader):
        X = []
        y = []
        for elem in batch:
            X.append(elem[0])
            y.append(elem[1])
        y = torch.Tensor(y).cuda()
#         batch, target = batch
#         for mol_tree in batch:
#             for node in mol_tree.nodes:
#                 if node.label not in node.cands:
#                     node.cands.append(node.label)
#                     node.cand_mols.append(node.label_mol)

#         try:
        model.zero_grad()
        pred = model(X)
        loss = criterion(y, pred)
        loss.backward()
        optimizer.step()
        
        print(loss)
        print('sdfsd')
#         except Exception as e:
#             with open('broken_smiles.txt', 'a') as f:
#                 smiles = [elem.smiles for elem in batch]
#                 f.write('\n'.join(smiles))
#                 f.write('\n')
#                 f.write(e.args[0])
#                 f.write('\n\n\n\n')
#             print(e)
#             continue

        



0
tensor(3.0210, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(5.6895, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(3.9407, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
1
tensor(4.2203, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(4.4869, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(3.0060, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
2
tensor(2.7947, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(1.4723, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(4.8524, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
3
tensor(5.0123, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(3.1371, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(3.7929, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
4
tensor(3.0206, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(3.3031, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
tensor(4.8036, device='cuda:0', grad_fn=<MeanBackward1>)
sdfsd
5
tensor(2.3912, device='cuda:0', grad_fn=<Me