In [2]:
import sys
sys.path.append('/notebooks')

import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import matplotlib.pyplot as plt

import numpy as np
import scipy.sparse as sp
import pickle
import pandas as pd
import torch_scatter
from collections import Counter


from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

from src.basicGNN import basicGNN

##  MUTAG DATASET
**Dataset Summary**

The MUTAG dataset is a collection of nitroaromatic compounds and the goal is to predict their mutagenicity on Salmonella typhimurium'.


**Supported Tasks and Leaderboards**

MUTAG should be used for molecular property prediction (aiming to predict whether molecules have a mutagenic effect on a given bacterium or not), a binary classification task. The score used is accuracy, using a 10-fold cross-validation.

In [3]:
dataset_directory = '/notebooks/data'
dataset = TUDataset(root=dataset_directory, name='MUTAG')

In [4]:
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: MUTAG(188):
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [5]:
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:int(len(dataset)*0.8)]
test_dataset = dataset[int(len(dataset)*0.8):]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 150
Number of test graphs: 38


In [6]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [7]:
model = basicGNN(dataset.num_node_features, [64, 64, 64], dataset.num_classes)
print(model)


basicGNN(
  (conv_layers): ModuleList(
    (0): GraphConv(7, 64)
    (1): GraphConv(64, 64)
    (2): GraphConv(64, 64)
  )
  (final_layer): GraphConv(64, 2)
)


In [9]:
model = basicGNN(in_channels=dataset.num_node_features, 
                 hidden_channels=[64, 64, 64],
                 out_channels=dataset.num_classes, 
                 mlp=True, 
                 pooling=True)                 
                 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train(model):
    losses = []
    model.train()
    for batch in train_loader:
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 31):
    train(model)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 008, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 009, Train Acc: 0.7200, Test Acc: 0.8158
Epoch: 010, Train Acc: 0.7333, Test Acc: 0.8158
Epoch: 011, Train Acc: 0.7467, Test Acc: 0.8158
Epoch: 012, Train Acc: 0.7533, Test Acc: 0.8158
Epoch: 013, Train Acc: 0.7467, Test Acc: 0.8421
Epoch: 014, Train Acc: 0.7867, Test Acc: 0.8421
Epoch: 015, Train Acc: 0.7533, Test Acc: 0.7895
Epoch: 016, Train Acc: 0.8133, Test Acc: 0.8421
Epoch: 017, Train Acc: 0.8000, Test Acc: 0.8158
Epoch: 018, Train Acc: 0.7933, Test Acc: 0.8158
Epoch: 019, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.7867, Test Acc: 0.8158
Epoch: 021, Train Acc: 0.7933, Test Acc: