<a href="https://colab.research.google.com/github/alessiodevoto/gnns_xai_liverpool/blob/main/notebooks/A_Primer_on_Graph_Neural_Networks_(Liverpool).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Primer on Graph Neural Networks

**Author**: [Alessio Devoto](https://alessiodevoto.github.io/)

This is an introductory tutorial to [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) a library for the design of deep Graph Neural Networks(GNNs). The first part is a re-adaptation of the documentation from the PyG website, training a GNN for graph classification on the [MUTAG](https://paperswithcode.com/dataset/mutag) dataset (using [PyTorch Lightning](https://www.pytorchlightning.ai/) for the training loop). The notebook is inspired by [Simone Scardapane](https://sscardapane.it/)'s material on GNNs.




## 1. 🚗 Setup the colab environment


In [None]:
# We use a cpu based installation for torch geometric
# More info here https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
!pip install torch_geometric
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
!pip install pytorch-lightning --quiet

In [None]:
# PyTorch imports
import torch
from torch.nn import functional as F

In [None]:
# PyTorch-related imports
import torch_geometric as ptgeom
import torch_scatter, torch_sparse

In [None]:
import pytorch_lightning as ptlight
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics.functional import accuracy

In [None]:
# Other imports
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split

In [None]:
matplotlib.rcParams['figure.dpi'] = 120 # I like higher resolution plots :)

## 2. 💾 Data

### 2.1 Download & Explore Dataset

Pytorch Geometric provides a number of datasets to use off-the-shelf, for all graph related tasks (graph, node or edge level tasks). Find a complete list [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/data_cheatsheet.html).

In this tutorial, we will use the **MUTAG** dataset. See the MUTAG page on [Papers With Code](https://paperswithcode.com/dataset/mutag) and related papers for more information about the dataset. This is a toy version, so we do not care too much about the final performance.

In [None]:
# Download the dataset


In [None]:
# Useful info stored in the dataset class, e.g. num_node_features

In [None]:
# Each graph in the dataset is represented as an instance of the generic Data object:
# https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data


What is the meaning of each of these fields ?

![](https://raw.githubusercontent.com/alessiodevoto/gnns_xai_liverpool/main/images/simple_graph_labels.png)

In [None]:
# node features


In [None]:
# graph class (remember we only have two classes as this is a binary classfication problem)


In [None]:
# the adjacency matrix stored in COO format


In [None]:
# let's take a look at the first 4 edges


In [None]:
# Inside utils (https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html)
# there are a number of useful tools, e.g., we can check that the graph is undirected (the adjacency matrix is symmetric)


In [None]:
# are there self loops in this graph ?


In [None]:
# any isolated components ?


### 2.2 Data Visualization

As always, it is crucial to visualize (if possible) the data structures we are dealing with.

In the case of graphs, this can be prohibitive due to very high number of nodes. Luckily all our molecules are quite small.

Let's define a simple function to plot a graph, using `matplotlib` and the `networkx` package

In [None]:
# This one is copy-pasted from: https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing
import networkx as nx
import numpy as np
from torch_geometric.utils import to_networkx
from matplotlib.pyplot import figure

# transform the pytorch geometric graph into networkx format
def to_molecule(data: ptgeom.data.Data) -> nx.classes.digraph.DiGraph:
    ATOM_MAP = ['C', 'O', 'Cl', 'H', 'N', 'F',
                'Br', 'S', 'P', 'I', 'Na', 'K', 'Li', 'Ca']
    g = to_networkx(data, node_attrs=['x'])
    for u, data in g.nodes(data=True):
        data['name'] = ATOM_MAP[data['x'].index(1.0)]
        del data['x']
    return g

# plot the molecule
def draw_molecule(g, edge_mask=None, draw_edge_labels=True, draw_node_labels=True, ax=None, figsize=None):
    figure(figsize = figsize or (4, 3))

    # check if it's been already converted to a nx graph
    if not isinstance(g, nx.classes.digraph.DiGraph):
      g = to_molecule(g)

    g = g.copy().to_undirected()
    node_labels = {}
    for u, data in g.nodes(data=True):
        node_labels[u] = data['name']
    pos = nx.planar_layout(g)
    pos = nx.spring_layout(g, pos=pos)
    if edge_mask is None:
        edge_color = 'black'
        widths = None
    else:
        edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
        widths = [x * 10 for x in edge_color]
    nx.draw(g, pos=pos, labels=node_labels if draw_node_labels else None, width=widths,
            edge_color=edge_color, edge_cmap=plt.cm.Blues,
            node_color='azure')

    if draw_edge_labels and edge_mask is not None:
        edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
                                    font_color='red', ax=ax)

    if ax is None:
      plt.show()


In [None]:
# plot

In [None]:
# @title Visualize some graphs { run: "auto" }
mutag_idx = 6 # @param {type:"slider", min:0, max:187, step:1}

draw_molecule(mutag[mutag_idx])


### 2.3: Transformations

Transformations are a quick way to include standard preprocessing when loading the graphs (e.g., automatically computing edge from the nodes positions). They work pretty much like torchvision's transforms.

See the full list of available transformations here:

https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html

In [None]:
# As an experiment, we load the graph with a sparse adjacency format instead of the COO list

In [None]:
# The format has a number of useful methods that are already implemented: https://github.com/rusty1s/pytorch_sparse
# For example, we can perform a single step of diffusion on the node features efficiently with a sparse-dense matrix multiplication

🔥 **Warmup Exercise no. 1**

Imagine you are a (probably crazy) chemist and you want to *add self loops* to all of the molecules in your dataset.

What would you do? Plot a graph after adding self loops. Hint: [this transform](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.AddSelfLoops.html#torch_geometric.transforms.AddSelfLoops)

In [None]:
# We load the graph and add self loops to all nodes (probably doesn't make sense from a chemical point of view 🤔)


### 2.4 Data loading

Data loaders are a nice utility to automatically build mini-batches (either a subset of graphs, or a subgraph extracted from a single graph) from the dataset.


Pytorch Geometric manages the batching by [stacking the adjacency matrices](https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html) into a single huge graph.



In [None]:
# Plain MUTAG without self loops
mutag = ptgeom.datasets.TUDataset(root='.', name='MUTAG')

In [None]:
# First, we split the original dataset into a training and test spart with a stratified split on the class
train_idx, test_idx = train_test_split(range(len(mutag)), stratify=[m.y[0].item() for m in mutag], test_size=0.25, random_state=11)

train_mutag = mutag[train_idx]
test_mutag = mutag[test_idx]

In [None]:
# Build the two loaders
train_loader = ptgeom.loader.DataLoader(train_mutag, batch_size=32, shuffle=True)
test_loader = ptgeom.loader.DataLoader(test_mutag, batch_size=32)

In [None]:
# Let us inspect the first batch of data


In [None]:
# The batch is built by considering all the subgraphs as a single giant graph with unconnected components
# Let us explore some of the components of the batch



🔥 **Warmup Exercise no. 2**

As we said, PyTorch Geometric creates batches by stacking together small graphs into a single large one.

Create a dataloader with batch_size = 2 and plot the first batch to check it's content.

In [None]:
# Don't do this with large batch size

In [None]:
# If we built this new huge graph, how do we keep track of all the small subgraphs 🤔 ?
# There is an additional property linking each node to its corresponding graph index
# Print the batch

![](https://raw.githubusercontent.com/alessiodevoto/gnns_xai_liverpool/main/images/batching.png)

In [None]:
# We can perform a graph-level average with torch_scatter, see the figure here for a visual explanation:
# https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html


In [None]:
# Alternatively, PyG has this implemented as a functional layer in nn.global_mean_pool


## 3. 🪄 Design and train the Graph Neural Network

We have explored the data and created the Dataloaders, which will help us during the training. We are finally able to build the model!

In [None]:
# Layers in PyG are very similar to PyTorch, e.g., this is a standard graph convolutional layer GCNConv


In [None]:
# Pay attention to the forward arguments

In [None]:
# Different layers have different properties, see this "cheatsheet" from the documentation:
# https://pytorch-geometric.readthedocs.io/en/latest/notes/cheatsheet.html
# For example, GCNConv accepts an additional "edge_weight" parameter to weight each edge.

In [None]:
# If you are not used to PyTorch Lightning, see the 5-minutes intro from here:
# https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html

train_losses = []
eval_accs = []

class MUTAGClassifier(ptlight.LightningModule):

  def __init__(self, hidden_features: int):
    super().__init__()
    # Here go the layers


  def forward(self, x, edge_index=None, batch=None, edge_weight=None):

    # unwrap the graph if the whole Data object was passed
    if edge_index is None:
      x, edge_index, batch = x.x, x.edge_index, x.batch

    # Here we process the input
    # We go gcn -> F.relu -> mean_pool -> F.dropout -> Linear

    return logits

  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
      return optimizer


  def training_step(self, batch, _):

      # Training step

      # Log loss

      return loss

  def validation_step(self, batch, _):

    # Eval step

    # Log accuracy





In [None]:
# print the model
model = MUTAGClassifier(256)
model

In [None]:
# forward one batch

### 3.1 Training the model



In [None]:
# We save checkpoints every 50 epochs
# This is like taking 'snapshots' of the model every 50 epochs
# We will use this in the next notebook

checkpoint_callback = ptlight.callbacks.ModelCheckpoint(
    dirpath='./checkpoints/',
    filename='gnn-{epoch:02d}',
    every_n_epochs=50,
    save_top_k=-1)

In [None]:
# define the trainer
trainer = ptlight.Trainer(
    max_epochs=80,
    callbacks=[checkpoint_callback])

In [None]:
# This is not a particularly well-designed model, we expect approximately 80% test accuracy
trainer.fit(model, train_loader, test_loader)

In [None]:
# simple plots to visualize metrics

plt.figure(figsize=(5,4))
plt.plot(train_losses)
plt.plot(eval_accs)

plt.legend(['Loss', 'Accuracy'])
plt.show()

In [None]:
# not working due to cookies settings in most cases
%reload_ext tensorboard
%tensorboard --logdir=/content/lightning_logs

## 4. 💪 Exercise time

Pytorch geometric contains a wide range of possibilities for Graph Convolutional layers. You can find them [here](https://pytorch-geometric.readthedocs.io/en/latest/cheatsheet/gnn_cheatsheet.html).

1. Instead of the simple `GCNConv` we used, build a model making use of different layers, e.g. the GATConv. Train the model and compare the results with the ones we obtained. Are they better or worse?

2. (If we have time) Can we change the forward function of our model and also use edge weights. Is it beneficial for the training ?

In [None]:
from torch_geometric.nn import GATConv

# Keep track of metrics here
train_losses = []
train_accs = []
eval_accs = []

# Define the new model
class MyCoolGNN(ptlight.LightningModule):

  def __init__(self, hidden_features: int, heads: int):
    super().__init__()

  def forward(self, x, edge_index=None, batch=None, edge_weight=None):

    # unwrap the graph if the whole Data object was passed
    if edge_index is None:
      x, edge_index, batch = x.x, x.edge_index, x.batch

    return logits

  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
      return optimizer


  def training_step(self, batch, _):


  def validation_step(self, batch, _):


model = MyCoolGNN(256, 2)

In [None]:
# Test on one batch


In [None]:
# Train (no callbacks needed this time)


In [None]:
# Plot the metrics
