In [1]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
from __future__ import print_function
import argparse
import h5py
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection

from DeepPurpose import utils
from DeepPurpose import DTI as models
from DeepPurpose.utils import *
from DeepPurpose.dataset import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
import argparse
from collections import deque
# import cPickle as pickle
import pickle
import gc
# from fast_jtnn import *
from hgraph import *
import rdkit

from cnn_finetune import make_model
import torch.nn as nn

from albumentations.pytorch import ToTensorV2
import albumentations as A
from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,
    RandomCrop, Normalize, Resize
)
import cv2
from tqdm import tqdm

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

In [2]:
def one_hot_array(i, n):
    return map(int, [ix == i for ix in xrange(n)])

def one_hot_index(vec, charset):
    return map(charset.index, vec)

def from_one_hot_array(vec):
    oh = np.where(vec == 1)
    if oh[0].shape == (0, ):
        return None
    return int(oh[0][0])

def decode_smiles_from_indexes(vec, charset):
    return b"".join(map(lambda x: charset[x], vec)).strip()

def decode_smiles_from_indexes_new (vec, enc_drug):
    b = np.zeros((len(vec), 63))
    b[np.arange(len(vec)),vec] = 1
    return ("".join(enc_drug.inverse_transform(b).reshape(-1))).strip('?')

def load_dataset(filename, split = True):
    h5f = h5py.File(filename, 'r')
    if split:
        data_train = h5f['data_train'][:]
    else:
        data_train = None
    data_test = h5f['data_test'][:]
    charset =  h5f['charset'][:]
    h5f.close()
    if split:
        return (data_train, data_test, charset)
    else:
        return (data_test, charset)


In [3]:
class ARGS:
    def __init__(self):
        self.train = './train_processed_bms_ver2/mol/'
        self.vocab = 'data/chembl/vocab.txt'
        self.atom_vocab = common_atom_vocab
        self.save_dir = 'vae_model/'
        self.load_model = None
        self.seed = 7
        
        self.rnn_type = 'LSTM'
        self.hidden_size = 250
        self.embed_size = 250
        self.batch_size = 32
        self.latent_size = 32
        self.depthT = 15
        self.depthG = 15
        self.diterT = 1
        self.diterG = 3
        self.dropout = 0.0
        
        self.lr = 1e-3
        self.clip_norm = 5.0
#         self.beta = 0.0
        self.step_beta = 0.001
        self.max_beta = 1.0
        self.warmup = 10000
        self.kl_anneal_iter = 2000
        
        self.epoch = 100
        self.anneal_rate = 0.9
        self.anneal_iter = 25000
        self.kl_anneal_iter = 2000
        self.print_iter = 50
        self.save_iter = 5000
        
        
args = ARGS()

## Model

In [4]:
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    
print (device)

cuda:0


In [5]:
NUM_DICT = 63
MAX_LEN = 100

class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        # The input filter dim should be 35
        #  corresponds to the size of CHARSET
        self.conv1d1 = nn.Conv1d(NUM_DICT, 9, kernel_size=9)  
        self.conv1d2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv1d3 = nn.Conv1d(9, 10, kernel_size=11)
        self.fc0 = nn.Linear(740, 435)
        self.fc11 = nn.Linear(435, 292)
        self.fc12 = nn.Linear(435, 292)

        self.fc2 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.fc3 = nn.Linear(501, NUM_DICT)

    def encode(self, x):
        h = F.relu(self.conv1d1(x))
        h = F.relu(self.conv1d2(h))
        h = F.relu(self.conv1d3(h))
        h = h.view(h.size(0), -1)
        h = F.selu(self.fc0(h))
        return self.fc11(h), self.fc12(h)

    def reparametrize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = 1e-2 * torch.randn_like(std)
            w = eps.mul(std).add_(mu)
            return w
        else:
            return mu
        
    def reparametrize_ver2(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = 1e-2 * torch.randn_like(std)
        w = eps.mul(std).add_(mu)
        return w

    def decode(self, z):
        z = F.selu(self.fc2(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, MAX_LEN, 1)
        out, h = self.gru(z)
        out_reshape = out.contiguous().view(-1, out.size(-1))
        y0 = F.softmax(self.fc3(out_reshape), dim=1)
        y = y0.contiguous().view(out.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

In [6]:
b = [bytes(i, 'utf-8') for i in smiles_char]
charset = np.array(b)
from sklearn.preprocessing import OneHotEncoder
enc_drug = OneHotEncoder().fit(np.array(smiles_char).reshape(-1, 1))

In [7]:
vocab = [x.strip("\r\n ").split() for x in open(args.vocab)] 
args.vocab = PairVocab(vocab)

In [8]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

# data_train, data_test, charset = load_dataset('./data/processed.h5')

In [9]:
torch.manual_seed(42)

epochs = 40
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_vae = MolecularVAE().to(device)
model_vae.load_state_dict(torch.load('/home/quang/working/denuvo/DeepPurpose/model_weight/molecular_vae.pt'))
model_vae.train()
for param in model_vae.parameters():
    param.requires_grad = False

In [10]:
model_K = make_model('inception_v4', num_classes=10, pretrained=True)

for param in model_K.parameters():
    param.requires_grad = False

model_K._classifier.weight.requires_grad = True
model_K._classifier.bias.requires_grad = True


class InceptionV4Bottom(nn.Module):
            def __init__(self, original_model):
                super(InceptionV4Bottom, self).__init__()
                self.features = nn.Sequential(
                    # stop at conv4
                    *list(original_model.children())[:-1]
                )
#                 self.features = nn.Sequential(
#                     # stop at conv4
#                     *list(original_model.children())
#                 )
                dim1 = 512
                dim2 = 292
                self.num_ftrs = original_model._classifier.in_features
                self.classifier1 = nn.Sequential(nn.Linear(self.num_ftrs, dim1), nn.BatchNorm1d(dim1), nn.ReLU(), 
                                                nn.Linear(dim1, dim2))
                self.classifier2 = nn.Sequential(nn.Linear(self.num_ftrs, dim1), nn.BatchNorm1d(dim1), nn.ReLU(), 
                                                nn.Linear(dim1, dim2))

            def forward(self, x1):
                x1 = self.features(x1)
#                 print (x1.shape)
                x1 = x1.view(-1, self.num_ftrs)
#                 x2 = self.dense1(x2)
#                 print (x2.shape, x1.shape)
#                 x = torch.cat((x1, x2), dim=1)
                o1 = self.classifier1(x1)
                o2 = self.classifier1(x1)
                return o1, o2

model_cnn = InceptionV4Bottom(model_K)

In [11]:
def unfreeze(model, ct_c = 7):
    for param in model.parameters():
        param.requires_grad = True

    ct = 0
    for child in model.features[0].children():
        ct += 1
        if ct < ct_c:
            for param in child.parameters():
                param.requires_grad = False
                
# unfreeze(model_cnn,5)

In [12]:
model_cnn = model_cnn.to(device)

# Observe that all parameters are being optimized
optimizer = optim.Adam(model_cnn.parameters(), lr=0.001)

In [13]:
IMG_SIZE = 224
transform = A.Compose([
    A.Resize(int(IMG_SIZE), int(IMG_SIZE)),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [14]:
from sklearn.utils import shuffle as skutils 
import cv2

smiles_char = ['?', '#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4',
       '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I',
       'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V',
       'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i',
       'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y']

def trans_drug(x):
    temp = list(x)
    temp = [i if i in smiles_char else '?' for i in temp]
    if len(temp) < MAX_SEQ_DRUG:
        temp = temp + ['?'] * (MAX_SEQ_DRUG-len(temp))
    else:
        temp = temp [:MAX_SEQ_DRUG]
    return temp

class DataFolder_ver3_BMS(object):

    def __init__(self, data_folder, ids_path, batch_size, transform, path_img, shuffle=True):
        self.data_folder = data_folder
        self.data_files = [fn for fn in os.listdir(data_folder)]
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.transform = transform
        self.path_img = path_img
        self.ids_path = ids_path
        self.drug_encoding = 'CNN'

    def __len__(self):
        return len(self.data_files) * 50

    def __iter__(self):
        for fn_t in self.data_files:
            fn = os.path.join(self.data_folder, fn_t)
            with open(fn, 'rb') as f:
                batches = pickle.load(f)
                
            fn_ids = os.path.join(self.ids_path, fn_t)
            with open(fn_ids, 'rb') as f:
                batches_ids, batches_labels = pickle.load(f)

            if self.shuffle: 
#                 random.shuffle(batches) #shuffle data before batch
                batches, batches_ids, batches_labels = skutils(batches, batches_ids, batches_labels)
            
            for batch, batch_ids, batches_label in zip(batches, batches_ids, batches_labels):
                images_list = []
                for id_temp in batch_ids:
                    image = cv2.imread(os.path.join(self.path_img, id_temp + '.png'))
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                    transformed = self.transform(image=image)
                    image = transformed['image']
                    images_list.append(image)
                batch_image = torch.stack(images_list)
                
                df_data = pd.DataFrame(batches_label)
                df_data.rename(columns={0:'SMILES',},inplace=True)
#                 df_data = encode_drug(df_data, self.drug_encoding)
                unique = pd.Series(df_data['SMILES'].unique()).apply(trans_drug)
                unique_dict = dict(zip(df_data['SMILES'].unique(), unique))
                df_data['drug_encoding'] = [unique_dict[i] for i in df_data['SMILES']]
                
                v_d = [drug_2_embed(i) for i in df_data['drug_encoding'].values]
                
                yield batch, batch_image, batches_label, torch.from_numpy(np.array(v_d)).float()

            del batches, batch_image, batches_ids, batches_labels
            gc.collect()


In [15]:
dataset = DataFolder_ver3_BMS(args.train, '/home/quang/working/Theory_of_ML/hgraph2graph/train_processed_bms_ver2/ids', args.batch_size, transform, '/home/quang/working/Theory_of_ML/mbs-molecular-captioning/images/train/')
for batch, images_list, labels, v_ds in tqdm(dataset):
    break

  0%|          | 0/200 [00:00<?, ?it/s]


## Training

In [16]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

In [31]:
total_step = 0
meters = np.zeros(1)
for epoch in range(args.epoch):
    dataset = DataFolder_ver3_BMS(args.train, '/home/quang/working/Theory_of_ML/hgraph2graph/train_processed_bms_ver2/ids', args.batch_size, transform, '/home/quang/working/Theory_of_ML/mbs-molecular-captioning/images/train/')
    for batch, images_list, labels, v_ds in tqdm(dataset):
        total_step += 1
#         images_list, tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx) = batch_t
#         batch = tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
        images_list = images_list.to(device)
        v_ds = v_ds.to(device)
        optimizer.zero_grad()
        mu_vec, var_vec = model_cnn(images_list)
        z = model_vae.reparametrize_ver2(mu_vec, var_vec)
        output = model_vae.decode(z)
#         print (output.shape, v_ds.shape)
        loss = vae_loss(output, v_ds.transpose(1,2), mu_vec, var_vec)
        loss.backward()
        optimizer.step()
        
#         try:
#             optimizer.zero_grad()
#             tree_vec, mol_vec = model_cnn(images_list)
#             loss, kl_div, wacc, iacc, tacc, sacc = model.forward_edit(*batch, mol_vec, beta=beta)
# #             loss, kl_div, wacc, tacc, sacc = model.forward_edit(batch, beta, tree_vec, mol_vec)
#             loss.backward()
#             optimizer.step()
#         except Exception as e:
#             print (e)
#             continue

        meters = meters + np.array([loss.item()])

        if total_step % args.print_iter == 0:
            meters /= args.print_iter
            print("[%d] KL: loss: %.3f" % (total_step, meters[0]))
            sys.stdout.flush()
            meters *= 0

#         if total_step % args.save_iter == 0:
#             torch.save(model.state_dict(), args.save_dir + "/model.iter-" + str(total_step))

#         if total_step % args.anneal_iter == 0:
#             scheduler.step()
#             print ("learning rate: %.6f" % scheduler.get_lr()[0])

#         if total_step % args.kl_anneal_iter == 0 and total_step >= args.warmup:
#             beta = min(args.max_beta, beta + args.step_beta)
            
torch.save(model_cnn.state_dict(),  "./model_cnn/model_cnn_CNNRNN_ver5.iter-" + str(total_step))

 24%|██▍       | 49/200 [00:09<00:27,  5.51it/s]

[50] KL: loss: 25951.596


 50%|████▉     | 99/200 [00:19<00:19,  5.12it/s]

[100] KL: loss: 25342.919


 74%|███████▍  | 149/200 [00:29<00:09,  5.65it/s]

[150] KL: loss: 24515.640


100%|█████████▉| 199/200 [00:39<00:00,  5.60it/s]

[200] KL: loss: 23964.124


245it [00:48,  5.04it/s]                         
  2%|▏         | 4/200 [00:01<00:52,  3.75it/s]

[250] KL: loss: 23866.978


 27%|██▋       | 54/200 [00:11<00:28,  5.11it/s]

[300] KL: loss: 23483.858


 52%|█████▏    | 104/200 [00:21<00:18,  5.09it/s]

[350] KL: loss: 23099.925


 77%|███████▋  | 154/200 [00:31<00:08,  5.18it/s]

[400] KL: loss: 22581.579


204it [00:41,  5.52it/s]                         

[450] KL: loss: 22352.718


245it [00:49,  4.96it/s]
  4%|▍         | 9/200 [00:02<00:38,  4.93it/s]

[500] KL: loss: 21777.814


 30%|██▉       | 59/200 [00:12<00:27,  5.12it/s]

[550] KL: loss: 21596.499


 55%|█████▍    | 109/200 [00:22<00:17,  5.08it/s]

[600] KL: loss: 21536.215


 80%|███████▉  | 159/200 [00:32<00:08,  5.03it/s]

[650] KL: loss: 21068.637


209it [00:42,  5.17it/s]                         

[700] KL: loss: 21056.430


245it [00:50,  4.89it/s]
  7%|▋         | 14/200 [00:03<00:34,  5.40it/s]

[750] KL: loss: 20666.541


 32%|███▏      | 64/200 [00:13<00:31,  4.25it/s]

[800] KL: loss: 20798.271


 57%|█████▋    | 114/200 [00:22<00:16,  5.07it/s]

[850] KL: loss: 21360.875


 82%|████████▏ | 164/200 [00:33<00:07,  5.05it/s]

[900] KL: loss: 26788.226


214it [00:43,  5.14it/s]                         

[950] KL: loss: 26217.368


245it [00:49,  4.95it/s]
 10%|▉         | 19/200 [00:04<00:33,  5.37it/s]

[1000] KL: loss: 25639.614


 34%|███▍      | 69/200 [00:14<00:25,  5.20it/s]

[1050] KL: loss: 24755.501


 60%|█████▉    | 119/200 [00:23<00:15,  5.15it/s]

[1100] KL: loss: 25166.908


 84%|████████▍ | 169/200 [00:34<00:06,  5.02it/s]

[1150] KL: loss: 23426.791


219it [00:44,  5.11it/s]                         

[1200] KL: loss: 22968.400


245it [00:49,  4.93it/s]
 12%|█▏        | 24/200 [00:04<00:31,  5.62it/s]

[1250] KL: loss: 22477.990


 37%|███▋      | 74/200 [00:15<00:22,  5.54it/s]

[1300] KL: loss: 21959.384


 62%|██████▏   | 124/200 [00:25<00:22,  3.35it/s]

[1350] KL: loss: 22014.977


 87%|████████▋ | 174/200 [00:34<00:05,  4.87it/s]

[1400] KL: loss: 20910.205


224it [00:44,  5.21it/s]                         

[1450] KL: loss: 20603.796


245it [00:49,  4.97it/s]
 14%|█▍        | 29/200 [00:05<00:33,  5.09it/s]

[1500] KL: loss: 20388.427


 40%|███▉      | 79/200 [00:16<00:22,  5.32it/s]

[1550] KL: loss: 20223.172


 64%|██████▍   | 129/200 [00:26<00:14,  5.04it/s]

[1600] KL: loss: 19652.487


 90%|████████▉ | 179/200 [00:35<00:04,  5.14it/s]

[1650] KL: loss: 19203.091


229it [00:46,  5.03it/s]                         

[1700] KL: loss: 18978.948


245it [00:49,  4.96it/s]
 17%|█▋        | 34/200 [00:07<00:33,  4.96it/s]

[1750] KL: loss: 18952.498


 42%|████▏     | 84/200 [00:16<00:22,  5.19it/s]

[1800] KL: loss: 18465.655


 67%|██████▋   | 134/200 [00:27<00:12,  5.40it/s]

[1850] KL: loss: 18441.688


 92%|█████████▏| 184/200 [00:37<00:06,  2.52it/s]

[1900] KL: loss: 18121.203


234it [00:47,  5.36it/s]                         

[1950] KL: loss: 17979.032


245it [00:49,  4.93it/s]
 20%|█▉        | 39/200 [00:07<00:31,  5.16it/s]

[2000] KL: loss: 17720.836


 44%|████▍     | 89/200 [00:17<00:20,  5.34it/s]

[2050] KL: loss: 17628.522


 70%|██████▉   | 139/200 [00:28<00:10,  5.56it/s]

[2100] KL: loss: 17046.915


 94%|█████████▍| 189/200 [00:38<00:02,  4.46it/s]

[2150] KL: loss: 17147.146


239it [00:48,  4.84it/s]                         

[2200] KL: loss: 16906.157


245it [00:49,  4.95it/s]
 22%|██▏       | 44/200 [00:08<00:30,  5.13it/s]

[2250] KL: loss: 16699.071


 47%|████▋     | 94/200 [00:18<00:21,  4.96it/s]

[2300] KL: loss: 16651.907


 72%|███████▏  | 144/200 [00:29<00:10,  5.51it/s]

[2350] KL: loss: 16325.635


 97%|█████████▋| 194/200 [00:39<00:01,  5.31it/s]

[2400] KL: loss: 16087.209


244it [00:49,  5.17it/s]                         

[2450] KL: loss: 16173.945


245it [00:49,  4.96it/s]
 24%|██▍       | 49/200 [00:09<00:29,  5.08it/s]

[2500] KL: loss: 15919.018


 50%|████▉     | 99/200 [00:20<00:19,  5.07it/s]

[2550] KL: loss: 15722.376


 74%|███████▍  | 149/200 [00:30<00:09,  5.18it/s]

[2600] KL: loss: 15874.764


100%|█████████▉| 199/200 [00:40<00:00,  5.42it/s]

[2650] KL: loss: 15411.560


245it [00:48,  5.01it/s]                         
  2%|▏         | 4/200 [00:01<00:50,  3.88it/s]

[2700] KL: loss: 15648.081


 27%|██▋       | 54/200 [00:10<00:27,  5.41it/s]

[2750] KL: loss: 15107.784


 52%|█████▏    | 104/200 [00:20<00:18,  5.08it/s]

[2800] KL: loss: 15272.413


 77%|███████▋  | 154/200 [00:31<00:08,  5.33it/s]

[2850] KL: loss: 15104.729


204it [00:41,  5.51it/s]                         

[2900] KL: loss: 14948.318


245it [00:48,  5.02it/s]
  4%|▍         | 9/200 [00:02<00:36,  5.24it/s]

[2950] KL: loss: 14961.644


 30%|██▉       | 59/200 [00:11<00:26,  5.41it/s]

[3000] KL: loss: 14596.683


 55%|█████▍    | 109/200 [00:21<00:18,  4.95it/s]

[3050] KL: loss: 14721.758


 80%|███████▉  | 159/200 [00:32<00:08,  4.98it/s]

[3100] KL: loss: 14506.791


209it [00:42,  5.58it/s]                         

[3150] KL: loss: 14584.994


245it [00:50,  4.88it/s]
  7%|▋         | 14/200 [00:03<00:35,  5.24it/s]

[3200] KL: loss: 14258.527


 32%|███▏      | 64/200 [00:13<00:30,  4.39it/s]

[3250] KL: loss: 14137.169


 57%|█████▋    | 114/200 [00:22<00:16,  5.07it/s]

[3300] KL: loss: 14258.106


 82%|████████▏ | 164/200 [00:33<00:07,  5.07it/s]

[3350] KL: loss: 14134.579


214it [00:43,  5.37it/s]                         

[3400] KL: loss: 14016.578


245it [00:49,  5.00it/s]
 10%|▉         | 19/200 [00:03<00:32,  5.54it/s]

[3450] KL: loss: 13810.775


 34%|███▍      | 69/200 [00:13<00:24,  5.36it/s]

[3500] KL: loss: 13808.166


 60%|█████▉    | 119/200 [00:23<00:14,  5.49it/s]

[3550] KL: loss: 13864.460


 84%|████████▍ | 169/200 [00:34<00:06,  5.00it/s]

[3600] KL: loss: 13726.614


219it [00:44,  5.24it/s]                         

[3650] KL: loss: 13719.541


245it [00:49,  4.98it/s]
 12%|█▏        | 24/200 [00:05<00:34,  5.12it/s]

[3700] KL: loss: 13647.904


 37%|███▋      | 74/200 [00:15<00:24,  5.07it/s]

[3750] KL: loss: 13341.923


 62%|██████▏   | 124/200 [00:25<00:23,  3.30it/s]

[3800] KL: loss: 13464.528


 87%|████████▋ | 174/200 [00:35<00:05,  5.07it/s]

[3850] KL: loss: 13333.379


224it [00:45,  5.14it/s]                         

[3900] KL: loss: 13288.749


245it [00:50,  4.88it/s]
 14%|█▍        | 29/200 [00:06<00:31,  5.46it/s]

[3950] KL: loss: 13153.597


 40%|███▉      | 79/200 [00:15<00:23,  5.23it/s]

[4000] KL: loss: 13076.322


 64%|██████▍   | 129/200 [00:25<00:14,  4.95it/s]

[4050] KL: loss: 13175.998


 90%|████████▉ | 179/200 [00:35<00:04,  4.97it/s]

[4100] KL: loss: 13036.190


229it [00:45,  5.05it/s]                         

[4150] KL: loss: 12971.094


245it [00:49,  4.98it/s]
 17%|█▋        | 34/200 [00:06<00:33,  5.02it/s]

[4200] KL: loss: 12774.095


 42%|████▏     | 84/200 [00:17<00:23,  5.02it/s]

[4250] KL: loss: 12947.860


 67%|██████▋   | 134/200 [00:27<00:13,  5.00it/s]

[4300] KL: loss: 12936.975


 92%|█████████▏| 184/200 [00:37<00:06,  2.44it/s]

[4350] KL: loss: 12772.173


234it [00:46,  5.14it/s]                         

[4400] KL: loss: 12780.486


245it [00:48,  5.00it/s]
 20%|█▉        | 39/200 [00:08<00:31,  5.04it/s]

[4450] KL: loss: 12759.638


 44%|████▍     | 89/200 [00:18<00:21,  5.11it/s]

[4500] KL: loss: 13427.392


 70%|██████▉   | 139/200 [00:28<00:11,  5.49it/s]

[4550] KL: loss: 13398.419


 94%|█████████▍| 189/200 [00:38<00:02,  4.58it/s]

[4600] KL: loss: 13057.318


239it [00:48,  5.08it/s]                         

[4650] KL: loss: 13011.198


245it [00:49,  4.91it/s]
 22%|██▏       | 44/200 [00:09<00:30,  5.07it/s]

[4700] KL: loss: 12757.615


 47%|████▋     | 94/200 [00:19<00:19,  5.50it/s]

[4750] KL: loss: 12737.320


 72%|███████▏  | 144/200 [00:29<00:10,  5.47it/s]

[4800] KL: loss: 12534.748


 97%|█████████▋| 194/200 [00:39<00:01,  5.22it/s]

[4850] KL: loss: 12472.723


244it [00:48,  5.07it/s]                         

[4900] KL: loss: 12484.207


245it [00:49,  5.00it/s]


## Validation

In [17]:
model_cnn.load_state_dict(torch.load('./model_cnn/model_cnn_CNNRNN_ver1.iter-4900'))

<All keys matched successfully>

In [None]:
# from sklearn.utils import shuffle as skutils 
# import cv2
dataset = DataFolder_ver3_BMS(args.train, '/home/quang/working/Theory_of_ML/hgraph2graph/train_processed_bms_ver2/ids', args.batch_size, transform, '/home/quang/working/Theory_of_ML/mbs-molecular-captioning/images/train/')


In [32]:
dataset_val = DataFolder_ver3_BMS('./val_processed_bms/mol/', '/home/quang/working/Theory_of_ML/hgraph2graph/val_processed_bms/ids', args.batch_size, transform, '/home/quang/working/Theory_of_ML/mbs-molecular-captioning/images/val/')
total_labels = []
total_preds = []
for batch, images_list, labels, _ in tqdm(dataset_val):
    images_list = images_list.to(device)
    with torch.no_grad():
        mu_vec, var_vec = model_cnn(images_list)
        z = model_vae.reparametrize_ver2(mu_vec, var_vec)
        outputs = model_vae.decode(z)
        for i in range(len(outputs)):
            sampled = outputs[i].reshape(1, 100, len(charset)).argmax(axis=2)[0].cpu()
            total_preds.append(decode_smiles_from_indexes_new(sampled, enc_drug))
    total_labels.extend(labels)
#     total_preds.extend(outputs)

114it [00:21,  5.27it/s]                         


## Levenshtein distance for SMILES code

In [33]:
import Levenshtein
def get_score(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score

score = get_score(total_labels, total_preds)
print (score)

42.912527593818986


## Validity score

In [37]:
invalid_smiles = 0
for smile_temp in tqdm(total_preds):
    m = Chem.MolFromSmiles(smile_temp)
    if m is None:
        invalid_smiles += 1
        
print (1 - invalid_smiles/len(total_preds))

100%|██████████| 3624/3624 [00:00<00:00, 37121.27it/s]

0.003863134657836609





## Visualization

In [38]:
for i in range(len(labels)):
    print ('GT: ',labels[i])
    sampled = output[i].reshape(1, 100, len(charset)).argmax(axis=2)[0].cpu()
    print ('Pred: ', decode_smiles_from_indexes_new(sampled, enc_drug))

GT:  CN1C(=O)CCN(C(=O)c2cnsn2)c2ccccc21
Pred:  CCc1(C=)OOOOO3CC33333CCC[[[CC
GT:  C#CCCCNCc1cccc(C(F)(F)F)c1
Pred:  CCcc((C=N)OOOO??CCCC[[[[[
GT:  c1ccc2c3c(ccc2c1)ONC3
Pred:  CC(C)OC?C???CCCCCCCC[[[
GT:  COCCNCC(=O)N1CCCN(Cc2ccc(Cl)cc2)CC1
Pred:  CC11(C=OOO2CCC???[[C
GT:  COCOc1ccc(-c2ccc3cc(O)ccc3c2Cc2ccc(O)cc2)cc1
Pred:  CCcc((C=))OO)))))))))))))))OO???CCCCCCCCCCCCCCC[[[[[[[[[[[[[[[[[[[[[[[[[[---------------------------
GT:  COc1ccc(/C=C/C(=O)OCCC(C)C)cc1
Pred:  CCc((C=))OCCCCCCCCCCCCCCCCCCC?????[[[[[[[
GT:  CS(=O)(=O)C1CCCC(Nc2cnn(-c3cccc(C(F)(F)F)c3)c(=O)c2Cl)C1
Pred:  CCcc(CCC)C)))))CCCCCCCCCCCC????CCCCCCCCCCCCCCCC????????????????[[[[??????????????????????????[O??[O
GT:  OC(=Nc1cnc(Oc2ccccc2)nc1)c1ccc(F)cc1F
Pred:  CC11(C)CCCCCCCCCCCCCCCCCCCCCCCCCCCCCC[[[[[[[[
GT:  O=S(=O)(c1cc(Cl)ccc1Cl)N1CC[C@H](n2ccnn2)C1
Pred:  CC(C)O))C?2C2222CCCCCCCCC??????CCC[[[
GT:  CC(C)(C)C(CCN)CCC(O)=NCCc1cccs1
Pred:  CC(CCCCCCC33333CCCCCCC[[[[[[[C
GT:  CCC(C#N)Oc1ccc(C)cc1CNCCOC
Pred:  CCc(CC)CC)CC?CC

In [None]:
from IPython.display import SVG
from rdkit import Chem
from rdkit.Chem import Draw

In [None]:
m = Chem.MolFromSmiles('Cc1ccc(C(=O)NCC2(CO)CC2)cc1C')
print (Chem.MolToInchi(m))
# SVG(Draw.MolsToGridImage([m], molsPerRow=4, subImgSize=(180, 150), useSVG=True))

In [None]:
smiles_code = 'C1#CCCC[CH2:1]C#CCC1'
m = Chem.MolFromSmiles(smiles_code)
print (Chem.MolToInchi(m))
SVG(Draw.MolsToGridImage([m], molsPerRow=4, subImgSize=(180, 150), useSVG=True))

In [None]:
smiles_code = 'C1#CCC[CH2:1]CC#CCC1'
m = Chem.MolFromSmiles(smiles_code)
print (Chem.MolToInchi(m))
SVG(Draw.MolsToGridImage([m], molsPerRow=4, subImgSize=(180, 150), useSVG=True))

In [None]:
smiles_code = 'C1#CC[CH2:1][CH2:1][CH2:1]C#CCC1'
m = Chem.MolFromSmiles(smiles_code)
print (Chem.MolToInchi(m))
SVG(Draw.MolsToGridImage([m], molsPerRow=4, subImgSize=(180, 150), useSVG=True))