In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import copy
import math
import sys
import time


import numpy as np

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm_notebook
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [82]:
paser = argparse.ArgumentParser()
args = paser.parse_args("")
args.seed = 123
args.max_mol = 11100
args.max_peaks = 150
args.max_atoms = 250
args.max_partial_charge = 4.0
args.min_partial_charge = -1.0
args.num_feature = 59
args.val_size = 0.1
args.test_size = 0.1
args.shuffle = True

In [83]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# 1. Pre-Processing

In [146]:
# 134 max peaks, 23 mean peaks
# max 209 atoms per molecule
# total 11100 molecules

def read_nmrDB_pickle(file_name, num_mol, max_peak):
    df = pd.read_pickle(file_name).head(num_mol)
    
    inchi_list = df['inchi'].tolist()
    peaks_list = df['peaks'].tolist()
    
    for i, inchi in enumerate(inchi_list):
        if '\n' in inchi:
            inchi_list[i] = inchi.split('\n')[0]
    
    shift_list = list()
    intensity_list = list()
    for peaks in peaks_list:
        shifts = [0] * max_peak
        intensities = [0] * max_peak
        for i, peak in enumerate(peaks):
            shifts[i] = float(peak[1])
            intensities[i] = float(peak[2])
        shift_list.append(shifts)
        intensity_list.append(intensities)
    
    return inchi_list, shift_list, intensity_list

def convert_inchi_to_graph(inchi_list, max_atoms):
    adj = list()
    adj_norm = list()
    features = list()
    for inchi in inchi_list:
        # Generate mol from InChI code.
        iMol = Chem.inchi.MolFromInchi(inchi)
        # Add H atoms to the mol.
        iMol = Chem.rdmolops.AddHs(iMol)
        iAdjTmp = Chem.rdmolops.GetAdjacencyMatrix(iMol)
        if (iAdjTmp.shape[0] <= max_atoms):
            # Preprocess features
            iFeature = np.zeros((max_atoms, num_feature))
            iFeatureTmp = []
            AllChem.ComputeGasteigerCharges(iMol)
            for atom in iMol.GetAtoms():
                iFeatureTmp.append(atom_feature(atom))
            iFeature[0:len(iFeatureTmp), 0:num_feature] = iFeatureTmp
            features.append(iFeature)
            # Preprocess adjacency matrix
            iAdj = np.zeros((max_atoms, max_atoms))
            iAdj[0:len(iFeatureTmp), 0:len(iFeatureTmp)] = iAdjTmp + np.eye(len(iFeatureTmp))
            adj.append(np.asarray(iAdj))
    features = np.asarray(features)
    
    return features, adj

def normalized_partial_charge_of_atom(atom):
    partial_charge = float(atom.GetProp("_GasteigerCharge"))
    partial_charge = (partial_charge-min_partial_charge)/(max_partial_charge-min_partial_charge)
    return partial_charge

def atom_feature(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                      ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                                       'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                                       'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                                       'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()] +
                    [normalized_partial_charge_of_atom(atom)])    # (40, 6, 5, 6, 1, 1)

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

In [147]:
class NMRDataset(Dataset):
    
    def __init__(self, list_feature, list_adj, list_shift, list_intensity):
        self.list_feature = list_feature
        self.list_adj = list_adj
        self.list_shift = list_shift
        self.list_intensity = list_intensity
        
    def __len__(self):
        return len(self.list_feature)
    
    def __getitem__(self, index):
        feature = self.list_feature[index]
        adj = self.list_adj[index]
        shift = self.list_shift[index]
        intensity = self.list_intensity[index]
        return feature, adj, shift, intensity
    
def partition(list_feature, list_adj, list_shift, list_intensity, args):
    num_total = list_feature.shape[0]
    num_train = int(num_total * (1 - args.test_size - args.val_size))
    num_val = int(num_total * args.val_size)
    num_test = int(num_total * args.test_size)
    
    feature_train = list_feature[:num_train]
    adj_train = list_adj[:num_train]
    shift_train = list_shift[:num_train]
    intensity_train = list_intensity[:num_train]
    feature_val = list_feature[num_train:num_train + num_val]
    adj_val = list_adj[num_train:num_train + num_val]
    shift_val = list_shift[num_train:num_train + num_val]
    intensity_val = list_intensity[num_train:num_train + num_val]
    feature_test = list_feature[num_total - num_test:]
    adj_test = list_adj[num_total - num_test:]
    shift_test = list_shift[num_total - num_test:]
    intensity_test = list_intensity[num_total - num_test:]
    
    train_set = NMRDataset(feature_train, adj_train, shift_train, intensity_train)
    val_set = NMRDataset(feature_val, adj_val, shift_val, intensity_val)
    test_set = NMRDataset(feature_test, adj_test, shift_test, intensity_test)
    
    partition = {
        'train': train_set,
        'val': val_set,
        'test': test_set
    }

    return partition

In [148]:
list_inchi, list_shift, list_intensity = read_nmrDB_pickle('nmrDB_deduplicated.pkl', 1000, args.max_peaks)
list_feature, list_adj = convert_inchi_to_graph(list_inchi, args.max_atoms)
dict_partition = partition(list_feature, list_adj, list_shift, list_intensity, args)

# 2. Model Construction

In [89]:
class SkipConnection(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super(SkipConnection, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        
    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        out = in_x + out_x
        return out

In [90]:
class GatedSkipConnection(nn.Module):
    
    def __init__(self, in_dim, out_dim):
        super(GatedSkipConnection, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(in_dim, out_dim, bias=False)
        self.linear_coef_in = nn.Linear(out_dim, out_dim)
        self.linear_coef_out = nn.Linear(out_dim, out_dim)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        z = self.gate_coefficient(in_x, out_x)
        out = torch.mul(z, out_x) + torch.mul(1.0-z, in_x)
        return out
            
    def gate_coefficient(self, in_x, out_x):
        x1 = self.linear_coef_in(in_x)
        x2 = self.linear_coef_out(out_x)
        return self.sigmoid(x1+x2)

In [91]:
class Attention(nn.Module):
    
    def __init__(self, in_dim, output_dim, num_head):
        super(Attention, self).__init__()
        
        self.num_head = num_head
        self.atn_dim = output_dim // num_head
        
        self.linears = nn.ModuleList()
        self.corelations = nn.ParameterList()
        for i in range(self.num_head):
            self.linears.append(nn.Linear(in_dim, self.atn_dim))
            corelation = torch.FloatTensor(self.atn_dim, self.atn_dim)
            nn.init.xavier_uniform_(corelation)
            self.corelations.append(nn.Parameter(corelation))
            
        self.tanh = nn.Tanh()
        
    def forward(self, x, adj):
        heads = list()
        for i in range(self.num_head):
            x_transformed = self.linears[i](x)
            alpha = self.attention_matrix(x_transformed, self.corelations[i], adj)
            x_head = torch.matmul(alpha, x_transformed)
            heads.append(x_head)
        output = torch.cat(heads, dim=2)
        return output
            
    def attention_matrix(self, x_transformed, corelation, adj):
        x = torch.einsum('akj,ij->aki', (x_transformed, corelation))
        alpha = torch.matmul(x, torch.transpose(x_transformed, 1, 2))
        alpha = torch.mul(alpha, adj)
        alpha = self.tanh(alpha)
        return alpha

In [92]:
class GCNLayer(nn.Module):
    
    def __init__(self, in_dim, out_dim, n_atom, act=None, bn=False, atn=False, num_head=1, dropout=0):
        super(GCNLayer, self).__init__()
        
        self.use_bn = bn
        self.use_atn = atn
        self.linear = nn.Linear(in_dim, out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.bn = nn.BatchNorm1d(n_atom)
        self.attention = Attention(out_dim, out_dim, num_head)
        self.activation = act
        self.dropout_rate = dropout
        self.dropout = nn.Dropout2d(self.dropout_rate)
        
    def forward(self, x, adj):
        out = self.linear(x)
        if self.use_atn:
            out = self.attention(out, adj)
        else:
            out = torch.matmul(adj, out)
        if self.use_bn:
            out = self.bn(out)
        if self.activation != None:
            out = self.activation(out)
        if self.dropout_rate > 0:
            out = self.dropout(out)
        return out, adj

In [94]:
class GCNBlock(nn.Module):
    
    def __init__(self, n_layer, in_dim, hidden_dim, out_dim, n_atom, bn=True, atn=True, num_head=1, sc='gsc', dropout=0):
        super(GCNBlock, self).__init__()
        
        self.layers = nn.ModuleList()
        for i in range(n_layer):
            self.layers.append(GCNLayer(in_dim if i==0 else hidden_dim,
                                        out_dim if i==n_layer-1 else hidden_dim,
                                        n_atom,
                                        nn.ReLU() if i!=n_layer-1 else None,
                                        bn,
                                        atn,
                                        num_head,
                                        dropout))
        self.relu = nn.ReLU()
        if sc=='gsc':
            self.sc = GatedSkipConnection(in_dim, out_dim)
        elif sc=='sc':
            self.sc = SkipConnection(in_dim, out_dim)
        elif sc=='no':
            self.sc = None
        else:
            assert False, "Wrong sc type."
        
    def forward(self, x, adj):
        residual = x
        for i, layer in enumerate(self.layers):
            out, adj = layer((x if i==0 else out), adj)
        if self.sc != None:
            out = self.sc(residual, out)
        out = self.relu(out)
        return out, adj

In [95]:
class ReadOut(nn.Module):
    
    def __init__(self, in_dim, out_dim, act=None):
        super(ReadOut, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim= out_dim
        
        self.linear = nn.Linear(self.in_dim, 
                                self.out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.activation = act

    def forward(self, x):
        out = self.linear(x)
        out = torch.sum(out, 1)
        if self.activation != None:
            out = self.activation(out)
        return out

In [96]:
class Predictor(nn.Module):
    
    def __init__(self, in_dim, out_dim, act=None):
        super(Predictor, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.linear = nn.Linear(self.in_dim,
                                self.out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.activation = act
        
    def forward(self, x):
        out = self.linear(x)
        if self.activation != None:
            out = self.activation(out)
        return out

In [120]:
class GCNNet(nn.Module):
    
    def __init__(self, args):
        super(GCNNet, self).__init__()
        
        self.blocks = nn.ModuleList()
        for i in range(args.n_block):
            self.blocks.append(GCNBlock(args.n_layer,
                                        args.in_dim if i==0 else args.hidden_dim,
                                        args.hidden_dim,
                                        args.hidden_dim,
                                        args.n_atom,
                                        args.bn,
                                        args.atn,
                                        args.num_head,
                                        args.sc,
                                        args.dropout))
        self.readout = ReadOut(args.hidden_dim, 
                               args.readout_dim,
                               act=nn.ReLU())
        self.shift_pred1 = Predictor(args.readout_dim,
                               args.shift_pred_dim1,
                               act=nn.ReLU())
        self.shift_pred2 = Predictor(args.shift_pred_dim1,
                               args.shift_pred_dim2,
                               act=nn.ReLU())
        self.shift_pred3 = Predictor(args.shift_pred_dim2,
                               args.out_dim)
        self.intensity_pred1 = Predictor(args.readout_dim,
                                         args.intensity_pred_dim1,
                                         act=nn.ReLU())
        self.intensity_pred2 = Predictor(args.intensity_pred_dim1,
                                         args.intensity_pred_dim2,
                                         act=nn.ReLU())
        self.intensity_pred3 = Predictor(args.intensity_pred_dim2,
                                         args.out_dim)
        
    def forward(self, x, adj):
        for i, block in enumerate(self.blocks):
            out, adj = block((x if i==0 else out), adj)
        out = self.readout(out)
        shift_out = self.shift_pred1(out)
        shift_out = self.shift_pred2(shift_out)
        shift_out = self.shift_pred3(shift_out)
        intensity_out = self.intensity_pred1(out)
        intensity_out = self.intensity_pred2(intensity_out)
        intensity_out = self.intensity_pred3(intensity_out)
        return shift_out, intensity_out

# 3. Train, Validate, and Test

In [172]:
def train(model, device, optimizer, criterion, data_train, bar, args):
    epoch_train_loss = 0
    for i, batch in enumerate(data_train):
        # [batch_size, max_atom, num_feature], [100, 250, 59]
        list_feature = torch.tensor(batch[0]).to(device).float()
        list_adj = torch.tensor(batch[1]).to(device).float()
        # [batch_size, max_peak]
        # [150, 100]
        print(len(batch[2]), len(batch[2][0]))
        list_shift = torch.tensor(batch[2]).to(device)
        list_intensity = torch.tensor(batch[3]).to(device)
        list_shift = list_shift.view(-1,1)
        list_intensity = list_intensity.view(-1,1)        
        
        model.train()
        optimizer.zero_grad()
        
        list_pred_shift, list_pred_intensity = model(list_feature, list_adj)
        list_pred_shift.require_grad = False
        list_pred_intensity.require_grad = False
        
        train_shift_loss = criterion(list_pred_shift, list_shift)
        train_intensity_loss = criterion(list_pred_intensity, list_intensity)
        
        train_loss = train_shift_loss + train_intensity_loss
        epoch_train_loss += train_loss.item()
        
        train_loss.backward()
        optimizer.step()
       
        bar.update(len(list_feature))

    epoch_train_loss /= len(data_train)
    
    return model, epoch_train_loss

In [161]:
def validate(model, device, optimizer, criterion, data_val, bar, args):
    epoch_val_loss = 0
    for i, batch in enumerate(data_val):
        list_feature = torch.tensor(batch[0]).to(device).float()
        list_adj = torch.tensor(batch[1]).to(device).float()
        list_shift = torch.tensor(batch[2]).to(device).float()
        list_intensity = torch.tensor(batch[3]).to(device).float()
        list_shift = list_shift.view(-1,1)
        list_intensity = list_intensity.view(-1,1)        
        
        model.eval()
        
        list_pred_shift, list_pred_intensity = model(list_feature, list_adj)
        list_pred_shift.require_grad = False
        list_pred_intensity.require_grad = False
        
        val_shift_loss = criterion(list_pred_shift, list_shift)
        val_intensity_loss = criterion(list_pred_intensity, list_intensity)
        
        val_loss = val_shift_loss + val_intensity_loss
        epoch_val_loss += val_loss.item()
       
        bar.update(len(list_feature))

    epoch_val_loss /= len(data_train)
    
    return model, epoch_val_loss

In [162]:
def test(model, device, data_test, args):
    model.eval()
    with torch.no_grad():
        shift_total = list()
        pred_shift_total = list()
        intensity_total = list()
        pred_intensity_total = list()
        for i, batch in enumerate(data_test):
            list_feature = torch.tensor(batch[0]).to(device).float()
            list_adj = torch.tensor(batch[1]).to(device).float()
            list_shift = torch.tensor(batch[2]).to(device).float()
            list_intensity = torch.tensor(batch[3]).to(device).float()
            shift_total.append(list_shift.tolist())
            intensity_total.append(list_intensity.tolist())
            list_shift = list_shift.view(-1,1)
            list_intensity = list_intensity.view(-1,1)
            
            list_pred_shift, list_pred_intensity = model(list_feature, list_adj)
            pred_shift_total.append(list_pred_shift.view(-1).tolist())
            pred_intensity_total.append(list_pred_intensity.view(-1).tolist())
        
        mae_shift = 0
        mae_intensity = 0
        for i in range(len(shift_total)):
            mae_shift += mean_absolute_error(shift_total[i], pred_shift_total[i])/len(shift_total[i])
            mae_intensity += mean_absolute_error(intensity_total[i], pred_intensity_total[i])/len(intensity_total[i])
        mae_shift /= len(shift_total)
        mae_intensity /= len(intensity_total)
        
    return mae_shift, mae_intensity, shift_total, pred_shift_total, intensity_total, pred_intensity_total

In [104]:
def experiment(dict_partition, device, bar, args):
    time_start = time.time()
    
    model = GCNNet(args)
    model.to(device)
    
    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.l2_coef)
    elif args.optim == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.l2_coef)
    elif args.optim == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.l2_coef)
    else:
        assert False, 'Undefined Optimizer Type'
        
    criterion = nn.MSELoss()
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=args.gamma)
    
    list_train_loss = list()
    list_val_loss = list()

    data_train = DataLoader(dict_partition['train'], 
                            batch_size=args.batch_size,
                            shuffle=args.shuffle)

    data_val = DataLoader(dict_partition['val'],
                          batch_size=args.batch_size,
                          shuffle=args.shuffle)
    
    for epoch in range(args.epoch):
        scheduler.step()
        model, train_loss = train(model, device, optimizer, criterion, data_train, bar, args)
        list_train_loss.append(train_loss)
        
        model, val_loss = validate(model, device, criterion, data_val, bar, args)
        list_val_loss.append(val_loss)
        
    data_test = DataLoader(dict_partition['test'],
                           batch_size=args.batch_size,
                           shuffle=args.shuffle)

    mae_shift, mae_intensity, shift_total, pred_shift_total, intensity_total, pred_intensity_total = test(model, device, data_test, args)
    
    time_end = time.time()
    time_required = time_end - time_start
    
    args.list_train_loss = list_train_loss
    args.list_val_loss = list_val_loss
    args.shift_total = shift_total
    args.pred_shift_total = pred_shift_total
    args.intensity_total = intensity_total
    args.pred_intensity_total = pred_intensity_total
    args.mae_shift = mae_shift
    args.mae_intensity = mae_intensity
    args.time_required = time_required
    
    return args

In [149]:
args.batch_size = 100
args.lr = 0.0001
args.l2_coef = 0
args.optim = 'Adam'
args.epoch = 30
args.n_block = 2
args.n_layer = 2
args.n_atom = args.max_atoms
args.in_dim = args.num_feature
args.hidden_dim = 64
args.readout_dim = 256
args.shift_pred_dim1 = 256
args.shift_pred_dim2 = 128
args.intensity_pred_dim1 = 256
args.intensity_pred_dim2 = 128
args.out_dim = args.max_peaks
args.bn = True
args.sc = 'no'
args.atn = False
args.num_head = 16
args.dropout = 0
args.step_size = 10
args.gamma = 0.1

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [171]:
dict_result = dict()
n_iter = args.epoch*(len(dict_partition['train'])+len(dict_partition['val']))
bar = tqdm_notebook(total=n_iter, file=sys.stdout, position=0)

args.exp_name = "result"
result = vars(experiment(dict_partition, device, bar, args))
dict_result[args.exp_name] = copy.deepcopy(result)

torch.cuda.empty_cache()
bar.close()

df_result = pd.DataFrame(dict_result).transpose()
df_result.to_json('result.JSON', orient='table')

HBox(children=(IntProgress(value=0, max=27000), HTML(value='')))

100 250
150 100


ValueError: only one element tensors can be converted to Python scalars