# Graph Classification with PyTorch Geometric

Graph classification is a task where we aim to predict a label for an entire graph, rather than individual nodes or edges.
This is a natural extension to node or edge classification and has applications in areas like molecule classification, social network classification, etc.
In this notebook, we'll dive into this task using PyTorch Geometric.

## Table of Contents

1. [Loading the Dataset](#Loading-the-Dataset)
2. [Defining the Model](#Defining-the-Model)
3. [Setting up the Training Loop](#Setting-up-the-Training-Loop)
4. [Evaluating Training Results](#Evaluating-Training-Results)


In [1]:
# Uncomment the following line to install the required packages
# !pip install torch torchaudio torch-geometric

## Loading the Dataset <a name="Loading-the-Dataset"></a>

PyTorch Geometric offers datasets that are particularly tailored for graph classification.

In [2]:
from torch_geometric.datasets import TUDataset

# Switching to the PROTEINS dataset
dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')

# Splitting the dataset
train_dataset = dataset[:len(dataset) // 10 * 8]
test_dataset = dataset[len(dataset) // 10 * 8:]

print(f'Number of classes: {dataset.num_classes}')
print(f'Number of graph features: {dataset.num_features}')


Number of classes: 2
Number of graph features: 3


## Defining the Model <a name="Defining-the-Model"></a>

We'll use a Graph Convolution Network (GCN) followed by global pooling to perform graph classification.

In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphClassifier, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        
        return F.log_softmax(self.fc(x), dim=1)

# Move model to GPU
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")  # Use mps to train on mac m1 gpu
model = GraphClassifier(input_dim=dataset.num_features, hidden_dim=64, output_dim=dataset.num_classes).to(device)

## Setting up the Training Loop <a name="Setting-up-the-Training-Loop"></a>

Let's train our graph classification model.

In [4]:
from torch_geometric.loader import DataLoader

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(50):
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
    
    # Print loss for every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch: {epoch}, Loss: {loss.item()}')

  size = int(batch.max().item() + 1) if size is None else size


Epoch: 0, Loss: 0.45274659991264343
Epoch: 10, Loss: 0.47858163714408875
Epoch: 20, Loss: 0.49092409014701843
Epoch: 30, Loss: 0.5761591792106628
Epoch: 40, Loss: 0.49991852045059204


## Evaluating Training Results <a name="Evaluating-Training-Results"></a>

It's time to evaluate the model's performance on the test set.

In [5]:
def compute_accuracy(model, loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
        
    return correct / len(loader.dataset)

test_accuracy = compute_accuracy(model, test_loader)
print(f'Test Accuracy: {test_accuracy:.4f}')

Test Accuracy: 0.2756


The accuaracy is very low, but this is expected since we're using a very simple model.
To improve the performance, we can try out different architectures, hyperparameters, and datasets.
And as before, try out Optuna and DeepHyper for hyperparameter tuning.