In [1]:
import os

In [9]:
import numpy as np
import pandas as pd

In [46]:
from ipywidgets import FloatProgress
from IPython.display import display

Widgets to display progress bars.

In [53]:
from collections import OrderedDict

Two OrderedDicts will represent the molecule

In [54]:

import deepchem as dc
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from deepchem.feat.graph_features import *

In [55]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [56]:
N = 133885 # number of molecules in the dataset
D = 75     # hidden dimension of each atom
E = 6      # dimension of each edge
T = 3      # number of time steps the message phase will run for
P = 32     # dimensions of the output from the readout phase, the penultimate output before the target layer
V = 12     # dimensions of the molecular targets or tasks

N = TRAIN_SIZE + VALID_SIZE + TEST_SIZE + delta

The delta can be incorporated by having the final batch size smaller than the rest.

In [57]:
TRAIN_SIZE = 113880
VALID_SIZE = 10000
TEST_SIZE  = 10000
BATCH_SIZE = 20
NUM_EPOCHS = 7

LR is the initial learning rate.
DF is the decay factor.
LF is the final learning rate.

In [58]:
DF = np.random.uniform(0.01, 1)
LR = np.random.uniform(1e-5, 5e-4)
LF = DF * LR

In [59]:
print('decay factor          : %.6f'%(DF))
print('initial learning rate : %.6f'%(LR))
print('final learning rate   : %.6f'%(LF))

decay factor          : 0.633189
initial learning rate : 0.000223
final learning rate   : 0.000141


Get the qm9.csv file from this url : http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/qm9.csv

In [60]:
qm9 = pd.read_csv('qm9.csv')

In [61]:
qm9.head()

Unnamed: 0,mol_id,smiles,A,B,C,mu,alpha,homo,lumo,gap,...,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom
0,gdb_1,C,157.7118,157.70997,157.70699,0.0,13.21,-0.3877,0.1171,0.5048,...,0.044749,-40.47893,-40.476062,-40.475117,-40.498597,6.469,-395.999595,-398.64329,-401.014647,-372.471772
1,gdb_2,N,293.60975,293.54111,191.39397,1.6256,9.46,-0.257,0.0829,0.3399,...,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316,-276.861363,-278.620271,-280.399259,-259.338802
2,gdb_3,O,799.58812,437.90386,282.94545,1.8511,6.31,-0.2928,0.0687,0.3615,...,0.021375,-76.404702,-76.401867,-76.400922,-76.422349,6.002,-213.087624,-213.974294,-215.159658,-201.407171
3,gdb_4,C#C,0.0,35.610036,35.610036,0.0,16.28,-0.2845,0.0506,0.3351,...,0.026841,-77.308427,-77.305527,-77.304583,-77.327429,8.574,-385.501997,-387.237686,-389.016047,-365.800724
4,gdb_5,C#N,0.0,44.593883,44.593883,2.8937,12.99,-0.3604,0.0191,0.3796,...,0.016601,-93.411888,-93.40937,-93.408425,-93.431246,6.278,-301.820534,-302.906752,-304.091489,-288.720028


The chemical accuracies for various targets are from Table 5 of Gilmer et. al.

In [62]:
chemical_accuracy_dict = {'mu': [0.1],
                          'alpha': [0.1],
                          'homo': [0.043],
                          'lumo': [0.043],
                          'gap': [0.043],
                          'r2': [1.2],
                          'zpve': [0.0012],
                          'u0': [0.043],
                          'u298': [0.043],
                          'h298': [0.043],
                          'g298': [0.043],
                          'cv': [0.50]}

In [63]:
chemical_accuracy = pd.DataFrame(chemical_accuracy_dict)

In [64]:
chemical_accuracy

Unnamed: 0,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv
0,0.1,0.1,0.043,0.043,0.043,1.2,0.0012,0.043,0.043,0.043,0.043,0.5


In [65]:
structures = ['smiles']
tasks = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv']

In [66]:
X = qm9[structures]
y = qm9[tasks]

In [67]:
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler

In [68]:
y.head()

Unnamed: 0,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv
0,0.0,13.21,-0.3877,0.1171,0.5048,35.3641,0.044749,-40.47893,-40.476062,-40.475117,-40.498597,6.469
1,1.6256,9.46,-0.257,0.0829,0.3399,26.1563,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316
2,1.8511,6.31,-0.2928,0.0687,0.3615,19.0002,0.021375,-76.404702,-76.401867,-76.400922,-76.422349,6.002
3,0.0,16.28,-0.2845,0.0506,0.3351,59.5248,0.026841,-77.308427,-77.305527,-77.304583,-77.327429,8.574
4,2.8937,12.99,-0.3604,0.0191,0.3796,48.7476,0.016601,-93.411888,-93.40937,-93.408425,-93.431246,6.278


In [69]:
y.describe()

Unnamed: 0,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv
count,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0
mean,2.706037,75.191296,-0.239977,0.011124,0.2511,1189.52745,0.148524,-411.543985,-411.535513,-411.534569,-411.577397,31.600676
std,1.530394,8.187793,0.022131,0.046936,0.047519,279.757172,0.033274,40.06023,40.060012,40.060012,40.060741,4.062471
min,0.0,6.31,-0.4286,-0.175,0.0246,19.0002,0.015951,-714.568061,-714.560153,-714.559209,-714.602138,6.002
25%,1.5887,70.38,-0.2525,-0.0238,0.2163,1018.3226,0.125289,-437.913936,-437.905942,-437.904997,-437.947682,28.942
50%,2.5,75.5,-0.241,0.012,0.2494,1147.5858,0.148329,-417.864758,-417.857351,-417.856407,-417.895731,31.555
75%,3.6361,80.52,-0.2287,0.0492,0.2882,1308.8166,0.17115,-387.049166,-387.039746,-387.038802,-387.083279,34.276
max,29.5564,196.62,-0.1017,0.1935,0.6221,3374.7532,0.273944,-40.47893,-40.476062,-40.475117,-40.498597,46.969


In [70]:
scaler = StandardScaler()

In [71]:
y = pd.DataFrame(scaler.fit_transform(y), index=y.index, columns=y.columns)

In [72]:
y.describe()

Unnamed: 0,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv
count,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0,133885.0
mean,6.453449e-17,-4.24569e-16,6.572328e-16,-6.113793000000001e-17,-3.362586e-16,-9.272587e-16,-9.170690000000001e-17,6.249655e-16,-1.379e-15,-1.902069e-16,7.472414e-16,1.148035e-15
std,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004,1.000004
min,-1.768203,-8.412713,-8.5229,-3.965489,-4.766551,-4.184099,-3.984321,-7.56424,-7.564296,-7.564296,-7.56416,-6.30128
25%,-0.7301006,-0.5876204,-0.5658624,-0.7440738,-0.7323509,-0.6119789,-0.6983094,-0.6582601,-0.6582756,-0.6582756,-0.65826,-0.6544503
50%,-0.1346308,0.03770307,-0.0462376,0.01866871,-0.03578319,-0.1499222,-0.005872165,-0.1577823,-0.1578098,-0.1578098,-0.1577194,-0.01124342
75%,0.6077298,0.6508132,0.5095351,0.8112391,0.7807373,0.4264041,0.6799833,0.611452,0.6114791,0.6114791,0.6114268,0.6585484
max,17.5448,14.83051,6.248001,3.885645,7.807443,7.811181,3.769323,9.262714,9.262624,9.262624,9.262939,3.783013


Detach is not called on batch_mse_loss as it is required for backpropagation

In [73]:
def batch_mse_loss(pred, true):
    return F.mse_loss(pred, true) / BATCH_SIZE

Detach is called on valid_mse_loss as this is to be used only for tracking.

In [74]:
def valid_mse_loss(pred, true):
    return (F.mse_loss(pred, true)).detach() / VALID_SIZE

In [75]:
scale_batch_to_train = BATCH_SIZE / TRAIN_SIZE

The following class defines the master edge from each molecule to the master node.

In [76]:
class MasterEdge(nn.Module):
    
    def __init__(self):
        super(MasterEdge, self).__init__()
        
        self.l1 = nn.Linear(D, P)
        nn.init.kaiming_normal_(self.l1.weight)
        self.l2 = nn.Linear(P, 2*E)
        nn.init.kaiming_normal_(self.l2.weight)
        self.l3 = nn.Linear(2*E, E)
        nn.init.kaiming_normal_(self.l3.weight)
        
    def forward(self, x):
        return F.relu(self.l3(F.relu(self.l2(F.relu(self.l1(x))))))

In [77]:
master_edge_learner = MasterEdge()

The 2 OrderedDicts are G and H.

G maps the index of each atom to a list of tuples, each containing an edge to, and the index of, an adjacent atom.

H maps atom indices to their hidden states.

In [78]:
def construct_multigraph(smile):
    g = OrderedDict({})
    h = OrderedDict({})
    h[-1] = 0
    molecule = Chem.MolFromSmiles(smile)
    for i in range(molecule.GetNumAtoms()):
        atom_i = molecule.GetAtomWithIdx(i)
        atom_i_featurized = dc.feat.graph_features.atom_features(atom_i)
        atom_i_tensorized = torch.FloatTensor(atom_i_featurized).view(1, D)
        h[i] = atom_i_tensorized
        h[-1] += h[i]
        master_edge = master_edge_learner(h[i])
        g.setdefault(i, []).append((master_edge, -1))
        g.setdefault(-1, []).append((master_edge, i))
        for j in range(molecule.GetNumAtoms()):
            bond_ij = molecule.GetBondBetweenAtoms(i, j)
            if bond_ij: # bond_ij is None when there is no bond.
                #atom_j = molecule.GetAtomWithIdx(j)
                #atom_j_featurized = dc.feat.graph_features.atom_features(atom_j)
                #atom_j_tensorized = torch.FloatTensor(atom_j_featurized).view(1, 75)
                bond_ij_featurized = dc.feat.graph_features.bond_features(bond_ij).astype(int)
                bond_ij_tensorized = torch.FloatTensor(bond_ij_featurized).view(1, E)
                g.setdefault(i, []).append((bond_ij_tensorized, j))
    return g, h

The EdgeMappingNeuralNetwork takes as input (1, E) and returns (D, D)

In [79]:
class EdgeMappingNeuralNetwork(nn.Module):
    
    def __init__(self):
        super(EdgeMappingNeuralNetwork, self).__init__()
        
        self.fc1 = nn.Linear(E, D)
        nn.init.kaiming_normal_(self.fc1.weight)
        self.fc2 = nn.Linear(1, D)
        nn.init.kaiming_normal_(self.fc2.weight)
        
    def f1(self, x):
        return F.relu(self.fc1(x))
        
    def f2(self, x):
        return F.relu(self.fc2(x.permute(1, 0)))
        
    def forward(self, x):
        return self.f2(self.f1(x))

The MessagePhase combines the message passing and update functions into a single dictionary comprehension. We use G to iterate over every vertex in the graph. For each vertex v, we process its adjacency list of edge and neighbour tuples. We use A to project edge e_vw to a (D, D) vector and then matrix multiply it with the vertex w's hidden state. These messages are summed at each vertex v and fed as input to a GRUCell with weights tied for each time step. This creates the OrderedDict hT for input hT-1.

In [80]:
class MessagePhase(nn.Module):
    
    def __init__(self):
        super(MessagePhase, self).__init__()
        self.A = EdgeMappingNeuralNetwork()
        self.U = {i:nn.GRUCell(D, D) for i in range(T)}
        
    def forward(self, smile):
        
        g, h = construct_multigraph(smile)
        g0, h0 = construct_multigraph(smile)
        
        for k in range(T):
            h = OrderedDict(
                {
                    v:
                    self.U[k](
                        sum(torch.matmul(h[w], self.A(e_vw)) for e_vw, w in en), 
                        h[v]
                    )
                    for v, en in g.items()
                }
            )
        
        return h, h0

We use the readout function provided in Gilmer et. al.

In [81]:
class Readout(nn.Module):
    
    def __init__(self):
        super(Readout, self).__init__()
        
        self.i1 = nn.Linear(2*D, 2*P)
        nn.init.kaiming_normal_(self.i1.weight)
        self.i2 = nn.Linear(2*P, P)
        nn.init.kaiming_normal_(self.i2.weight)
        
        self.j1 = nn.Linear(D, P)
        nn.init.kaiming_normal_(self.j1.weight)
        
    def i(self, h_v, h0_v):
        return F.relu(self.i2(F.relu(self.i1(torch.cat([h_v, h0_v], dim=1)))))
    
    def j(self, h_v):        
        return F.relu(self.j1(h_v))

    def r(self, h, h0):
        return sum(torch.sigmoid(self.i(h[v], h0[v])) * self.j(h[v]) for v in h.keys())
                      
    def forward(self, h, h0):
        return self.r(h, h0)

Finally, we package all these stages into a single class.

In [82]:
class MPNN(nn.Module):
    
    def __init__(self):
        super(MPNN, self).__init__()
        
        self.M = MessagePhase()
        self.R = Readout()
        
        self.p1 = nn.Linear(P, P)
        nn.init.kaiming_normal_(self.p1.weight)
        self.p2 = nn.Linear(P, P)
        nn.init.kaiming_normal_(self.p2.weight)
        self.p3 = nn.Linear(P, V)
        nn.init.kaiming_normal_(self.p3.weight)
        
    def p(self, ro):
        return F.relu(self.p3(F.relu(self.p2(F.relu(self.p1(ro))))))
    
    def forward(self, smile):
        h, h0 = self.M(smile)
        embed = self.R(h, h0)
        return self.p(embed)

In [83]:
from sklearn.model_selection import train_test_split

In [84]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=143)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=VALID_SIZE, random_state=143)

In [85]:
model = MPNN()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [86]:
for epoch in range(NUM_EPOCHS):
    print("epoch [%d/%d]"%(epoch+1, NUM_EPOCHS))
    train_loss = 0
    train_bar = FloatProgress(min=0, max=TRAIN_SIZE)
    display(train_bar)
    for batch in range(0, TRAIN_SIZE, BATCH_SIZE):
        batch_loss = 0
        optimizer.zero_grad()
        for sample in range(BATCH_SIZE):
            index = sample + batch
            smile = X_train.iloc[index]['smiles']
            y_hat = model(smile)
            y_tru = torch.Tensor(y_train.iloc[index].values.reshape(1, V))
            batch_loss += batch_mse_loss(y_hat, y_tru)
            train_bar.value += 1
        train_loss += (batch_loss * scale_batch_to_train).detach()
        batch_loss.backward()
        optimizer.step()
    valid_loss = 0
    accu_check = 0
    valid_bar = FloatProgress(min=0, max=VALID_SIZE)
    display(valid_bar)
    for sample in range(VALID_SIZE):
        index = sample
        smile = X_val.iloc[index]['smiles']
        y_hat = model(smile)
        y_tru = torch.Tensor(y_val.iloc[index].values.reshape(1, V))
        valid_loss += valid_mse_loss(y_hat, y_tru)
        accu_check += np.abs(scaler.inverse_transform(y_hat.detach()) - \
                             scaler.inverse_transform(y_tru.detach())) / VALID_SIZE
        valid_bar.value += 1
    print('train_loss [%4.2f]'%(train_loss.item()))
    print('valid_loss [%4.2f]'%(valid_loss.item()))
    print(accu_check/chemical_accuracy)    

epoch [1/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.93]
valid_loss [0.85]
          mu      alpha      homo     lumo      gap         r2       zpve  \
0  12.021016  63.404932  0.347198  0.90254  0.92222  128.51059  13.500393   

           u0        u298       h298        g298        cv  
0  528.105935  723.636051  723.69611  523.347278  6.418866  
epoch [2/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.80]
valid_loss [0.77]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.307155  0.90254  0.92222  120.765673  12.530256   

           u0        u298       h298        g298        cv  
0  441.710228  457.666841  723.69611  439.037545  6.419103  
epoch [3/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.76]
valid_loss [0.76]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.294215  0.90254  0.92222  117.654228  12.265405   

           u0        u298       h298       g298        cv  
0  418.175232  423.546015  723.69611  420.06914  6.419122  
epoch [4/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.75]
valid_loss [0.76]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.287768  0.90254  0.92222  114.845619  12.394311   

           u0        u298       h298        g298        cv  
0  420.079653  425.312441  723.69611  419.648548  6.419047  
epoch [5/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.75]
valid_loss [0.75]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.284066  0.90254  0.92222  113.291143  12.263235   

           u0        u298       h298        g298        cv  
0  418.055157  414.771368  723.69611  416.125896  6.418866  
epoch [6/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.74]
valid_loss [0.75]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.282286  0.90254  0.92222  112.248942  12.014668   

           u0        u298       h298        g298        cv  
0  412.121396  413.405308  723.69611  410.840278  6.418866  
epoch [7/7]


FloatProgress(value=0.0, max=113880.0)

FloatProgress(value=0.0, max=10000.0)

train_loss [0.74]
valid_loss [0.75]
          mu      alpha      homo     lumo      gap          r2       zpve  \
0  12.021016  63.404932  0.281452  0.90254  0.92222  111.530266  11.938114   

           u0        u298       h298        g298        cv  
0  411.717969  407.359234  723.69611  413.426155  6.418866  
