*Okay. Alright. That's fine.*

\- Drake



Hyperparameter optimization did not as expected, but we have more exciting things ahead of us: creating our own graph classifier. Let's approach this task using PyTorch Geometric using this [example](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=mHSP6-RBOqCE) as reference.

We can break this down into a few steps:
- Convert SMILES data into graph data
- Mini-batch graph data
- Define Graph Neural Network for graph classification
- Train model and evaluate - you know the drill 

Let's get to it 🤖

In [37]:
!pip install deepchem
!pip install 'deepchem[torch]'
!pip install rdkit
!pip install torch_geometric


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgr

In [38]:
import deepchem as dc

tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv', reload=False)
train_dataset, valid_dataset, test_dataset = datasets



Let's take a close look at this dataset by inspecting the training dataset.

In [39]:
print(train_dataset)
print(train_dataset.X)

<DiskDataset X.shape: (6264,), y.shape: (6264, 12), w.shape: (6264, 12), task_names: ['NR-AR' 'NR-AR-LBD' 'NR-AhR' ... 'SR-HSE' 'SR-MMP' 'SR-p53']>
[<deepchem.feat.mol_graphs.ConvMol object at 0x2cc648100>
 <deepchem.feat.mol_graphs.ConvMol object at 0x2c6a7b040>
 <deepchem.feat.mol_graphs.ConvMol object at 0x2a192dc00> ...
 <deepchem.feat.mol_graphs.ConvMol object at 0x29a92f6d0>
 <deepchem.feat.mol_graphs.ConvMol object at 0x29a96c070>
 <deepchem.feat.mol_graphs.ConvMol object at 0x29a96c760>]


It looks like there are **6264** graphs (molecules) in our training dataset, where *X* stores the molecules as ConvMol objects, *y* stores the hot-encoded output vector with one entry for each of the **12** measures/tasks, and a weights matrix *w*.

Here's one particular molecule:

In [40]:
test_mol = train_dataset.X[0]
print(f"This molecule has {test_mol.n_atoms} atoms with {test_mol.n_feat} features each.")

This molecule has 11 atoms with 75 features each.


In [41]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(6264, 783, 784)

Note above sizes of our training, validation, and test datasets. However, we can't just operate on the above data directly as molecular information is stored as a [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) string (that is, simplified molecular-input line-entry system). Try saying that five times fast. In short, the specification enables structural information of molecules to be encoded into a string.

For example,

In [42]:
train_dataset.ids[0]

'CC(O)(P(=O)(O)O)P(=O)(O)O'

Fortunately, open-source is once again our savior: [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spmatrix.html) supplies a few helper functions to convert sparse adjacency matrices  into graph data that we *can* use with PyTorch Geometric. From there, we can extract additional metadata to help understand the numbers.

In [43]:
import scipy.sparse
import numpy as np
def adjacency_list_to_sparse(adj_list):
    num_nodes = len(adj_list)
    rows, cols = [], []

    for i, neighbors in enumerate(adj_list):
        rows.extend([i] * len(neighbors))
        cols.extend(neighbors)

    adjacency_matrix = scipy.sparse.coo_matrix((np.ones_like(rows), (rows, cols)),
                                              shape=(num_nodes, num_nodes),
                                              dtype=np.float32)

    return adjacency_matrix


In [44]:
import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils.convert import from_scipy_sparse_matrix


def to_data(mol_graph):

    x = mol_graph.get_atom_features()

    adj_list = mol_graph.get_adjacency_list()
    sparse_mat = adjacency_list_to_sparse(adj_list)
    edge_index, edge_attr = from_scipy_sparse_matrix(sparse_mat)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

A little clarification on the `Data` object might help:

1. **x (Node Feature Matrix):**
   - This parameter represents the feature matrix for each node in the graph.
   - It is a PyTorch tensor with shape [num_nodes, num_node_features].
   - Each row corresponds to a node, and each column corresponds to a feature of that node.
   - For example, if you are representing atoms in a molecule, `x` could contain features like atomic number, charge, etc.


2. **edge_index (Graph Connectivity):**
   - `edge_index` represents the graph connectivity in COO (Coordinate List) format.
   - It is a PyTorch tensor with shape [2, num_edges].
   - Each column of `edge_index` contains the indices of two nodes that form an edge.
   - For an undirected graph, (i, j) and (j, i) should both be present in the `edge_index`.


3. **edge_attr (Edge Feature Matrix):**
   - `edge_attr` represents the feature matrix for each edge in the graph.
   - It is a PyTorch tensor with shape [num_edges, num_edge_features].
   - Each row corresponds to an edge, and each column corresponds to a feature of that edge.
   - This is often used to store information like bond types, distances, or any other edge-specific features.

While there are a few other parameters, these three collectively provide a comprehensive representation of the graph needed for our specific use case.

Let's convert all our datasets from ConvMol to Data objects.

In [45]:
train_dataset_graph = []
for mol_graph in train_dataset.X:
    data = to_data(mol_graph)
    train_dataset_graph.append(data)

valid_dataset_graph = []
for mol_graph in valid_dataset.X:
    data = to_data(mol_graph)
    valid_dataset_graph.append(data)

test_dataset_graph = []
for mol_graph in test_dataset.X:
    data = to_data(mol_graph)
    test_dataset_graph.append(data)

In [46]:
train_dataset_graph[0]

Data(x=[11, 75], edge_index=[2, 20], edge_attr=[20])

In [47]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset_graph, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset_graph, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [48]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2204], edge_attr=[2204], batch=[1089], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 1988], edge_attr=[1988], batch=[969], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2202], edge_attr=[2202], batch=[1085], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 1814], edge_attr=[1814], batch=[907], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2130], edge_attr=[2130], batch=[1032], ptr=[65])

Step 6:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2046], edge_attr=[2046], batch=[1004], ptr=[65])

Step 7:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 1988], edge_attr=[1988], batch=[985], ptr=[65])

Step 8:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge

Hmm. Slight issue with the x parameter, but no clear solution is popping out to me at the moment. As my friend Dan likes to say, "Let's circle back to this."

Also, small (*very* consequential) update: after spending 2 days trying to get perfectly preprocess this data, I have discovered that PyTorch Geometric supplies its own MoleculeNet class, complete with a Tox 21 dataset 🥲.

Despite how much my RAM has suffered due to tabs upon tabs of documentation, I'd say that this process helped clarify **how** and **why** we preprocess our data. Now let's condense yesterday's work into two lines!

In [49]:
from torch_geometric.datasets import MoleculeNet
dataset = MoleculeNet(root="./data/", name="Tox21")

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: Tox21(7831):
Number of graphs: 7831
Number of features: 9
Number of classes: 12

Data(x=[16, 9], edge_index=[2, 34], edge_attr=[34, 3], smiles='CCOc1ccc2nc(S(N)(=O)=O)sc2c1', y=[1, 12])
Number of nodes: 16
Number of edges: 34
Average node degree: 2.12
Has isolated nodes: False
Has self-loops: False
Is undirected: True


There are 12 labels, meaning 12 different tasks to predict. 
Predicting the presence of all 12 for every molecule is error-prone as an entry for every task is not present for every molecule (i.e., pesky NaNs!). Take a look:



In [50]:
data.y

tensor([[0., 0., 1., nan, nan, 0., 0., 1., 0., 0., 0., 0.]])

For this reason, let's tighten our scope: let's predict the presence of NR-AhR (column 3) in a molecule.

The first step to to **drop rows with a "NaN" value for NR-AhR**.

In [51]:
column_index_to_check = 2
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y.long()
dataset = dataset.shuffle()
filtered_dataset = [data for data in dataset if not torch.isnan(data.y[0][column_index_to_check]).any()]



Then we split our data into training and test datasets (a classic).

In [52]:
split = int(0.8 * len(dataset))

train_dataset, test_dataset = filtered_dataset[:split], filtered_dataset[split:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 6264
Number of test graphs: 1567


Its beautiful. We can finally move on...

Next step: **Mini-batching**.

Instead of "stacking" equally-sized matrices into a single mini-batch, as we may have done with image data, we take an alternative approach with graph data. "Why overcomplicate things?" you may ask. According to PyTorch Geometric [documentation](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=0gZ-l0npPIca):

1. GNN operators that rely on a **message passing scheme** (more on this later) do not need to be modified since messages are not exchanged between two nodes that belong to different graphs

2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries (*i.e.*, the edges)

PyTorch Geometric automatically takes care of **batching multiple graphs into a single giant graph** with the help of the [`torch_geometric.data.DataLoader`](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader) class:

In [53]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[1155, 9], edge_index=[2, 2404], edge_attr=[2404, 3], smiles=[64], y=[64, 12], batch=[1155], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[1126, 9], edge_index=[2, 2336], edge_attr=[2336, 3], smiles=[64], y=[64, 12], batch=[1126], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[1117, 9], edge_index=[2, 2306], edge_attr=[2306, 3], smiles=[64], y=[64, 12], batch=[1117], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[1242, 9], edge_index=[2, 2558], edge_attr=[2558, 3], smiles=[64], y=[64, 12], batch=[1242], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[1210, 9], edge_index=[2, 2514], edge_attr=[2514, 3], smiles=[64], y=[64, 12], batch=[1210], ptr=[65])

Step 6:
Number of graphs in the current batch: 64
DataBatch(x=[1338, 9], edge_index=[2, 2784], edge_attr=[2784, 3], smiles=[64], y=[64, 12], batch=[1338], ptr=[65])

Step

Mini-batch data ✅

Next step: Graph. Neural. Networks.

FINALLY. Let's make like Merriam-Webster and define what our GNN is going to look like.

But before we that, let's quickly review message-passing (you're welcome, future me).

In [54]:
from torch_geometric.nn import GraphConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GNN(hidden_channels=64)
print(model)

GNN(
  (conv1): GraphConv(9, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [55]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = loss_fn(out, data.y[:, column_index_to_check])  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y[:, column_index_to_check]).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


In [56]:
for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.9025, Test Acc: 0.9004
Epoch: 002, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 003, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 004, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 005, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 006, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 007, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 008, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 009, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 010, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 011, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 012, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 013, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 014, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 015, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 016, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 017, Train Acc: 0.9026, Test Acc: 0.9004
Epoch: 018, Train Acc: 0.9023, Test Acc: 0.9004
Epoch: 019, Train Acc: 0.9031, Test Acc: 0.9017
Epoch: 020, Train Acc: 0.9052, Test Acc: 0.9017
Epoch: 021, Train Acc: 0.9036, Test Acc: