This notebook is adapted from [Stanford CS224](http://snap.stanford.edu/class/cs224w-2021/) by Jure Leskovec.

# Graph Convolutional Neural Networks

In this notebook, we will work to construct our own graph neural network using PyTorch Geometric (PyG) and then apply that model on two Open Graph Benchmark (OGB) datasets. These two datasets will be used to benchmark your model's performance on two different graph-based tasks: 1) node property prediction, predicting properties of single nodes and 2) graph property prediction, predicting properties of entire graphs or subgraphs.

we will load and inspect the Open Graph Benchmark (OGB) datasets by using the `ogb` package. OGB is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs. The `ogb` package not only provides data loaders for each dataset but also model evaluators.

Lastly, we will build our own graph neural network using PyTorch Geometric. We will then train and evaluate our model on the OGB node property prediction and graph property prediction tasks.

In [None]:
import torch
import os
print("PyTorch has version {}".format(torch.__version__))

In [None]:
!pip install torch-geometric==2.4.0
!pip install ogb

Download the necessary packages for PyG. Make sure that your version of torch matches the output from the cell above. In case of any issues, more information can be found on the [PyG's installation page](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html).

In [None]:
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch
import pandas as pd
import torch.nn.functional as F
print(torch.__version__)

# The PyG built-in GCNConv
from torch_geometric.nn import GCNConv


# 1) PyTorch Geometric (Datasets and Data)


PyTorch Geometric has two classes for storing and/or transforming graphs into tensor format. One is `torch_geometric.datasets`, which contains a variety of common graph datasets. Another is `torch_geometric.data`, which provides the data handling of graphs in PyTorch tensors.

In this section, we will learn how to use `torch_geometric.datasets` and `torch_geometric.data` together.

## PyG Datasets

The `torch_geometric.datasets` class has many common graph datasets. Here we will explore its usage through one example dataset.

In [None]:
enzymes_dataset = TUDataset(root='.', name='PROTEINS').shuffle()

## Question 1: What is the number of classes and number of features in the ENZYMES dataset? (5 points)

In [None]:
print(f'Dataset: {enzymes_dataset}:')
print('======================')
print(f'Number of graphs: {len(enzymes_dataset)}')
print(f'Number of features: {enzymes_dataset.num_features}')
print(f'Number of classes: {enzymes_dataset.num_classes}')

print(enzymes_dataset.x[0])

## PyG Data

Each PyG dataset stores a list of `torch_geometric.data.Data` objects, where each `torch_geometric.data.Data` object represents a graph. We can easily get the `Data` object by indexing into the dataset.

For more information such as what is stored in the `Data` object, please refer to the [documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data).

In [None]:
print(enzymes_dataset[0])
print('Graph with index {} has label {}'.format(200, enzymes_dataset[200].y.item()))
print('Graph with index {} has {} number of edges'.format(200, enzymes_dataset[200].edge_index.shape[1]))

# 3) GNN: Node Property Prediction

In this section we will build our first graph neural network using PyTorch Geometric. Then we will apply it to the task of node property prediction (node classification).

Specifically, we will use GCN as the foundation for your graph neural network ([Kipf et al. (2017)](https://arxiv.org/pdf/1609.02907.pdf)). To do so, we will work with PyG's built-in `GCNConv` layer.

## Load and Preprocess the Dataset

In [None]:
from torch_geometric.loader import DataLoader
demo_loader = DataLoader(enzymes_dataset[:4], batch_size=3, shuffle=False)

print(enzymes_dataset[0])
print(enzymes_dataset[1])

print(enzymes_dataset[2])
print(enzymes_dataset[3])
print('============')
for idx, batch in enumerate(demo_loader):
    print(batch, batch.ptr)
    print(batch.num_graphs)


# Advanced Graph Batching

What happened internally on the dataloader?. Since individual graph have represeneted their connetcivity according to the number of nodes in them a link on the graph 0 could read like [[0],[1]]
another link on the graph 2 could have the very same connectivty. Howvere they are no inside the grap, and therefore represent a single connection on the agggeragted graph.

In [None]:
from torch_geometric.loader import DataLoader
demo_loader2 = DataLoader(enzymes_dataset, batch_size=32, shuffle=True)

for idx, batch in enumerate(demo_loader2):
    print(batch)
    print(batch.num_graphs)
    if idx > 20:
        break



In [None]:
from torch_geometric.utils import to_networkx
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

import networkx as nx
import matplotlib.pyplot as plt

G = to_networkx(enzymes_dataset[35], to_undirected=True)

# 3D spring layout
pos = nx.spring_layout(G, dim=3, seed=0)

# Extract node and edge positions from the layout
node_xyz = np.array([pos[v] for v in sorted(G)])
edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])

# Create the 3D figure
fig = plt.figure(figsize=(16,16))
ax = fig.add_subplot(111, projection="3d")

# Suppress tick labels
for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
    dim.set_ticks([])

# Plot the nodes - alpha is scaled by "depth" automatically
ax.scatter(*node_xyz.T, s=500, c="#0A047A")

# Plot the edges
for vizedge in edge_xyz:
    ax.plot(*vizedge.T, color="tab:gray")

# fig.tight_layout()
plt.show()

In [None]:
from torch_geometric.loader import DataLoader

# Create training, validation, and test sets
train_idx = int(len(enzymes_dataset)*0.8)
train_dataset = enzymes_dataset[:train_idx]
val_dataset   = enzymes_dataset[train_idx:]

print(f'Training set   = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')

# Create mini-batches
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

print('\nTrain loader:')
for i, subgraph in enumerate(train_loader):
    print(f' - Subgraph {i}: {subgraph}')

print('\nValidation loader:')
for i, subgraph in enumerate(val_loader):
    print(f' - Subgraph {i}: {subgraph}')


## GCN Model

Now we will implement our GCN model!


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

import torch
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool


class EnzymesGCN(torch.nn.Module):
    """GCN"""
    def __init__(self, dim_h):
        super(EnzymesGCN, self).__init__()
        self.conv1 = GCNConv(train_dataset.num_node_features, dim_h)
        self.conv2 = GCNConv(dim_h, dim_h)
        self.conv3 = GCNConv(dim_h, dim_h)
        self.conv4 = GCNConv(dim_h, dim_h)
        self.conv5 = GCNConv(dim_h, dim_h)
        self.lin = Linear(dim_h, train_dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embeddings
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = self.conv2(h, edge_index)
        h = h.relu()
        h = self.conv3(h, edge_index)
        h = h.relu()
        h = self.conv4(h, edge_index)
        h = h.relu()
        h = self.conv5(h, edge_index)
        h = h.relu()

        # Graph-level readout
        hG = global_mean_pool(h, batch)

        # Classifier
        h = F.dropout(hG, p=0.5, training=self.training)
        h = self.lin(h)

        return hG, F.log_softmax(h, dim=1)

print('GCN', EnzymesGCN)


class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(train_dataset.num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, train_dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embeddings
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)

        return h, F.log_softmax(h, dim=1)

In [None]:
def train(model, train_loader, val_loader):
    print('Train')
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)
    epochs = 100


    for epoch in range(epochs+1):
        print('epoch', epoch)

        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0

        # Train on batches
        model.train()
        for data in train_loader:

            optimizer.zero_grad()
            _, out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            print(loss)
            total_loss += loss / len(train_loader)
            acc += accuracy(out.argmax(dim=1), data.y) / len(train_loader)
            loss.backward()
            optimizer.step()

        # Validation
        val_loss, val_acc = test(model, val_loader)

        # Print metrics every 10 epochs
        print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} '
              f'| Train Acc: {acc*100:>5.2f}% '
              f'| Val Loss: {val_loss:.2f} '
              f'| Val Acc: {val_acc*100:.2f}%')

    return model


def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        _, out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

In [None]:
egcn = GIN(dim_h=32)
egcn = train(egcn, train_loader, val_loader)