*Okay. Alright. That's fine.*

\- Drake



Hyperparameter optimization did not as expected, but we have more exciting things ahead of us: creating our own graph classifier. Let's approach this task using PyTorch Geometric using this [example](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=mHSP6-RBOqCE) as reference.

We can break this down into a few steps:
- Convert SMILES data into graph data
- Mini-batch graph data
- Define Graph Neural Network for graph classification
- Train model and evaluate - you know the drill 

Let's get to it 🤖

In [None]:
!pip install deepchem
!pip install 'deepchem[torch]'
!pip install rdkit
!pip install torch_geometric

In [43]:
import deepchem as dc

tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv', reload=False)
train_dataset, valid_dataset, test_dataset = datasets



In [44]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(6264, 783, 784)

Note above sizes of our training, validation, and test datasets. However, we can't just operate on the above data directly as molecular information is stored as a [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) string (that is, Simplified molecular-input line-entry system). Try saying that five times fast. In short, the specification enables structural information of molecules to be encoded into a string.

For example,

In [48]:
train_dataset.ids[0]

'CC(O)(P(=O)(O)O)P(=O)(O)O'

Fortunately, open-source is once again our savior: [RDKit](https://www.rdkit.org/) supplies a few helper functions to convert SMILES strings into graph data that we *can* use with PyTorch Geometric. From there, we can extract additional metadata to help understand the numbers.

In [49]:
import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    
    if mol is None:
        return None
    
    AllChem.Compute2DCoords(mol)
    atom_features = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
    edge_indices = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    bond_types = [bond.GetBondTypeAsDouble() for bond in mol.GetBonds()]
    
    x = torch.tensor(atom_features, dtype=torch.float).view(-1, 1)
    edge_index = torch.tensor(list(zip(*edge_indices)), dtype=torch.long)
    edge_attr = torch.tensor(bond_types, dtype=torch.float).view(-1, 1)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

In [53]:
data = smiles_to_graph(train_dataset.ids[0])

Data(x=[11, 1], edge_index=[2, 10], edge_attr=[10, 1])

If you are anything like me 15 minutes ago, you may be underwhelmed and/or perplexed by the above numbers. A little clarification on the `Data` object might help:

1. **x (Node Feature Matrix):**
   - This parameter represents the feature matrix for each node in the graph.
   - It is a PyTorch tensor with shape [num_nodes, num_node_features].
   - Each row corresponds to a node, and each column corresponds to a feature of that node.
   - For example, if you are representing atoms in a molecule, `x` could contain features like atomic number, charge, etc.


2. **edge_index (Graph Connectivity):**
   - `edge_index` represents the graph connectivity in COO (Coordinate List) format.
   - It is a PyTorch tensor with shape [2, num_edges].
   - Each column of `edge_index` contains the indices of two nodes that form an edge.
   - For an undirected graph, (i, j) and (j, i) should both be present in the `edge_index`.


3. **edge_attr (Edge Feature Matrix):**
   - `edge_attr` represents the feature matrix for each edge in the graph.
   - It is a PyTorch tensor with shape [num_edges, num_edge_features].
   - Each row corresponds to an edge, and each column corresponds to a feature of that edge.
   - This is often used to store information like bond types, distances, or any other edge-specific features.

While there are a few other parameters, these three collectively provide a comprehensive representation of the graph needed for our specific use case.

Let's convert all our datasets from SMILES to graph data.

In [62]:
train_dataset_graph = []
for smiles in train_dataset.ids:
    data = smiles_to_graph(smiles)
    train_dataset_graph.append(data)

valid_dataset_graph = []
for smiles in valid_dataset.ids:
    data = smiles_to_graph(smiles)
    valid_dataset_graph.append(data)

test_dataset_graph = []
for smiles in test_dataset.ids:
    data = smiles_to_graph(smiles)
    test_dataset_graph.append(data)



Convert SMILES data into graph data ✅

Next step: **Mini-batching**.

Instead of "stacking" equally-sized matrices into a single mini-batch, as we may have done with image data, we take an alternative approach with graph data. "Why overcomplicate things?" you may ask. According to PyTorch Geometric [documentation](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=0gZ-l0npPIca):

1. GNN operators that rely on a **message passing scheme** (more on this later) do not need to be modified since messages are not exchanged between two nodes that belong to different graphs

2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries (*i.e.*, the edges)

PyTorch Geometric automatically takes care of **batching multiple graphs into a single giant graph** with the help of the [`torch_geometric.data.DataLoader`](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader) class:

In [65]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset_graph, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset_graph, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[1085, 1], edge_index=[2, 1101], edge_attr=[1101, 1], batch=[1085], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[1131, 1], edge_index=[2, 1148], edge_attr=[1148, 1], batch=[1131], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[1079, 1], edge_index=[2, 1108], edge_attr=[1108, 1], batch=[1079], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[1015, 1], edge_index=[2, 1048], edge_attr=[1048, 1], batch=[1015], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[1182, 1], edge_index=[2, 1221], edge_attr=[1221, 1], batch=[1182], ptr=[65])

Step 6:
Number of graphs in the current batch: 64
DataBatch(x=[1043, 1], edge_index=[2, 1077], edge_attr=[1077, 1], batch=[1043], ptr=[65])

Step 7:
Number of graphs in the current batch: 64
DataBatch(x=[965, 1], edge_index=[2, 986], edge_attr=[986, 1], batch=[965], ptr=[65])

Step 8:
Number of

Mini-batch data ✅

Next step: Graph. Neural. Networks.

FINALLY. Let's make like Merriam-Webster and define what our GNN is going to look like.

But before we that, let's quickly review message-passing (you're welcome, future me).