In [3]:
from rdkit import Chem
import numpy as np
import pandas as pd
from IPython.display import display
import tqdm
import torch
import deepchem as dc

# Train a embedding model on ChemBL

In [4]:
%%time
# Load raw dataset
chembl_tasks, datasets, transformers = dc.molnet.load_chembl(shard_size=2000, featurizer="raw", set="5thresh", splitter="random")
train_dataset, valid_dataset, test_dataset = datasets

CPU times: user 4.57 ms, sys: 3.62 ms, total: 8.19 ms
Wall time: 7.01 ms


In [11]:
%%time
# Featurize the input
f = dc.feat.MolGraphConvFeaturizer(use_edges=True, use_partial_charge=True)
y = train_dataset.y

CPU times: user 7min 45s, sys: 385 ms, total: 7min 46s
Wall time: 7min 46s


%%time

# Train the MPNN model
model = dc.models.torch_models.MPNNModel(len(chembl_tasks), number_atom_features=31, number_bond_features=11)

print(model.model)

model.fit(dc.data.NumpyDataset(f.featurize(train_dataset.X), y)) # On original dataset

# Lets see some scores (rmse)
avg_rms = dc.metrics.Metric(dc.metrics.rms_score, np.mean)
model.evaluate(dc.data.NumpyDataset(x, y), [avg_rms], transformers), model.evaluate(dc.data.NumpyDataset(f.featurize(valid_dataset.X), valid_dataset.y), [avg_rms], transformers), model.evaluate(dc.data.NumpyDataset(f.featurize(test_dataset.X), test_dataset.y), [avg_rms], transformers)

# Output of 1 sample for qualitative comparision
list(zip(train_dataset.y[:1].flatten(), model.predict(dc.data.NumpyDataset(x[:1], y[:1])).flatten()))

In [23]:
%%time

# Train the MPNN model
model = dc.models.torch_models.MPNNModel(len(chembl_tasks), number_atom_features=31, number_bond_features=11)

print(model.model)

# There are single atom molecules in the actions - For these, featurizer does not work - so adding H's so each atom has at least some neighbors
X = []
for mol in tqdm.tqdm(train_dataset.X):
    X.append(Chem.AddHs(mol))
X = np.array(X)
x = f.featurize(X)

# train
model.fit(dc.data.NumpyDataset(x, y)) 

MPNN(
  (model): MPNNPredictor(
    (gnn): MPNNGNN(
      (project_node_feats): Sequential(
        (0): Linear(in_features=31, out_features=64, bias=True)
        (1): ReLU()
      )
      (gnn_layer): NNConv(
        (edge_func): Sequential(
          (0): Linear(in_features=11, out_features=128, bias=True)
          (1): ReLU()
          (2): Linear(in_features=128, out_features=4096, bias=True)
        )
      )
      (gru): GRU(64, 64)
    )
    (readout): Set2Set(
      n_iters=6
      (lstm): LSTM(128, 64, num_layers=3)
    )
    (predict): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=691, bias=True)
    )
  )
)




CPU times: user 3h 25min 27s, sys: 32min 40s, total: 3h 58min 8s
Wall time: 40min 46s


0.7483420848846436

In [24]:
# Lets see some scores (rmse)
avg_rms = dc.metrics.Metric(dc.metrics.rms_score, np.mean)
model.evaluate(dc.data.NumpyDataset(x, y), [avg_rms], transformers), \
    model.evaluate(dc.data.NumpyDataset(f.featurize(np.vectorize(Chem.AddHs)(valid_dataset.X)), valid_dataset.y), [avg_rms], transformers), \
    model.evaluate(dc.data.NumpyDataset(f.featurize(np.vectorize(Chem.AddHs)(test_dataset.X)), test_dataset.y), [avg_rms], transformers)

({'mean-rms_score': 0.2996347741008072},
 {'mean-rms_score': 0.29453254293992603},
 {'mean-rms_score': 0.29342535192793767})

In [25]:
# Output of 1 sample for qualitative comparision
list(zip(train_dataset.y[:1].flatten(), model.predict(dc.data.NumpyDataset(x[:1], y[:1])).flatten()))

[(-0.007236693326090918, 0.008164551),
 (-0.04238371855537214, -0.03503477),
 (0.0, -0.005800266),
 (0.0, -0.00035296706),
 (-0.027899843079487056, -0.050205093),
 (-0.025025194419930924, -0.026770085),
 (-0.01908282090230921, -0.008256391),
 (-0.02155178296682597, -0.10114729),
 (0.0, 7.620081e-05),
 (-0.022514736706494138, -0.05614335),
 (-0.023869481191052172, -0.0330522),
 (0.0, 0.000555055),
 (0.0, -0.0003162555),
 (-0.017423504840988296, -0.01330005),
 (-0.00723669332609096, -0.06322547),
 (-0.025991172597553413, 0.015069576),
 (0.0, 0.0017335708),
 (0.0, -0.0016625188),
 (0.0, 0.005104471),
 (-0.010232547811979919, -0.026510943),
 (-0.01601282273079471, 0.03235635),
 (-0.010137734473236647, 0.037284892),
 (0.0, 0.002554138),
 (-0.06908539966106848, 0.5791265),
 (-0.0846528929503079, -0.2929492),
 (-0.014448120136706365, -0.041189305),
 (-0.014447497539225454, -0.009178527),
 (-0.007236693326090872, -0.05063478),
 (-0.07556022955680562, -0.16636756),
 (-0.07963641257843566, -0.18

# Dump the model(s)

In [27]:
from torch import nn
import dgl

class MPNNMolEmbedder(nn.Module):
    """MPNN embedder."""
    def __init__(self, gnn, readout):
        super(MPNNMolEmbedder, self).__init__()

        self.gnn = gnn
        self.readout = readout

    def _prepare_batch(self, g):
        dgl_graphs = [graph.to_dgl_graph() for graph in g]
        inputs = dgl.batch(dgl_graphs).to("cpu")
        return inputs
        
    def forward(self, g):
        """Graph-level regression/soft classification.

        Parameters
        ----------
        g : GraphData
            GraphData for a batch of graphs.

        Returns
        -------
        graph embeddings
        """
        dgl_g = self._prepare_batch(g)
        node_feats = self.gnn(dgl_g, dgl_g.ndata["x"], dgl_g.edata["edge_attr"])
        graph_feats = self.readout(dgl_g, node_feats)
        return graph_feats

class MPNNAtomEmbedder(nn.Module):
    """MPNN embedder."""
    def __init__(self, gnn):
        super(MPNNAtomEmbedder, self).__init__()
        self.gnn = gnn

    def _prepare_batch(self, g):
        dgl_graphs = [graph.to_dgl_graph() for graph in g]
        inputs = dgl.batch(dgl_graphs).to("cpu")
        return inputs
        
    def forward(self, g, idx):
        """Graph-level regression/soft classification.

        Parameters
        ----------
        g : GraphData
            GraphData for a batch of graphs.

        Returns
        -------
        graph embeddings
        """
        dgl_g = self._prepare_batch(g)
        node_feats = self.gnn(dgl_g, dgl_g.ndata["x"], dgl_g.edata["edge_attr"])
        return node_feats[idx]

In [28]:
mol_embedder = MPNNMolEmbedder(*list(model.model.model.children())[:2])
atom_embedder = MPNNAtomEmbedder(*list(model.model.model.children())[:1])

In [29]:
mol_embedder

MPNNMolEmbedder(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=31, out_features=64, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
    (gru): GRU(64, 64)
  )
  (readout): Set2Set(
    n_iters=6
    (lstm): LSTM(128, 64, num_layers=3)
  )
)

In [30]:
atom_embedder

MPNNAtomEmbedder(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=31, out_features=64, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
    (gru): GRU(64, 64)
  )
)

In [31]:
atom_embedder([x[4]], 1).shape

torch.Size([64])

In [32]:
torch.save(mol_embedder, "models/MPNNMolEmbedder.pt")
torch.save(atom_embedder, "models/MPNNAtomEmbedder.pt")

# Load the model and test on a new molecule

In [33]:
import deepchem as dc
import torch

# Featurizer
f = dc.feat.MolGraphConvFeaturizer(use_edges=True, use_partial_charge=True)

# Model
mol_em_model = torch.load("models/MPNNMolEmbedder.pt")
atom_em_model = torch.load("models/MPNNAtomEmbedder.pt")

def mol_to_embedding(mol):
    features = f.featurize([mol])[0]
    return mol_em_model([features])[0]

def atom_to_embedding(mol, idx):
    features = f.featurize([mol])[0]
    return atom_em_model([features], idx)

mol_to_embedding(Chem.MolFromSmiles("CCCCCC")), atom_to_embedding(Chem.MolFromSmiles("CCCCCC"), 5)

(tensor([-0.4981, -0.3094,  0.0359, -0.0273,  0.0486,  0.0055,  0.0209, -0.2974,
          0.2711,  0.0229, -0.4001,  0.0598, -0.0237,  0.0394,  0.4154,  0.4887,
          0.0559, -0.0183, -0.2650,  0.3097,  0.2091,  0.6680, -0.0248,  0.1022,
          0.3671, -0.2866, -0.1959,  0.3826, -0.0631, -0.0816, -0.1895, -0.0070,
         -0.3358, -0.0399, -0.1000,  0.0863,  0.2740, -0.0144,  0.2664,  0.2931,
          0.0883, -0.0027, -0.1221, -0.1433,  0.2882, -0.0342, -0.2034, -0.0433,
         -0.0802,  0.1175, -0.0645,  0.4082, -0.0927, -0.0020, -0.0434, -0.0355,
         -0.4305,  0.1692,  0.2422, -0.1983,  0.1784, -0.2990, -0.1470, -0.3663,
          0.9347,  0.8867,  0.9883,  0.9934,  0.9891,  0.2513,  0.0481,  0.9916,
         -0.0480,  0.2216,  0.9288,  0.7891,  0.5612, -0.6109, -0.6209,  0.4739,
         -0.0296, -0.9944, -0.2888, -0.0268, -0.9340,  0.4385, -0.2914,  0.9926,
         -0.5496, -0.4660,  0.4663,  0.8730,  0.2880, -0.4318, -0.8252,  0.6199,
         -0.1942,  0.8182,  