In [7]:
import numpy as np
import os

# matplotlib for plotting
import matplotlib
%matplotlib inline 
import matplotlib.pyplot as plt

# pytorch for neural network stuff
import torch
from torch import nn
from torch.nn import BatchNorm1d, MSELoss # loss function, normalization
import torch.optim as optim # optimizer
import torch.nn.functional as F

# torch-geometric for the graph stuff ontop of pytorch 
from torch_geometric.nn import GCNConv,global_mean_pool
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_networkx, to_networkx

import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

# datasets 
from torch_geometric.datasets import TUDataset

# pytorch lightning for automating training for us 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# tqdm for progress bars
from tqdm.notebook import tqdm



# see if we have a gpu
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.cuda("cpu")
print(device)

cuda:0


### Load data and take a look
We are loading the MUTAG dataset. See [here](https://paperswithcode.com/dataset/mutag) for more info. Basically, it is a dataset of compounds with the goal to predict some property.

I am following [this](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial7/GNN_overview.ipynb) tutorial

In [6]:
dataset = TUDataset(root='~/tmp/MUTAG', name='MUTAG')
data = dataset[0]

print("Data object:", dataset.data)
print("Length:", len(dataset))
print(f"Min label {dataset.data.y.float().min().item()}, Max label: {dataset.data.y.float().max().item()}")
print("Average label: %4.2f" % (dataset.data.y.float().mean().item()))


Data object: Data(edge_attr=[7442, 4], edge_index=[2, 7442], x=[3371, 7], y=[188])
Length: 188
Min label 0.0, Max label: 1.0
Average label: 0.66


We have 188 samples (graphs) with the labels ranging from 0 to 1 and the average label being 0.66. We can print more information/plot them (exercise left to the reader :) ).

In [8]:
# seed for reproducibility 
torch.manual_seed(42)

#shuffle the dataset so we can randomly sample our train/test split
dataset = dataset.shuffle()

# split train/test
train_dataset = dataset[:150]
test_dataset = dataset[150:]


In [10]:
# use a dataloader to load data
graph_train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# train and validation sets should NEVER have the same data
# will be doing it anyways because there is not the biggest dataset
graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=64) # Additional loader if you want to change to a larger dataset
graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=64)


In [12]:
batch = next(iter(graph_test_loader))


# torch-geometric batches are composed of graphs but come in one Batch object.
# we know which edge/node come from which graph because of the the edge_index attribute


print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])

Batch: Batch(batch=[668], edge_attr=[1466, 4], edge_index=[2, 1466], ptr=[39], x=[668, 7], y=[38])
Labels: tensor([1, 0, 1, 0, 0, 1, 1, 1, 1, 1])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


### Lets define our model

In [13]:
# gnn_layer_by_name = {
#     "GCN": geom_nn.GCNConv,
#     "GAT": geom_nn.GATConv,
#     "GraphConv": geom_nn.GraphConv
# }

# regular GNN model (will be our base model)
# The GNN applies a sequence of graph layers (GCN, GAT, or GraphConv),
# ReLU (rectified linear unit) as activation function, and dropout for regularization.
class GNNModel(nn.Module):    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        
        gnn_layer = geom_nn.GraphConv
        
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, 
                          out_channels=out_channels,
                          **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x, edge_index):
        """
        Forward-pass function. What gets called when we invoke our model (aka predict)
        ie GNNModel(*params)(batch)
        
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x
    

    
# with the graph structure we need to add a pooling layer since
# the batch_idx parameter tells us which nodes belong to which graph
# Here we simply use the above model but add an AveragePooling layer
# this simple averages over node/edge attributes and extracts features into an array
class GraphGNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of output features (usually number of classes)
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the GNNModel object
        """
        super().__init__()
        # use our model from above
        
        self.GNN = GNNModel(c_in=c_in, 
                            c_hidden=c_hidden, 
                            c_out=c_hidden, # Not our prediction output yet!
                            **kwargs)
        self.head = nn.Sequential(
            nn.Dropout(dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )
        
    
    def forward(self, x, edge_index, batch_idx):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
        """
        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling
        x = self.head(x)
        return x

### Lets use a Pytorch-lighting model to handle the training for us

From docs: "PyTorch Lightning is just organized PyTorch". 
It takes care of a lot for us. 

For more info on how to use lightning models, see [torch-lightning github](https://github.com/PyTorchLightning/pytorch-lightning)

In [14]:
class GraphLevelGNN(pl.LightningModule):
    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()
        
    # forward in torch lightning module is not the same as before. 
    # here it means the prediction step
    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
        else:
            preds = x.argmax(dim=-1)
        loss = self.loss_module(x, data.y)
        acc = (preds == data.y).sum().float() / preds.shape[0]
        return loss, acc
        
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0) # High lr because of small dataset and small model
        return optimizer
        
        
    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss
        
        
    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log('val_acc', acc)
        
        
    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)

In [15]:
# define function to train 
def train_graph_classifier(model_name, **model_kwargs):
    pl.seed_everything(42)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=500,
                         progress_bar_refresh_rate=0)
    
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "GraphLevel%s.ckpt" % model_name)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=dataset.num_node_features, 
                              c_out=1 if dataset.num_classes==2 else dataset.num_classes, 
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # Test best model on validation and test set
    train_result = trainer.test(model, test_dataloaders=graph_train_loader, verbose=False)
    test_result = trainer.test(model, test_dataloaders=graph_test_loader, verbose=False)
    result = {"test": test_result[0]['test_acc'], "train": train_result[0]['test_acc']} 
    return model, result

In [16]:
CHECKPOINT_PATH = './checkpoints'
model, result = train_graph_classifier(model_name="GraphConv", 
                                       c_hidden=256, 
                                       num_layers=3, 
                                       dp_rate_linear=0.5,
                                       dp_rate=0.0)

Global seed set to 42
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name        | Type              | Params
--------------------------------------------------
0 | model       | GraphGNNModel     | 266 K 
1 | loss_module | BCEWithLogitsLoss | 0     
--------------------------------------------------
266 K     Trainable params
0         Non-trainable params
266 K     Total params
1.067     Total estimated model params size (MB)
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [17]:
print("Train performance: %4.2f%%" % (100.0*result['train']))
print("Test performance:  %4.2f%%" % (100.0*result['test']))

Train performance: 82.10%
Test performance:  86.84%
