In [1]:
from rdkit import Chem
import numpy as np
import pandas as pd
from IPython.display import display
import tqdm
import torch
import deepchem as dc

2022-09-23 19:45:07.654947: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-23 19:45:07.759590: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-09-23 19:45:07.759606: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2022-09-23 19:45:07.780663: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-09-23 19:45:08.282526: W tensorflow/stream_executor/platform/de

# Train a embedding model on ChemBL

In [2]:
%%time
# Load raw dataset
chembl_tasks, datasets, transformers = dc.molnet.load_chembl(shard_size=2000, featurizer="raw", set="5thresh", splitter="random")
train_dataset, valid_dataset, test_dataset = datasets

CPU times: user 9.26 ms, sys: 350 Âµs, total: 9.61 ms
Wall time: 8.19 ms


In [3]:
%%time
# Featurize the input
f = dc.feat.MolGraphConvFeaturizer(use_edges=True, use_partial_charge=True)
x = f.featurize(train_dataset.X)
y = train_dataset.y

CPU times: user 2min 53s, sys: 608 ms, total: 2min 54s
Wall time: 2min 54s


In [None]:
%%time

# Train the MPNN model
model = dc.models.torch_models.MPNNModel(len(chembl_tasks), number_atom_features=31, number_bond_features=11)

print(model.model)

model.fit(dc.data.NumpyDataset(x, y))

In [None]:
# Lets see some scores (rmse)
avg_rms = dc.metrics.Metric(dc.metrics.rms_score, np.mean)
model.evaluate(dc.data.NumpyDataset(x, y), [avg_rms], transformers), model.evaluate(dc.data.NumpyDataset(f.featurize(valid_dataset.X), valid_dataset.y), [avg_rms], transformers), model.evaluate(dc.data.NumpyDataset(f.featurize(test_dataset.X), test_dataset.y), [avg_rms], transformers)

In [None]:
# Output of 1 sample for qualitative comparision
list(zip(train_dataset.y[:1].flatten(), model.predict(dc.data.NumpyDataset(x[:1], y[:1])).flatten()))

# Dump the model(s)

In [None]:
model.model

In [4]:
from torch import nn
import dgl

class MPNNMolEmbedder(nn.Module):
    """MPNN embedder."""
    def __init__(self, gnn, readout):
        super(MPNNMolEmbedder, self).__init__()

        self.gnn = gnn
        self.readout = readout

    def _prepare_batch(self, g):
        dgl_graphs = [graph.to_dgl_graph() for graph in g]
        inputs = dgl.batch(dgl_graphs).to("cpu")
        return inputs
        
    def forward(self, g):
        """Graph-level regression/soft classification.

        Parameters
        ----------
        g : GraphData
            GraphData for a batch of graphs.

        Returns
        -------
        graph embeddings
        """
        dgl_g = self._prepare_batch(g)
        node_feats = self.gnn(dgl_g, dgl_g.ndata["x"], dgl_g.edata["edge_attr"])
        graph_feats = self.readout(dgl_g, node_feats)
        return graph_feats

class MPNNAtomEmbedder(nn.Module):
    """MPNN embedder."""
    def __init__(self, gnn):
        super(MPNNAtomEmbedder, self).__init__()
        self.gnn = gnn

    def _prepare_batch(self, g):
        dgl_graphs = [graph.to_dgl_graph() for graph in g]
        inputs = dgl.batch(dgl_graphs).to("cpu")
        return inputs
        
    def forward(self, g, idx):
        """Graph-level regression/soft classification.

        Parameters
        ----------
        g : GraphData
            GraphData for a batch of graphs.

        Returns
        -------
        graph embeddings
        """
        dgl_g = self._prepare_batch(g)
        node_feats = self.gnn(dgl_g, dgl_g.ndata["x"], dgl_g.edata["edge_attr"])
        return node_feats[idx]

In [111]:
mol_embedder = MPNNMolEmbedder(*list(model.model.model.children())[:2])
atom_embedder = MPNNAtomEmbedder(*list(model.model.model.children())[:1])

In [112]:
mol_embedder

MPNNMolEmbedder(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=31, out_features=64, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
    (gru): GRU(64, 64)
  )
  (readout): Set2Set(
    n_iters=6
    (lstm): LSTM(128, 64, num_layers=3)
  )
)

In [113]:
atom_embedder

MPNNAtomEmbedder(
  (gnn): MPNNGNN(
    (project_node_feats): Sequential(
      (0): Linear(in_features=31, out_features=64, bias=True)
      (1): ReLU()
    )
    (gnn_layer): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=11, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
    (gru): GRU(64, 64)
  )
)

In [115]:
atom_embedder([x[4]], 1).shape

torch.Size([64])

In [117]:
torch.save(mol_embedder, "models/MPNNMolEmbedder.pt")
torch.save(atom_embedder, "models/MPNNAtomEmbedder.pt")

# Load the model and test on a new molecule

In [6]:
import deepchem as dc
import torch

# Featurizer
f = dc.feat.MolGraphConvFeaturizer(use_edges=True, use_partial_charge=True)

# Model
mol_em_model = torch.load("models/MPNNMolEmbedder.pt")
atom_em_model = torch.load("models/MPNNAtomEmbedder.pt")

def mol_to_embedding(mol):
    features = f.featurize([mol])[0]
    return mol_em_model([features])[0]

def atom_to_embedding(mol, idx):
    features = f.featurize([mol])[0]
    return atom_em_model([features], idx)

mol_to_embedding(Chem.MolFromSmiles("CCCCCC")), atom_to_embedding(Chem.MolFromSmiles("CCCCCC"), 5)

(tensor([ 0.4273, -0.0054,  0.0300, -0.5470,  0.7364,  0.0445,  0.0112, -0.0166,
          0.1102,  0.2216,  0.0039,  0.1971, -0.3468, -0.0202, -0.2214,  0.0314,
         -0.0983, -0.5491, -0.2083,  0.1508, -0.7335,  0.0810, -0.0330, -0.0227,
          0.1960, -0.9099,  0.0609,  0.7126, -0.1994,  0.4139,  0.6269, -0.0964,
         -0.4431, -0.3666, -0.2090, -0.0110,  0.3658,  0.1680, -0.5915,  0.0269,
          0.0530,  0.1097, -0.0087,  0.3116,  0.0679,  0.1380, -0.0237,  0.4257,
          0.2114,  0.1945, -0.2312,  0.0227, -0.0225,  0.0219,  0.6036,  0.0149,
         -0.1436, -0.0034,  0.2187, -0.8046, -0.6189, -0.1587,  0.5385, -0.1924,
          0.0639,  0.8521, -0.2918,  0.9619,  0.9513,  0.9992,  0.4112,  0.7927,
         -0.3629,  0.9802,  0.9209,  0.9591,  0.3266,  0.3460,  0.9029,  0.7934,
          0.2595,  0.1027,  0.0793, -0.4858,  0.4777,  0.8709, -0.5346,  0.2575,
         -0.1317, -0.6196,  0.0855, -0.0145,  0.3551,  0.9084,  0.3139,  0.1671,
          0.8924,  0.2319,  