In [4]:
import argparse
import sys

import numpy as np

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP

from sklearn.metrics import accuracy_score, roc_auc_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm_notebook
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
paser = argparse.ArgumentParser()
args = paser.parse_args("")
args.seed = 123
args.val_size = 0.1
args.test_size = 0.1
args.shuffle = True

In [3]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7f875e8f8330>

# 1. Pre-Processing

In [12]:
def read_ZINC_smiles(file_name, num_mol):
    f = open(file_name, 'r')
    contents = f.readlines()
    
    smi_list = list()
    logP_list = list()
    
    for i in tqdm_notebook(range(num_mol), desc='Reading Data'):
        smi = contents[i].strip()
        m = Chem.MolFromSmiles(smi)
        smi_list.append(smi)
        logP_list.append(MolLogP(m))
        
    logP_list = np.asarray(logP_list).astype(float)
    
    return smi_list, logP_list

def smiles_to_onehot(smi_list):
    def smiles_to_vector(smiles, vocab, max_length):
        while len(smiles)<max_length:
            smiles +=" "
        vector = [vocab.index(str(x)) for x in smiles]
        one_hot = np.zeros((len(vocab), max_length), dtype=int)
        for i, elm in enumerate(vector):
            one_hot[elm][i] = 1
        return one_hot
        
    vocab = np.load('./vocab.npy')
    smi_total = []
    for i, smi in tqdm_notebook(enumerate(smi_list), desc='Converting Data'):
        smi_onehot = smiles_to_vector(smi, list(vocab), 120)
        smi_total.append(smi_onehot)

    return np.asarray(smi_total)

class OneHotLogPDataSet(Dataset):
    def __init__(self, list_one_hot, list_logP):
        self.list_one_hot = list_one_hot
        self.list_logP = list_logP
        
    def __len__(self):
        return len(self.list_one_hot)
    
    def __getitem__(self, index):
        return self.list_one_hot[index], self.list_logP[index]
    
def partition(list_one_hot, list_logP, args):
    num_total = list_one_hot.shape[0]
    num_train = int(num_total*(1-args.test_size-args.val_size))
    num_val = int(num_total*args.val_size)
    num_test = int(num_total*args.test_size)
    
    one_hot_train = list_one_hot[:num_train]
    logP_train = list_logP[:num_train]
    one_hot_val = list_one_hot[num_train:num_train+num_val]
    logP_val = list_logP[num_train:num_train+num_val]
    one_hot_test = list_one_hot[num_total-num_test:]
    logP_test = list_logP[num_total-num_test:]
    
    train_set = OneHotLogPDataSet(one_hot_train, logP_train)
    val_set = OneHotLogPDataSet(one_hot_val, logP_val)
    test_set = OneHotLogPDataSet(one_hot_test, logP_test)
    
    partition = {
        'train' : train_set,
        'val' : val_set,
        'test' : test_set
    }
    
    return partition

In [14]:
list_smi, list_logP = read_ZINC_smiles('ZINC.smiles', 50000)
list_one_hot = smiles_to_onehot(list_smi)
dict_partition = partition(list_one_hot, list_logP, args)

HBox(children=(IntProgress(value=0, description='Reading Data', max=50000), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', description='Converting Data', max=1), HTML(value='')))




# 2. Model Construction

In [9]:
class SkipConnectionBlock(nn.Module):
    
    def __init__(self, in_planes, planes):
        super(SkipConnectionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, 
                               planes, 
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        #torch.nn.init.xavier_uniform(self.conv1.weight)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        
        return out

In [10]:
class BasicConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    

class InceptionBlock(nn.Module):
    '''(32,120)->(15, 59)'''
    def __init__(self, in_channels):
        super(InceptionBlock, self).__init__()
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
        
        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
        
        self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2)
        
    def forward(self, x):
        branch3x3 = self.branch3x3(x)
        
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3db1_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
        
        branch_pool = self.branch_pool(x)
        
        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

In [11]:
class Flatten(nn.Module):
    
    def forward(self, x):
        return x.view(x.size(0), -1)

In [12]:
args.batch_size = 100
args.lr = 0.001
args.l2_coef = 0.001
args.optim = optim.Adam
args.criterion = nn.MSELoss()
args.epoch = 10
args.device = 'cuda'

In [16]:
model = nn.Sequential(SkipConnectionBlock(1, 64),
                      Flatten(),
                      nn.Linear(245760, 1))
model.to(args.device)

list_train_loss = list()
list_val_loss = list()
acc = 0

optimizer = args.optim(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.l2_coef)

data_train = DataLoader(args.dict_partition['train'], 
                        batch_size=args.batch_size,
                        shuffle=args.shuffle)
print("Loaded data for training")

data_val = DataLoader(args.dict_partition['val'],
                      batch_size=args.batch_size,
                      shuffle=args.shuffle)
print("Loaded data for validation")

increment = args.epoch//40
point = args.epoch//100
for epoch in range(args.epoch):
    model.train()
    epoch_train_loss = 0
    for i, batch in enumerate(data_train):
        one_hots = torch.tensor(np.expand_dims(batch[0], axis=1),
                                dtype=torch.float,
                                device=args.device)
        logPs = torch.tensor(batch[1],
                             dtype=torch.float,
                             args.device)
        logPs = logPs.view(-1, 1)
        
        optimizer.zero_grad()
        pred_logPs = model(one_hots)
        pred_logPs.require_grad = False
        train_loss = args.criterion(pred_logPs, logPs)
        epoch_train_loss += train_loss.item()
        train_loss.backward()
        optimizer.step()
        
        print("Epoch: ", epoch, "\tbatch: ", i, "\tTraining")
        
    list_train_loss.append(epoch_train_loss/len(data_train))
    
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for i, batch in enumertae(data_val):
            one_hots = torch.tensor(np.expand_dims(batch[0], axis=1),
                                dtype=torch.float,
                                device=args.device)
            logPs = torch.tensor(batch[1],
                                 dtype=torch.float,
                                 args.device)
            logPs = logPs.view(-1, 1)
            
            pred_logPs = model(one_hots)
            val_loss = args.criterion(pred_logPs, logPs)
            epoch_val_loss += val_loss.item()
            
            print("Epoch: ", epoch, "\tbatch: ", i, "\tValidating")
    
    list_val_loss.append(epoch_val_loss/len(data_val))
    
    sys.stdout.write("\r["+"="*(i//increment)+" "*((args.epoch-i)//increment)+"]"+str(i/point)+"%")
    sys.stdout.flush()
    

data_test = DataLoader(args.dict_partition['test'],
                      batch_size=args.batch_size,
                      shuffle=args.shuffle)
    
model.eval()
with torch.no_grad():
    list_logP = list()
    list_pred_logP = list()
    for i, batch in enumerate(data_test):
        one_hots = torch.tensor(np.expand_dims(batch[0], axis=1),
                                dtype=torch.float,
                                device=args.device)
        logPs = torch.tensor(batch[1],
                             dtype=torch.float,
                             device=args.device)
        logPs = logPs.view(-1, 1)
        
        pred_logPs = model(one_hots)
        
        list_logP += torch.squeeze(toxs).tolist()
        list_pred_logP += pred_logPs.tolist()
        
    acc = accuracy_score(list_logP, list_pred_logP)

SyntaxError: positional argument follows keyword argument (<ipython-input-16-45bc14d39670>, line 35)

In [112]:
data_train = DataLoader(args.dict_partition['train'], 
                        batch_size=args.batch_size,
                        shuffle=args.shuffle)
one_hots = list()
logPs = list()
for i, batch in enumerate(data_train):
    if i==0:
        one_hots = torch.tensor(np.expand_dims(batch[0], axis=1),
                                dtype=torch.float,
                                device=args.device)
        logPs = torch.tensor(batch[1],
                             dtype=torch.float,
                             device=args.device)

In [113]:
one_hots.shape, logPs.shape

(torch.Size([100, 1, 32, 120]), torch.Size([100]))