In [17]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import numpy as np
import warnings
import pandas as pd
import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw
import networkx as nx
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.utils.data as data
import torch.optim as optim
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import dense_to_sparse, add_self_loops, to_scipy_sparse_matrix
from torch_geometric.data import Data
import torch.nn.functional as F
from six.moves import urllib
import deepchem as dc

warnings.filterwarnings("ignore")
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
        "figure.dpi": 300,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)

opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

soldata = pd.read_csv('https://dataverse.harvard.edu/api/access/datafile/3407241?format=original&gbrecs=true')
# had to rehost because dataverse isn't reliable
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv"
)
np.random.seed(0)

In [18]:
def gen_smiles2graph(smiles):
    """Argument for the RD2NX function should be a valid SMILES sequence
    returns: the graph
    """
    featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
    out = featurizer.featurize(smiles)
    return out[0]

In [19]:
dc.feat.MolGraphConvFeaturizer?

In [20]:
testCO = gen_smiles2graph("CO")
print(testCO)
print(type(testCO))
print(testCO.node_features)
print(type(testCO.node_features))
print(testCO.edge_index)
print(type(testCO.edge_index))
print(testCO.edge_features)
print(type(testCO.edge_features))

GraphData(node_features=[2, 30], edge_index=[2, 2], edge_features=[2, 11])
<class 'deepchem.feat.graph_data.GraphData'>
[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
  0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0.]]
<class 'numpy.ndarray'>
[[0 1]
 [1 0]]
<class 'numpy.ndarray'>
[[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
<class 'numpy.ndarray'>


In [21]:
help(testCO)

Help on GraphData in module deepchem.feat.graph_data object:

class GraphData(builtins.object)
 |  GraphData(node_features: numpy.ndarray, edge_index: numpy.ndarray, edge_features: Union[numpy.ndarray, NoneType] = None, node_pos_features: Union[numpy.ndarray, NoneType] = None, **kwargs)
 |  
 |  GraphData class
 |  
 |  This data class is almost same as `torch_geometric.data.Data
 |  <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data>`_.
 |  
 |  Attributes
 |  ----------
 |  node_features: np.ndarray
 |    Node feature matrix with shape [num_nodes, num_node_features]
 |  edge_index: np.ndarray, dtype int
 |    Graph connectivity in COO format with shape [2, num_edges]
 |  edge_features: np.ndarray, optional (default None)
 |    Edge feature matrix with shape [num_edges, num_edge_features]
 |  node_pos_features: np.ndarray, optional (default None)
 |    Node position matrix with shape [num_nodes, num_dimensions].
 |  num_nodes: int
 |    The

In [22]:
testOCO = gen_smiles2graph("OCO")
print(testOCO)
print(type(testOCO))
print(testOCO.node_features)
print(type(testOCO.node_features))
print(testOCO.edge_index)
print(type(testOCO.edge_index))
print(testOCO.edge_features)
print(type(testOCO.edge_features))

GraphData(node_features=[3, 30], edge_index=[2, 4], edge_features=[4, 11])
<class 'deepchem.feat.graph_data.GraphData'>
[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0.
  0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
  0. 0. 1. 0. 0. 0.]]
<class 'numpy.ndarray'>
[[0 2 2 1]
 [2 0 1 2]]
<class 'numpy.ndarray'>
[[1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
<class 'numpy.ndarray'>


In [23]:
help(testOCO)

Help on GraphData in module deepchem.feat.graph_data object:

class GraphData(builtins.object)
 |  GraphData(node_features: numpy.ndarray, edge_index: numpy.ndarray, edge_features: Union[numpy.ndarray, NoneType] = None, node_pos_features: Union[numpy.ndarray, NoneType] = None, **kwargs)
 |  
 |  GraphData class
 |  
 |  This data class is almost same as `torch_geometric.data.Data
 |  <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data>`_.
 |  
 |  Attributes
 |  ----------
 |  node_features: np.ndarray
 |    Node feature matrix with shape [num_nodes, num_node_features]
 |  edge_index: np.ndarray, dtype int
 |    Graph connectivity in COO format with shape [2, num_edges]
 |  edge_features: np.ndarray, optional (default None)
 |    Edge feature matrix with shape [num_edges, num_edge_features]
 |  node_pos_features: np.ndarray, optional (default None)
 |    Node position matrix with shape [num_nodes, num_dimensions].
 |  num_nodes: int
 |    The

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [25]:
graph = []
sol = []
for i in range(len(soldata)):
    graph.append(gen_smiles2graph(soldata.SMILES[i]))
    sol.append(soldata.Solubility[i])

Failed to featurize datapoint 0, [Mo]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 0, [Al+3].[Al+3].[Mo].[Mo].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Ca+2].[Mg+2].[O-2].[O-2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Mg+2]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 0, [Ca+2].[OH-].[OH-]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Ba+2].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[Fe+3].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2]. Appending em

Exception message: tuple index out of range
Failed to featurize datapoint 0, [H-].[H-].[Zr+2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Mg+2].[Mg+2].[Nb+5].[Nb+5].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2].[O-2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Hf]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 0, [Hf+4].[O-2].[O-2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Cl-].[Li+]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Cr].[F].[F].[F]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [F-].[F-].[Ni+2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Cl-].[Cl

Exception message: tuple index out of range
Failed to featurize datapoint 0, S.[Ni]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [S-2].[S-2].[Sn+4]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Cl-].[Cu+2].[Cu+2].[OH-].[OH-].[OH-]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Mn]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 0, [F-].[F-].[F-].[La+3]. Appending empty array
Exception message: tuple index out of range
Failed to featurize datapoint 0, [Ta]. Appending empty array
Exception message: More than one atom should be present in the molecule for this featurizer to work.
Failed to featurize datapoint 0, O.O.O.O.[Cl-].[Cl-].[Mn+2]. Appending empty array
Exception message: tuple index out of range
Failed to featurize

In [26]:
class CustomDataset(data.Dataset):
    def __init__(self, graphAll, solAll, transform=None, target_transform=None):
        self.graphInstances = graphAll
        self.solInstances = solAll
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.graphInstances)

    def __getitem__(self, idx):
        graphInstance = self.graphInstances[idx]
        solInstance = self.solInstances[idx]
        if self.transform:
            graphInstance = self.transform(graphInstance)
        if self.target_transform:
            solInstance = self.target_transform(solInstance)
        return graphInstance, solInstance

In [27]:
import multiprocessing

cores = multiprocessing.cpu_count() # Count the number of cores in a computer
cores

16

In [28]:
dataset = CustomDataset(graphAll=graph, solAll=sol)
dataloader = data.DataLoader(dataset, batch_size=1,
                        shuffle=True, num_workers=cores)
# print(len(dataloader))
# print(type(dataloader))
test_data, val_data, train_data = data.random_split(dataloader, [200, 200, len(dataloader) - 400], generator=torch.Generator().manual_seed(42))

test_data_loader = data.DataLoader(test_data, batch_size=1, shuffle=True, num_workers=cores)
val_data_loader = data.DataLoader(val_data, batch_size=1, shuffle=True, num_workers=cores)
train_data_loader = data.DataLoader(train_data, batch_size=1, shuffle=True, num_workers=cores)
# print(test_data)
# print(val_data)
# print(train_data)
# print(train_data.__getitem__(0))

In [31]:
# neural network with two fully connected layers
class FCNN(torch.nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.fc1 = nn.Linear(c_in, c_out)
        self.fc2 = nn.linear(c_in, c_out)
    
    def forward(self, features):
        features = self.fc1(features)
        features = F.relu(features)
        features = self.fc2(features)

# implementation of equation 5 in bondnet paper 
# https://pubs.rsc.org/en/content/articlepdf/2021/sc/d0sc05251e
class NodeFeatures(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.FCNN_one = FCNN(c_in=30, c_out=30)
        self.FCNN_two = FCNN(c_in=30, c_out=30)
        
    def forward(self, node_features, edge_features):
        FCNN_one_features = self.FCNN_one(node_features)
        FCNN_two_features = self.FCNN_two(node_features)
        epsilon = 1e-7
        sigmoid_edge_features = torch.nn.Sigmoid(edge_features)/(torch.nn.Sigmoid(edge_features) + epsilon)
        # * is for elementwise Hadamard product
        FCNN_processed_two_features = sigmoid_edge_features * FCNN_two_features
        intermediate_features = FCNN_one_features + FCNN_processed_two_features
        intermediate_features = F.relu(intermediate_features)
        output_features = node_features + intermediate_features

# implementation of equation 4 in bondnet paper
# https://pubs.rsc.org/en/content/articlepdf/2021/sc/d0sc05251e
class EdgeFeatures(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.FCNN_one = FCNN(c_in=30, c_out=30)
        self.FCNN_two = FCNN(c_in=11, c_out=11)
        
    
    def forward(self, node_features, edge_features):
        FCNN_one_features = self.FCNN_one(node_features)
        FCNN_two_features = self.FCNN_two(edge_features)
        intermediate_features = FCNN_one_features + FCNN_two_features
        intermediate_features = F.relu(intermediate_features)
        output_features = edge_features + intermediate_features
        
        return output_features