In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import sys
import time
import copy

import numpy as np

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Crippen import MolLogP

from sklearn.metrics import mean_absolute_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from tqdm import tnrange, tqdm_notebook
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
paser = argparse.ArgumentParser()
args = paser.parse_args("")
args.seed = 123
args.val_size = 0.1
args.test_size = 0.1
args.shuffle = True

In [4]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# 1. Pre-Processing

In this step, smiles data is read from the ZINC.smiles file and converted to a pytorch Dataset which contain feature tensor, adjacency matrix, and logP value of each molecules. A smiles string is converted to feature tensor as follows:
<img src="files/Graph_Generating_Process.png">

In [4]:
def read_ZINC_smiles(file_name, num_mol):
    f = open(file_name, 'r')
    contents = f.readlines()

    smi_list = []
    logP_list = []

    for i in tqdm_notebook(range(num_mol), desc='Reading Data'):
        smi = contents[i].strip()
        m = Chem.MolFromSmiles(smi)
        smi_list.append(smi)
        logP_list.append(MolLogP(m))

    logP_list = np.asarray(logP_list).astype(float)

    return smi_list, logP_list


def smiles_to_onehot(smi_list):
    def smiles_to_vector(smiles, vocab, max_length):
        while len(smiles) < max_length:
            smiles += " "
        vector = [vocab.index(str(x)) for x in smiles]
        one_hot = np.zeros((len(vocab), max_length), dtype=int)
        for i, elm in enumerate(vector):
            one_hot[elm][i] = 1
        return one_hot

    vocab = np.load('./vocab.npy')
    smi_total = []

    for i, smi in tqdm_notebook(enumerate(smi_list), desc='Converting to One Hot'):
        smi_onehot = smiles_to_vector(smi, list(vocab), 120)
        smi_total.append(smi_onehot)

    return np.asarray(smi_total)

def convert_to_graph(smiles_list):
    adj = []
    adj_norm = []
    features = []
    maxNumAtoms = 50
    for i in tqdm_notebook(smiles_list, desc='Converting to Graph'):
        # Mol
        iMol = Chem.MolFromSmiles(i.strip())
        #Adj
        iAdjTmp = Chem.rdmolops.GetAdjacencyMatrix(iMol)
        # Feature
        if( iAdjTmp.shape[0] <= maxNumAtoms):
            # Feature-preprocessing
            iFeature = np.zeros((maxNumAtoms, 58))
            iFeatureTmp = []
            for atom in iMol.GetAtoms():
                iFeatureTmp.append( atom_feature(atom) ) ### atom features only
            iFeature[0:len(iFeatureTmp), 0:58] = iFeatureTmp ### 0 padding for feature-set
            features.append(iFeature)

            # Adj-preprocessing
            iAdj = np.zeros((maxNumAtoms, maxNumAtoms))
            iAdj[0:len(iFeatureTmp), 0:len(iFeatureTmp)] = iAdjTmp + np.eye(len(iFeatureTmp))
            adj.append(np.asarray(iAdj))
    features = np.asarray(features)

    return features, adj
    
def atom_feature(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                      ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                                       'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                                       'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                                       'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])    # (40, 6, 5, 6, 1)

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


class GCNDataset(Dataset):
    def __init__(self, list_feature, list_adj, list_logP):
        self.list_feature = list_feature
        self.list_adj = list_adj
        self.list_logP = list_logP

    def __len__(self):
        return len(self.list_feature)

    def __getitem__(self, index):
        return self.list_feature[index], self.list_adj[index], self.list_logP[index]


def partition(list_feature, list_adj, list_logP, args):
    num_total = list_feature.shape[0]
    num_train = int(num_total * (1 - args.test_size - args.val_size))
    num_val = int(num_total * args.val_size)
    num_test = int(num_total * args.test_size)

    feature_train = list_feature[:num_train]
    adj_train = list_adj[:num_train]
    logP_train = list_logP[:num_train]
    feature_val = list_feature[num_train:num_train + num_val]
    adj_val = list_adj[num_train:num_train + num_val]
    logP_val = list_logP[num_train:num_train + num_val]
    feature_test = list_feature[num_total - num_test:]
    adj_test = list_adj[num_total - num_test:]
    logP_test = list_logP[num_total - num_test:]
        
    train_set = GCNDataset(feature_train, adj_train, logP_train)
    val_set = GCNDataset(feature_val, adj_val, logP_val)
    test_set = GCNDataset(feature_test, adj_test, logP_test)

    partition = {
        'train': train_set,
        'val': val_set,
        'test': test_set
    }

    return partition

In [5]:
list_smi, list_logP = read_ZINC_smiles('ZINC.smiles', 1000)
list_feature, list_adj = convert_to_graph(list_smi)
dict_partition = partition(list_feature, list_adj, list_logP, args)

HBox(children=(IntProgress(value=0, description='Reading Data', max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, description='Converting to Graph', max=1000), HTML(value='')))




# 2. Model Construction

The message passing neural network (MPNN) framework updates a node with following formula.

$$ H_i^{(l+1)} = U(H_i^{(l)}, m^{(l+1)}) $$

The i-th node is updated through the message state, $m^{(l+1)}$ from the adjacent nodes and previous node state, $H^{(l)}$. In general, a message state is updated as follows:

$$ m^{(l+1)} = \sum_{j \in N_i} M(H_i^{(l)}, H_j^{(l)}, e_{ij}) $$

From the inital edge information, $ e_{ij} $, a message state can be updated differently for different relations, such as a single bond, double bond and aromatic bond.

In this GGNN framework, only the connectivity between the nodes will be considered. Also, for simplicity, the message state will be defined as a simple summation of the adjacent node states as follows:

$$ m^{(l+1)} = \sum_{j \in N_i} H_j^{(l)} $$

The gated recurrent unit (GRU) will be used to update nodes. Thus the node state will be defined as follows:

$$ H_i^{(l+1)} = GRU(H_i^{(l)}, \sum_{j \in N_i} H_j^{(l)}) $$

In [None]:
class GatedSkipConnection(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super(GatedSkipConnection, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        self.linear_coef_in = nn.Linear(out_dim, out_dim)
        self.linear_coef_out = nn.Linear(out_dim, out_dim)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        z = self.gate_coefficient(in_x, out_x)
        out = torch.mul(z, out_x) + torch.mul(1.0-z, in_x)
        return out
            
    def gate_coefficient(self, in_x, out_x):
        x1 = self.linear_coef_in(in_x)
        x2 = self.linear_coef_out(out_x)
        return self.sigmoid(x1+x2)

In [44]:
class GGNNLayer(nn.Module):
    
    def __init__(self, in_dim, out_dim, n_layer=1, dropout=0, bidirectional=False):
        super(GGNNLayer, self).__init__()
        
        self.linear = None
        if in_dim != out_dim:
            self.linear = nn.Linear(in_dim, out_dim, bias=False).float()
        self.gru = nn.GRU(input_size=in_dim,
                          hidden_size=out_dim,
                          num_layers=n_layer,
                          dropout=dropout,
                          bidirectional=bidirectional)
        self.gruCell = nn.GRUCell(input_size=out_dim,
                                  hidden_size=out_dim)
            
    def forward(self, x, adj):
        num_nodes = x.shape[1]
        # x: (batch_size, max_atom, out_dim)
        if not self.linear is None:
            x = self.linear(x)
        
        # The message state is defined as the sum of state of adjacent nodes
        # m: (batch_size, max_atom, out_dim)
        m = torch.matmul(adj, x)
        
        messages = list()
        for i in range(num_nodes):
            # (batch_size, out_dim)
            m_i = m[:, i, :].squeeze()
            h_i = x[:, i, :].squeeze()
            # (batch_size, out_dim)
            m_i_new = self.gruCell(m_i, h_i)
            messages.append(torch.unsqueeze(m_i_new, 1))
        out = torch.cat(messages, 1)
        return out

In [None]:
class Predictor(nn.Module):
    
    def __init__(self, in_dim, out_dim, act=None):
        super(Predictor, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(self.in_dim,
                                self.out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.activation = act
        
    def forward(self, x):
        out = self.linear(x)
        if self.activation != None:
            out = self.activation(out)
        return out

In [None]:
class GGNNBlock(nn.Module):
    
    def __init__(self, )

In [45]:
args.batch_size = 128
args.in_dim = 58
args.out_dim = 64
args.n_layer = 1
args.dropout = 0
args.bidirectional = False

In [46]:
data_train = DataLoader(dict_partition['train'], batch_size=args.batch_size, shuffle=args.shuffle)

In [47]:
model = GGNNLayer(args.in_dim, args.out_dim, args.n_layer, args.dropout, args.bidirectional)
model.cuda()

GGNNLayer(
  (linear): Linear(in_features=58, out_features=64, bias=False)
  (gru): GRU(58, 64)
  (gruCell): GRUCell(64, 64)
)

In [50]:
for i, batch in enumerate(data_train):
    if i == 0:
        list_feature = torch.tensor(batch[0]).cuda().float()
        list_adj = torch.tensor(batch[1]).cuda().float()
        list_logP = torch.tensor(batch[2]).cuda().float()
        
        out = model(list_feature, list_adj)
        break

In [51]:
out.dtype

torch.float32

In [52]:
list_feature.dtype

torch.float32