By the end of this week, you will:
- Understand the motivations behind advanced GNN models
- Implement GAT, GraphSAGE, and GIN using PyTorch Geometric
- Train and evaluate them on the Cora dataset
- Compare performance and visualize learned embeddings

# 1. Recap — Why Go Beyond GCN?
| Problem with GCN                            | Solution Model                              |
| ------------------------------------------- | ------------------------------------------- |
| Treats all neighbors equally                | **GAT** — uses attention to weigh neighbors |
| Requires full-graph training (not scalable) | **GraphSAGE** — samples neighbors           |
| Limited expressive power                    | **GIN** — injects MLP expressiveness        |


# 2. Dataset Setup (Cora Again)

In [6]:
from torch_geometric.datasets import Planetoid
import torch
import torch.nn.functional as F

dataset = Planetoid(root="data/Planetoid", name="Cora")
data = dataset[0]


# 3. Graph Attention Network (GAT)

Idea:

Each node attends differently to its neighbors.
Attention coefficients $α_{ij}$

 determine the importance of neighbor j to node i.

In [7]:
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [8]:
model = GAT(dataset.num_features, 8, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
pred = out.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum() / int(data.test_mask.sum())
print(f"GAT Accuracy: {acc:.4f}")


GAT Accuracy: 0.6080


# 4. GraphSAGE (Inductive Learning)

Idea:

Instead of full adjacency aggregation, sample a fixed number of neighbors.
Supports inductive tasks (new unseen nodes).

In [9]:
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [10]:
model = GraphSAGE(dataset.num_features, 16, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
pred = out.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum() / int(data.test_mask.sum())
print(f"GraphSAGE Accuracy: {acc:.4f}")


GraphSAGE Accuracy: 0.7400


# 5. Graph Isomorphism Network (GIN)

Idea:

Aims to be as powerful as the Weisfeiler-Lehman (WL) test for graph isomorphism.
Uses sum aggregation + MLP for expressive updates.

In [11]:
from torch_geometric.nn import GINConv
from torch.nn import Linear, Sequential, ReLU

class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        nn1 = Sequential(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(nn1)

        nn2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels, out_channels))
        self.conv2 = GINConv(nn2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [12]:
model = GIN(dataset.num_features, 32, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
pred = out.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum() / int(data.test_mask.sum())
print(f"GIN Accuracy: {acc:.4f}")


GIN Accuracy: 0.7390


# 6. Compare Models

| Model     | Aggregation        | Key Feature                   | Typical Accuracy (Cora) |
| --------- | ------------------ | ----------------------------- | ----------------------- |
| GCN       | Mean               | Spectral Conv                 | ~80%                    |
| GAT       | Attention-weighted | Learnable neighbor importance | ~60%                    |
| GraphSAGE | Mean/Max/Sum       | Inductive sampling            | ~74%                    |
| GIN       | Sum + MLP          | High expressiveness           | ~73%                    |


# Exercises

## EX1: Try Different Aggregations in GraphSAGE



In [13]:
#Replace:

#SAGEConv(in_channels, hidden_channels, aggr="max")

#Compare "mean", "max", "sum" aggregations.

# EX2 Exercise 2 — Attention Heads in GAT



In [None]:

#Try:
##model = GAT(dataset.num_features, 8, dataset.num_classes, heads=4)
##Then heads=8, heads=16.
#Measure how multi-head attention changes performance.

## Exercise 3 — Embedding Visualization

In [14]:
# Use t-SNE again:

# from sklearn.manifold import TSNE
# import matplotlib.pyplot as plt
# 
# z = model.conv1(data.x, data.edge_index).detach()
# z_2d = TSNE(n_components=2).fit_transform(z)
# plt.scatter(z_2d[:,0], z_2d[:,1], c=data.y, cmap='tab10', s=15)
# plt.title("Node Embeddings Visualization")
# plt.show()