# Graph Classification with GNN
- Based off of the tutorial [Graph Classification with Graph Neural Networks](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing)

In [7]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

# For CUDA GPUs (ignore if CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Download TU Dortmund University's MUTAG Dataset
- [More info on MUTAG](https://paperswithcode.com/dataset/mutag)

In particular, MUTAG is a collection of nitroaromatic compounds and the goal is to predict their mutagenicity on Salmonella typhimurium. Input graphs are used to represent chemical compounds, where vertices stand for atoms and are labeled by the atom type (represented by one-hot encoding), while edges between vertices represent bonds between the corresponding atoms. It includes 188 samples of chemical compounds with 7 discrete node labels.

The MUTAG dataset consists of 188 chemical compounds divided into two 
classes according to their mutagenic effect on a bacterium. 

The chemical data was obtained form http://cdb.ics.uci.edu and converted 
to graphs, where vertices represent atoms and edges represent chemical 
bonds. Explicit hydrogen atoms have been removed and vertices are labeled
by atom type and edges by bond type (single, double, triple or aromatic).
Chemical data was processed using the Chemistry Development Kit (v1.4).

Node labels:

- 0  C
- 1  N
- 2  O
- 3  F
- 4  I
- 5  Cl
- 6  Br

Edge labels:

- 0  aromatic
- 1  single
- 2  double
- 3  triple

This dataset provides 188 different graphs, and the task is to classify each graph into one out of two classes.

In [4]:
dataset = TUDataset(root='../data/TUDataset', name='MUTAG')
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: MUTAG(188):
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True


## Select Training/Testing Split

In [5]:
split_index = 150

torch.manual_seed(42)
dataset = dataset.shuffle()

train_dataset = dataset[:split_index]
test_dataset  = dataset[split_index:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 150
Number of test graphs: 38


### GNN Minibatching with Dataloader
- Regular mini-batching is computationally expensive
    - Adding another dimension to an adjacency matrix will increase memory consumption

- Instead, we stack adjacency matrices diagonally along a larger matrix
    - Message passing is not affected since the subgraphs are disconnected
    - No overhead because adjacency matrices are stored sparsely (0s are ignored)

In [8]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2604], x=[1172, 7], edge_attr=[2604, 4], y=[64], batch=[1172], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2536], x=[1155, 7], edge_attr=[2536, 4], y=[64], batch=[1155], ptr=[65])

Step 3:
Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 836], x=[376, 7], edge_attr=[836, 4], y=[22], batch=[376], ptr=[23])



## GNN Training
Training a GNN for graph classification usually follows a simple recipe:

1. Embed each node by performing multiple rounds of message passing
2. Aggregate node embeddings into a unified graph embedding (**readout layer**)
3. Train a final classifier on the graph embedding

There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings:

$$
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
$$

PyTorch Geometric provides this functionality via [`torch_geometric.nn.global_mean_pool`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool), which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `batch` to compute a graph embedding of size `[batch_size, hidden_channels]` for each graph in the batch.

The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:

In [9]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [10]:
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 002, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 003, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 004, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 005, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 006, Train Acc: 0.7000, Test Acc: 0.5789
Epoch: 007, Train Acc: 0.7600, Test Acc: 0.6316
Epoch: 008, Train Acc: 0.7600, Test Acc: 0.6842
Epoch: 009, Train Acc: 0.7267, Test Acc: 0.6842
Epoch: 010, Train Acc: 0.7533, Test Acc: 0.6579
Epoch: 011, Train Acc: 0.7333, Test Acc: 0.6579
Epoch: 012, Train Acc: 0.7467, Test Acc: 0.6842
Epoch: 013, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 014, Train Acc: 0.7800, Test Acc: 0.6842
Epoch: 015, Train Acc: 0.7600, Test Acc: 0.7105
Epoch: 016, Train Acc: 0.7667, Test Acc: 0.6842
Epoch: 017, Train Acc: 0.7667, Test Acc: 0.6842
Epoch: 018, Train Acc: 0.7733, Test Acc: 0.6316
Epoch: 019, Train Acc: 0.7733, Test Acc: 0.6842
Epoch: 020, Train Acc: 0.7800, Test Acc: 0.6842
Epoch: 021, Train Acc: 0.7733, Test Acc:

## Improving with GraphConv
- No neighborhood normalization (decreases expressivity)
- Adds skip-connection

$$
\mathbf{x}_v^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_v^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{w \in \mathcal{N}(v)} \mathbf{x}_w^{(\ell)}
$$

In [11]:
from torch_geometric.nn import GraphConv


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GNN(hidden_channels=64)
print(model)

GNN(
  (conv1): GraphConv(7, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [13]:
model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 201):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

GNN(
  (conv1): GraphConv(7, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)
Epoch: 001, Train Acc: 0.3067, Test Acc: 0.4474
Epoch: 002, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 003, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 004, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 005, Train Acc: 0.6933, Test Acc: 0.5526
Epoch: 006, Train Acc: 0.7000, Test Acc: 0.5789
Epoch: 007, Train Acc: 0.7733, Test Acc: 0.6842
Epoch: 008, Train Acc: 0.7733, Test Acc: 0.6579
Epoch: 009, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 011, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7667, Test Acc: 0.6842
Epoch: 013, Train Acc: 0.8067, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.8000, Test Acc: 0.7895
Epoch: 015, Train Acc: 0.7533, Test Acc: 0.6842
Epoch: 016, Train Acc: 0.8133, Test Acc: 0.8158
Epoch: 017, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 018, Train Acc: 0.8000, T