*Okay. Alright. That's fine.*

\- Drake



Hyperparameter optimization did not as expected, but we have more exciting things ahead of us: creating our own graph classifier. Let's approach this task using PyTorch Geometric using this [example](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=mHSP6-RBOqCE) as reference.

We can break this down into a few steps:
- Convert SMILES data into graph data
- Mini-batch graph data
- Define Graph Neural Network for graph classification
- Train model and evaluate - you know the drill 

Let's get to it 🤖

In [1]:
!pip install deepchem
!pip install 'deepchem[torch]'
!pip install rdkit
!pip install torch_geometric


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgr

In [2]:
import deepchem as dc

tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='GraphConv', reload=False)
train_dataset, valid_dataset, test_dataset = datasets

Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading some Jax models, missing a dependency. No module named 'jax'


Let's take a close look at this dataset by inspecting the training dataset.

In [3]:
print(train_dataset.X.shape)

(6264,)


It looks like there are **6264** graphs (molecules) in our training dataset, where *X* stores the molecules as ConvMol objects, *y* stores the hot-encoded output vector with one entry for each of the **12** measures/tasks, and a weights matrix *w*.

Here's one particular molecule:

In [4]:
test_mol = train_dataset.X[0]
print(f"This molecule has {test_mol.n_atoms} atoms with {test_mol.n_feat} features each.")

This molecule has 11 atoms with 75 features each.


In [5]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(6264, 783, 784)

Note above sizes of our training, validation, and test datasets. However, we can't just operate on the above data directly as molecular information is stored as a [SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) string (that is, simplified molecular-input line-entry system). Try saying that five times fast. In short, the specification enables structural information of molecules to be encoded into a string.

For example,

In [6]:
train_dataset.ids[0]

'CC(O)(P(=O)(O)O)P(=O)(O)O'

Fortunately, open-source is once again our savior: [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spmatrix.html) supplies a few helper functions to convert sparse adjacency matrices  into graph data that we *can* use with PyTorch Geometric. From there, we can extract additional metadata to help understand the numbers.

In [7]:
import scipy.sparse
import numpy as np
def adjacency_list_to_sparse(adj_list):
    num_nodes = len(adj_list)
    rows, cols = [], []

    for i, neighbors in enumerate(adj_list):
        rows.extend([i] * len(neighbors))
        cols.extend(neighbors)

    adjacency_matrix = scipy.sparse.coo_matrix((np.ones_like(rows), (rows, cols)),
                                              shape=(num_nodes, num_nodes),
                                              dtype=np.float32)

    return adjacency_matrix


In [8]:
import torch
from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.utils.convert import from_scipy_sparse_matrix


def to_data(mol_graph):

    x = mol_graph.get_atom_features()

    adj_list = mol_graph.get_adjacency_list()
    sparse_mat = adjacency_list_to_sparse(adj_list)
    edge_index, edge_attr = from_scipy_sparse_matrix(sparse_mat)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

A little clarification on the `Data` object might help:

1. **x (Node Feature Matrix):**
   - This parameter represents the feature matrix for each node in the graph.
   - It is a PyTorch tensor with shape [num_nodes, num_node_features].
   - Each row corresponds to a node, and each column corresponds to a feature of that node.
   - For example, if you are representing atoms in a molecule, `x` could contain features like atomic number, charge, etc.


2. **edge_index (Graph Connectivity):**
   - `edge_index` represents the graph connectivity in COO (Coordinate List) format.
   - It is a PyTorch tensor with shape [2, num_edges].
   - Each column of `edge_index` contains the indices of two nodes that form an edge.
   - For an undirected graph, (i, j) and (j, i) should both be present in the `edge_index`.


3. **edge_attr (Edge Feature Matrix):**
   - `edge_attr` represents the feature matrix for each edge in the graph.
   - It is a PyTorch tensor with shape [num_edges, num_edge_features].
   - Each row corresponds to an edge, and each column corresponds to a feature of that edge.
   - This is often used to store information like bond types, distances, or any other edge-specific features.

While there are a few other parameters, these three collectively provide a comprehensive representation of the graph needed for our specific use case.

Let's convert all our datasets from ConvMol to Data objects.

In [9]:
train_dataset_graph = []
for mol_graph in train_dataset.X:
    data = to_data(mol_graph)
    train_dataset_graph.append(data)

valid_dataset_graph = []
for mol_graph in valid_dataset.X:
    data = to_data(mol_graph)
    valid_dataset_graph.append(data)

test_dataset_graph = []
for mol_graph in test_dataset.X:
    data = to_data(mol_graph)
    test_dataset_graph.append(data)

In [10]:
train_dataset_graph[0]

Data(x=[11, 75], edge_index=[2, 20], edge_attr=[20])

In [11]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset_graph, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset_graph, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [12]:
import itertools
for step, data in enumerate(itertools.islice(train_loader, 5)):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2208], edge_attr=[2208], batch=[1078], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 1876], edge_attr=[1876], batch=[948], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2260], edge_attr=[2260], batch=[1109], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2404], edge_attr=[2404], batch=[1166], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[64], edge_index=[2, 2122], edge_attr=[2122], batch=[1052], ptr=[65])



Hmm. Slight issue with the x parameter, but no clear solution is popping out to me at the moment. As my friend Dan likes to say, "Let's circle back to this."

Also, small (*very* consequential) update: after spending 2 days trying to get perfectly preprocess this data, I have discovered that PyTorch Geometric supplies its own MoleculeNet class, complete with a Tox 21 dataset 🥲.

Despite how much my RAM has suffered due to tabs upon tabs of documentation, I'd say that this process helped clarify **how** and **why** we preprocess our data. Now let's condense yesterday's work into three lines!

In [13]:
from torch_geometric.datasets import MoleculeNet
dataset = MoleculeNet(root="./data/", name="Tox21")
dataset.data



Data(x=[145459, 9], edge_index=[2, 302190], edge_attr=[302190, 3], smiles=[7831], y=[7831, 12])

In [14]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: Tox21(7831):
Number of graphs: 7831
Number of features: 9
Number of classes: 12

Data(x=[16, 9], edge_index=[2, 34], edge_attr=[34, 3], smiles='CCOc1ccc2nc(S(N)(=O)=O)sc2c1', y=[1, 12])
Number of nodes: 16
Number of edges: 34
Average node degree: 2.12
Has isolated nodes: False
Has self-loops: False
Is undirected: True


There are 12 targets to predict, which is 11 too many to do at once - let's select NR-AhR (column 3) as the target to predict. Some of the rows in our dataset have a `NaN` value for this task, so we must them filter out:

In [15]:
dataset[0].y[0]

tensor([0., 0., 1., nan, nan, 0., 0., 1., 0., 0., 0., 0.])

Some datatype manipulation so we don't run into issues downstream!

In [16]:
dataset.y = dataset.y.long()

In [17]:
dataset.shuffle()
column_index_to_check = 2
filtered_data_list = [data for data in dataset if not torch.isnan(data.y[0, column_index_to_check]).any()]

***KEEP AN EYE ON THIS METRIC - a crucial bit of foreshadowing***

In [18]:
print("Number of 0s:", list(dataset.data.y[:, column_index_to_check]).count(0))
print("Number of 1s: ", list(dataset.data.y[:, column_index_to_check]).count(1))

Number of 0s: 5781
Number of 1s:  768




In [19]:
split = int(len(filtered_data_list) * 0.80)
train_dataset, test_dataset = filtered_data_list[:split], filtered_data_list[split:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 5239
Number of test graphs: 1310


Now we're cooking with this dataset: shuffled, filtered, and split. We can finally move on...

Next step: **Mini-batching**.

Instead of "stacking" equally-sized matrices into a single mini-batch, as we may have done with image data, we take an alternative approach with graph data. "Why overcomplicate things?" you may ask. According to PyTorch Geometric [documentation](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=0gZ-l0npPIca):

1. GNN operators that rely on a **message passing scheme** (more on this later) do not need to be modified since messages are not exchanged between two nodes that belong to different graphs

2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries (*i.e.*, the edges)

PyTorch Geometric automatically takes care of **batching multiple graphs into a single giant graph** with the help of the [`torch_geometric.data.DataLoader`](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader) class:

In [20]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(itertools.islice(train_loader, 5)):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(x=[1207, 9], edge_index=[2, 2500], edge_attr=[2500, 3], smiles=[64], y=[64, 12], batch=[1207], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(x=[1370, 9], edge_index=[2, 2880], edge_attr=[2880, 3], smiles=[64], y=[64, 12], batch=[1370], ptr=[65])

Step 3:
Number of graphs in the current batch: 64
DataBatch(x=[1172, 9], edge_index=[2, 2440], edge_attr=[2440, 3], smiles=[64], y=[64, 12], batch=[1172], ptr=[65])

Step 4:
Number of graphs in the current batch: 64
DataBatch(x=[1073, 9], edge_index=[2, 2220], edge_attr=[2220, 3], smiles=[64], y=[64, 12], batch=[1073], ptr=[65])

Step 5:
Number of graphs in the current batch: 64
DataBatch(x=[1138, 9], edge_index=[2, 2370], edge_attr=[2370, 3], smiles=[64], y=[64, 12], batch=[1138], ptr=[65])



Mini-batch data ✅

Next step: Graph. Neural. Networks.

FINALLY. Let's make like Merriam-Webster and define what our GNN is going to look like.

Thank you to the wonderful people at [PyTorch](https://www.learnpytorch.io/02_pytorch_classification/) and [PyTorch Geometric](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=mHSP6-RBOqCE) for these thoroughly-documented references (they low-key carried). Take a look at them if you haven't already!

In [21]:
from torch_geometric.nn import GraphConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1) # Output a binary value (0 or 1) because this is a binary classification problem 

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)

        x = x.to(self.lin.weight.dtype)
        x = self.lin(x)
        
        return x

model = GNN(hidden_channels=64)
print(model)

GNN(
  (conv1): GraphConv(9, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=1, bias=True)
)


Define Graph Neural Network for graph classification ✅


And finally, the code that we'll use to train and test our model.

In [22]:
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item() # torch.eq() calculates where two tensors are equal
    acc = (correct / len(y_pred)) * 100 
    return acc

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()


def train(loader):
    model.train()

    train_loss, train_acc = 0, 0
    for data in loader:  # Iterate in batches over the training dataset.
         
         # 1. Forward pass (model outputs raw logits)
         y_logits = model(data.x.to(torch.float32), data.edge_index, data.batch)
         y_preds = torch.round(torch.sigmoid(y_logits))
         y_train = data.y[:, column_index_to_check] # extract target column
         
         # 2. Calculate loss/accuracy
         loss = loss_fn(y_logits.squeeze(), y_train)
         train_loss += loss
         train_acc += accuracy_fn(y_true=y_train, y_pred=y_preds.squeeze())

         # 3. Optimizer zero grad
         optimizer.zero_grad() 

         # 4. Loss backwards
         loss.backward()  

         # 5. Optimizer step
         optimizer.step()
    train_loss /= len(loader)
    train_acc /= len(loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

def test(loader):
     model.eval()

     test_loss, test_acc = 0, 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         test_logits = model(data.x.to(torch.float32), data.edge_index, data.batch)
         test_preds = torch.round(torch.sigmoid(test_logits))
         labels = data.y[:, column_index_to_check] # extract target column

         test_loss += loss_fn(test_logits.squeeze(), labels)
         test_acc += accuracy_fn(y_true=labels, y_pred=test_preds.squeeze())
     test_loss /= len(loader)
     test_acc /= len(loader)
     print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")



In [24]:
from tqdm.auto import tqdm

epochs = 100
for epoch in tqdm(range(epochs)):
    train(train_loader)
    test(test_loader)

  from .autonotebook import tqdm as notebook_tqdm
  1%|          | 1/100 [00:00<00:58,  1.68it/s]

Train loss: 0.46709 | Train accuracy: 86.55%
Test loss: 0.30009 | Test accuracy: 88.61%



  2%|▏         | 2/100 [00:01<00:50,  1.94it/s]

Train loss: 0.32101 | Train accuracy: 88.04%
Test loss: 0.29506 | Test accuracy: 88.61%



  3%|▎         | 3/100 [00:01<00:50,  1.93it/s]

Train loss: 0.30924 | Train accuracy: 88.25%
Test loss: 0.31156 | Test accuracy: 88.61%



  4%|▍         | 4/100 [00:02<00:49,  1.95it/s]

Train loss: 0.30842 | Train accuracy: 88.24%
Test loss: 0.29471 | Test accuracy: 88.61%



  5%|▌         | 5/100 [00:02<00:49,  1.90it/s]

Train loss: 0.30941 | Train accuracy: 88.25%
Test loss: 0.29283 | Test accuracy: 88.61%



  6%|▌         | 6/100 [00:03<00:47,  1.99it/s]

Train loss: 0.30037 | Train accuracy: 88.25%
Test loss: 0.28817 | Test accuracy: 88.61%



  7%|▋         | 7/100 [00:03<00:47,  1.96it/s]

Train loss: 0.30182 | Train accuracy: 88.24%
Test loss: 0.29172 | Test accuracy: 88.61%



  8%|▊         | 8/100 [00:04<00:46,  1.97it/s]

Train loss: 0.30725 | Train accuracy: 88.24%
Test loss: 0.28652 | Test accuracy: 88.61%



  9%|▉         | 9/100 [00:04<00:49,  1.84it/s]

Train loss: 0.29302 | Train accuracy: 88.34%
Test loss: 0.28141 | Test accuracy: 89.28%



 10%|█         | 10/100 [00:05<00:48,  1.86it/s]

Train loss: 0.29313 | Train accuracy: 88.46%
Test loss: 0.28418 | Test accuracy: 89.26%



 11%|█         | 11/100 [00:05<00:50,  1.76it/s]

Train loss: 0.29138 | Train accuracy: 88.13%
Test loss: 0.28410 | Test accuracy: 88.61%



 12%|█▏        | 12/100 [00:06<00:46,  1.91it/s]

Train loss: 0.29365 | Train accuracy: 88.14%
Test loss: 0.27985 | Test accuracy: 89.57%



 13%|█▎        | 13/100 [00:06<00:45,  1.93it/s]

Train loss: 0.28542 | Train accuracy: 88.39%
Test loss: 0.29385 | Test accuracy: 88.61%



 14%|█▍        | 14/100 [00:07<00:44,  1.92it/s]

Train loss: 0.28814 | Train accuracy: 88.47%
Test loss: 0.28579 | Test accuracy: 88.75%



 15%|█▌        | 15/100 [00:08<00:48,  1.77it/s]

Train loss: 0.29431 | Train accuracy: 88.16%
Test loss: 0.28014 | Test accuracy: 88.61%



 16%|█▌        | 16/100 [00:08<00:45,  1.86it/s]

Train loss: 0.29337 | Train accuracy: 88.31%
Test loss: 0.28316 | Test accuracy: 88.90%



 17%|█▋        | 17/100 [00:08<00:42,  1.93it/s]

Train loss: 0.28476 | Train accuracy: 88.33%
Test loss: 0.27442 | Test accuracy: 88.90%



 18%|█▊        | 18/100 [00:09<00:42,  1.95it/s]

Train loss: 0.28628 | Train accuracy: 88.66%
Test loss: 0.27424 | Test accuracy: 88.90%



 19%|█▉        | 19/100 [00:09<00:40,  1.99it/s]

Train loss: 0.28955 | Train accuracy: 88.57%
Test loss: 0.26551 | Test accuracy: 89.20%



 20%|██        | 20/100 [00:10<00:37,  2.11it/s]

Train loss: 0.28256 | Train accuracy: 88.53%
Test loss: 0.27328 | Test accuracy: 89.20%



 21%|██        | 21/100 [00:10<00:35,  2.21it/s]

Train loss: 0.28081 | Train accuracy: 88.73%
Test loss: 0.28542 | Test accuracy: 88.83%



 22%|██▏       | 22/100 [00:11<00:37,  2.09it/s]

Train loss: 0.28357 | Train accuracy: 88.18%
Test loss: 0.27393 | Test accuracy: 88.61%



 23%|██▎       | 23/100 [00:11<00:37,  2.04it/s]

Train loss: 0.28436 | Train accuracy: 88.53%
Test loss: 0.27294 | Test accuracy: 89.57%



 24%|██▍       | 24/100 [00:12<00:35,  2.14it/s]

Train loss: 0.28507 | Train accuracy: 88.78%
Test loss: 0.31944 | Test accuracy: 88.34%



 25%|██▌       | 25/100 [00:12<00:33,  2.23it/s]

Train loss: 0.29816 | Train accuracy: 88.32%
Test loss: 0.30190 | Test accuracy: 88.61%



 26%|██▌       | 26/100 [00:13<00:33,  2.18it/s]

Train loss: 0.27892 | Train accuracy: 88.55%
Test loss: 0.29413 | Test accuracy: 89.80%



 27%|██▋       | 27/100 [00:13<00:34,  2.10it/s]

Train loss: 0.27929 | Train accuracy: 88.69%
Test loss: 0.26694 | Test accuracy: 89.57%



 28%|██▊       | 28/100 [00:14<00:34,  2.07it/s]

Train loss: 0.27746 | Train accuracy: 88.82%
Test loss: 0.27373 | Test accuracy: 90.23%



 29%|██▉       | 29/100 [00:14<00:34,  2.08it/s]

Train loss: 0.28101 | Train accuracy: 88.38%
Test loss: 0.26796 | Test accuracy: 89.87%



 30%|███       | 30/100 [00:15<00:33,  2.08it/s]

Train loss: 0.27827 | Train accuracy: 88.87%
Test loss: 0.26674 | Test accuracy: 89.42%



 31%|███       | 31/100 [00:15<00:32,  2.14it/s]

Train loss: 0.28963 | Train accuracy: 88.47%
Test loss: 0.26934 | Test accuracy: 89.80%



 32%|███▏      | 32/100 [00:15<00:31,  2.17it/s]

Train loss: 0.28173 | Train accuracy: 88.82%
Test loss: 0.27876 | Test accuracy: 89.33%



 33%|███▎      | 33/100 [00:16<00:29,  2.25it/s]

Train loss: 0.27777 | Train accuracy: 88.84%
Test loss: 0.25772 | Test accuracy: 89.95%



 34%|███▍      | 34/100 [00:16<00:29,  2.27it/s]

Train loss: 0.27536 | Train accuracy: 88.82%
Test loss: 0.27152 | Test accuracy: 88.83%



 35%|███▌      | 35/100 [00:17<00:30,  2.16it/s]

Train loss: 0.27757 | Train accuracy: 89.02%
Test loss: 0.26418 | Test accuracy: 90.17%



 36%|███▌      | 36/100 [00:17<00:33,  1.93it/s]

Train loss: 0.27145 | Train accuracy: 89.04%
Test loss: 0.27573 | Test accuracy: 89.20%



 37%|███▋      | 37/100 [00:18<00:34,  1.85it/s]

Train loss: 0.28028 | Train accuracy: 88.79%
Test loss: 0.27125 | Test accuracy: 89.50%



 38%|███▊      | 38/100 [00:18<00:31,  1.96it/s]

Train loss: 0.27065 | Train accuracy: 88.79%
Test loss: 0.27789 | Test accuracy: 90.09%



 39%|███▉      | 39/100 [00:19<00:29,  2.09it/s]

Train loss: 0.27477 | Train accuracy: 88.95%
Test loss: 0.26854 | Test accuracy: 90.39%



 40%|████      | 40/100 [00:19<00:28,  2.12it/s]

Train loss: 0.27238 | Train accuracy: 89.12%
Test loss: 0.29529 | Test accuracy: 88.88%



 41%|████      | 41/100 [00:20<00:27,  2.15it/s]

Train loss: 0.27191 | Train accuracy: 88.92%
Test loss: 0.26758 | Test accuracy: 90.24%



 42%|████▏     | 42/100 [00:20<00:27,  2.10it/s]

Train loss: 0.27508 | Train accuracy: 88.95%
Test loss: 0.27113 | Test accuracy: 90.39%



 43%|████▎     | 43/100 [00:21<00:26,  2.17it/s]

Train loss: 0.26888 | Train accuracy: 89.24%
Test loss: 0.26417 | Test accuracy: 90.09%



 44%|████▍     | 44/100 [00:21<00:25,  2.18it/s]

Train loss: 0.26999 | Train accuracy: 88.85%
Test loss: 0.26199 | Test accuracy: 89.65%



 45%|████▌     | 45/100 [00:22<00:25,  2.19it/s]

Train loss: 0.27061 | Train accuracy: 88.68%
Test loss: 0.27786 | Test accuracy: 90.09%

Train loss: 0.27112 | Train accuracy: 88.84%
Test loss: 0.26477 | Test accuracy: 90.54%



 47%|████▋     | 47/100 [00:24<00:41,  1.27it/s]

Train loss: 0.26435 | Train accuracy: 88.91%
Test loss: 0.26815 | Test accuracy: 89.87%



 48%|████▊     | 48/100 [00:25<00:39,  1.30it/s]

Train loss: 0.27173 | Train accuracy: 88.86%
Test loss: 0.26941 | Test accuracy: 90.24%



 49%|████▉     | 49/100 [00:25<00:36,  1.41it/s]

Train loss: 0.27647 | Train accuracy: 89.29%
Test loss: 0.26146 | Test accuracy: 89.95%



 50%|█████     | 50/100 [00:26<00:32,  1.54it/s]

Train loss: 0.26515 | Train accuracy: 89.39%
Test loss: 0.27926 | Test accuracy: 88.90%



 51%|█████     | 51/100 [00:26<00:28,  1.70it/s]

Train loss: 0.27248 | Train accuracy: 88.88%
Test loss: 0.27278 | Test accuracy: 89.35%



 52%|█████▏    | 52/100 [00:27<00:26,  1.83it/s]

Train loss: 0.26824 | Train accuracy: 89.05%
Test loss: 0.26793 | Test accuracy: 89.05%



 53%|█████▎    | 53/100 [00:27<00:24,  1.94it/s]

Train loss: 0.27553 | Train accuracy: 88.65%
Test loss: 0.26629 | Test accuracy: 89.80%



 54%|█████▍    | 54/100 [00:27<00:22,  2.06it/s]

Train loss: 0.27759 | Train accuracy: 88.78%
Test loss: 0.27020 | Test accuracy: 89.65%



 55%|█████▌    | 55/100 [00:28<00:21,  2.12it/s]

Train loss: 0.27225 | Train accuracy: 88.86%
Test loss: 0.26552 | Test accuracy: 89.57%



 56%|█████▌    | 56/100 [00:29<00:23,  1.87it/s]

Train loss: 0.27219 | Train accuracy: 88.90%
Test loss: 0.28975 | Test accuracy: 88.68%



 57%|█████▋    | 57/100 [00:29<00:22,  1.88it/s]

Train loss: 0.26829 | Train accuracy: 89.26%
Test loss: 0.26680 | Test accuracy: 89.80%



 58%|█████▊    | 58/100 [00:30<00:22,  1.87it/s]

Train loss: 0.27208 | Train accuracy: 88.88%
Test loss: 0.26531 | Test accuracy: 90.02%



 59%|█████▉    | 59/100 [00:30<00:22,  1.84it/s]

Train loss: 0.26821 | Train accuracy: 89.02%
Test loss: 0.27403 | Test accuracy: 88.98%



 60%|██████    | 60/100 [00:31<00:20,  1.91it/s]

Train loss: 0.26541 | Train accuracy: 89.14%
Test loss: 0.26731 | Test accuracy: 90.02%



 61%|██████    | 61/100 [00:31<00:21,  1.84it/s]

Train loss: 0.26494 | Train accuracy: 89.05%
Test loss: 0.26445 | Test accuracy: 90.24%



 62%|██████▏   | 62/100 [00:32<00:20,  1.86it/s]

Train loss: 0.26416 | Train accuracy: 89.06%
Test loss: 0.26744 | Test accuracy: 90.02%



 63%|██████▎   | 63/100 [00:32<00:19,  1.93it/s]

Train loss: 0.27232 | Train accuracy: 88.68%
Test loss: 0.28596 | Test accuracy: 88.75%



 64%|██████▍   | 64/100 [00:33<00:19,  1.87it/s]

Train loss: 0.27294 | Train accuracy: 89.31%
Test loss: 0.28264 | Test accuracy: 88.68%



 65%|██████▌   | 65/100 [00:33<00:18,  1.88it/s]

Train loss: 0.26360 | Train accuracy: 89.33%
Test loss: 0.26412 | Test accuracy: 90.17%



 66%|██████▌   | 66/100 [00:34<00:17,  1.94it/s]

Train loss: 0.26035 | Train accuracy: 89.44%
Test loss: 0.26436 | Test accuracy: 90.24%



 67%|██████▋   | 67/100 [00:34<00:17,  1.92it/s]

Train loss: 0.26475 | Train accuracy: 89.21%
Test loss: 0.27027 | Test accuracy: 89.72%



 68%|██████▊   | 68/100 [00:35<00:16,  1.92it/s]

Train loss: 0.25999 | Train accuracy: 89.19%
Test loss: 0.26555 | Test accuracy: 90.17%



 69%|██████▉   | 69/100 [00:35<00:16,  1.91it/s]

Train loss: 0.26857 | Train accuracy: 88.86%
Test loss: 0.26577 | Test accuracy: 89.87%



 70%|███████   | 70/100 [00:36<00:15,  1.96it/s]

Train loss: 0.26045 | Train accuracy: 89.47%
Test loss: 0.26726 | Test accuracy: 89.65%



 71%|███████   | 71/100 [00:36<00:15,  1.83it/s]

Train loss: 0.26724 | Train accuracy: 89.11%
Test loss: 0.30283 | Test accuracy: 89.05%



 72%|███████▏  | 72/100 [00:37<00:14,  1.87it/s]

Train loss: 0.26502 | Train accuracy: 89.56%
Test loss: 0.27504 | Test accuracy: 90.09%



 73%|███████▎  | 73/100 [00:38<00:14,  1.87it/s]

Train loss: 0.26136 | Train accuracy: 89.22%
Test loss: 0.25707 | Test accuracy: 89.20%



 74%|███████▍  | 74/100 [00:38<00:13,  1.90it/s]

Train loss: 0.25800 | Train accuracy: 89.34%
Test loss: 0.26317 | Test accuracy: 90.24%



 75%|███████▌  | 75/100 [00:39<00:13,  1.91it/s]

Train loss: 0.25971 | Train accuracy: 89.40%
Test loss: 0.26921 | Test accuracy: 89.50%



 76%|███████▌  | 76/100 [00:39<00:12,  1.93it/s]

Train loss: 0.26427 | Train accuracy: 89.22%
Test loss: 0.27239 | Test accuracy: 90.39%



 77%|███████▋  | 77/100 [00:40<00:11,  1.98it/s]

Train loss: 0.26197 | Train accuracy: 89.50%
Test loss: 0.26724 | Test accuracy: 89.87%



 78%|███████▊  | 78/100 [00:40<00:10,  2.06it/s]

Train loss: 0.26272 | Train accuracy: 89.13%
Test loss: 0.26822 | Test accuracy: 88.98%



 79%|███████▉  | 79/100 [00:40<00:09,  2.14it/s]

Train loss: 0.26116 | Train accuracy: 89.23%
Test loss: 0.26574 | Test accuracy: 89.72%



 80%|████████  | 80/100 [00:41<00:09,  2.15it/s]

Train loss: 0.25870 | Train accuracy: 89.47%
Test loss: 0.26045 | Test accuracy: 90.09%



 81%|████████  | 81/100 [00:41<00:08,  2.12it/s]

Train loss: 0.26068 | Train accuracy: 89.41%
Test loss: 0.26254 | Test accuracy: 89.13%



 82%|████████▏ | 82/100 [00:42<00:08,  2.17it/s]

Train loss: 0.25773 | Train accuracy: 89.31%
Test loss: 0.26747 | Test accuracy: 89.87%



 83%|████████▎ | 83/100 [00:42<00:08,  2.11it/s]

Train loss: 0.26009 | Train accuracy: 89.31%
Test loss: 0.26785 | Test accuracy: 89.80%



 84%|████████▍ | 84/100 [00:43<00:07,  2.10it/s]

Train loss: 0.25792 | Train accuracy: 89.07%
Test loss: 0.26587 | Test accuracy: 90.09%



 85%|████████▌ | 85/100 [00:43<00:06,  2.17it/s]

Train loss: 0.25971 | Train accuracy: 89.19%
Test loss: 0.28663 | Test accuracy: 90.01%



 86%|████████▌ | 86/100 [00:44<00:06,  2.15it/s]

Train loss: 0.26520 | Train accuracy: 89.06%
Test loss: 0.28156 | Test accuracy: 89.28%



 87%|████████▋ | 87/100 [00:44<00:05,  2.20it/s]

Train loss: 0.26490 | Train accuracy: 89.29%
Test loss: 0.25390 | Test accuracy: 90.39%



 88%|████████▊ | 88/100 [00:45<00:05,  2.25it/s]

Train loss: 0.26048 | Train accuracy: 89.15%
Test loss: 0.27182 | Test accuracy: 89.80%



 89%|████████▉ | 89/100 [00:45<00:05,  2.03it/s]

Train loss: 0.25951 | Train accuracy: 89.41%
Test loss: 0.25865 | Test accuracy: 90.17%



 90%|█████████ | 90/100 [00:46<00:04,  2.03it/s]

Train loss: 0.25731 | Train accuracy: 89.49%
Test loss: 0.26551 | Test accuracy: 89.95%



 91%|█████████ | 91/100 [00:46<00:04,  2.08it/s]

Train loss: 0.26023 | Train accuracy: 89.07%
Test loss: 0.27571 | Test accuracy: 89.13%



 92%|█████████▏| 92/100 [00:47<00:03,  2.13it/s]

Train loss: 0.25869 | Train accuracy: 89.24%
Test loss: 0.25922 | Test accuracy: 90.24%



 93%|█████████▎| 93/100 [00:47<00:03,  2.18it/s]

Train loss: 0.26001 | Train accuracy: 89.14%
Test loss: 0.26961 | Test accuracy: 89.35%



 94%|█████████▍| 94/100 [00:47<00:02,  2.21it/s]

Train loss: 0.25758 | Train accuracy: 89.43%
Test loss: 0.27950 | Test accuracy: 89.35%



 95%|█████████▌| 95/100 [00:48<00:02,  2.16it/s]

Train loss: 0.25329 | Train accuracy: 89.74%
Test loss: 0.26774 | Test accuracy: 90.17%



 96%|█████████▌| 96/100 [00:48<00:01,  2.12it/s]

Train loss: 0.25557 | Train accuracy: 89.39%
Test loss: 0.26281 | Test accuracy: 90.09%



 97%|█████████▋| 97/100 [00:49<00:01,  1.91it/s]

Train loss: 0.26000 | Train accuracy: 89.19%
Test loss: 0.26906 | Test accuracy: 89.87%



 98%|█████████▊| 98/100 [00:49<00:01,  1.97it/s]

Train loss: 0.25769 | Train accuracy: 89.32%
Test loss: 0.26985 | Test accuracy: 90.09%



 99%|█████████▉| 99/100 [00:50<00:00,  2.04it/s]

Train loss: 0.26069 | Train accuracy: 89.07%
Test loss: 0.27434 | Test accuracy: 89.71%



100%|██████████| 100/100 [00:51<00:00,  1.96it/s]

Train loss: 0.25498 | Train accuracy: 89.47%
Test loss: 0.29478 | Test accuracy: 88.89%






Train model and evaluate ✅

We did it!

Let's pause here...

Just kidding! Job not finished 😤

There are two key issues with the above (alongside many other things that I am probably overlooking):
1. VERY unbalanced dataset of targets - take a look at the number of 0s vs the number of 1s. This leads to an artificially high model accuracy.
2. Accuracy may not be the best "accuracy" metric here. [MoleculeNet](https://moleculenet.org/datasets-1) recommends ROC-AUC (Area Under Curve of Receiver Operating Characteristics) as a classification metric.

Let's integrate these changes. Head to `mutag.ipynb`!