# Lab E.4 Solutions: Graph Classification

Complete solutions to all exercises in Lab E.4.

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, TopKPooling

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Exercise 1 Solution: PROTEINS Dataset

In [None]:
# Load PROTEINS dataset
proteins = TUDataset(root='/tmp/PROTEINS', name='PROTEINS')

print("=" * 50)
print("PROTEINS DATASET")
print("=" * 50)
print(f"Number of graphs: {len(proteins)}")
print(f"Number of classes: {proteins.num_classes}")
print(f"Number of node features: {proteins.num_node_features}")

# Graph statistics
num_nodes = [d.num_nodes for d in proteins]
num_edges = [d.num_edges for d in proteins]
print(f"\nNodes per graph: min={min(num_nodes)}, max={max(num_nodes)}, avg={np.mean(num_nodes):.1f}")
print(f"Edges per graph: min={min(num_edges)}, max={max(num_edges)}, avg={np.mean(num_edges):.1f}")

# Class distribution
labels = [d.y.item() for d in proteins]
for c in range(proteins.num_classes):
    count = labels.count(c)
    print(f"Class {c}: {count} ({100*count/len(proteins):.1f}%)")

In [None]:
# Create train/test split
proteins = proteins.shuffle()
train_size = int(0.8 * len(proteins))
train_dataset = proteins[:train_size]
test_dataset = proteins[train_size:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

class GraphClassifier(nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes, pooling='mean_max'):
        super().__init__()
        self.pooling = pooling
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        pool_dim = hidden_dim * 2 if pooling == 'mean_max' else hidden_dim
        self.classifier = nn.Sequential(
            nn.Linear(pool_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        if self.pooling == 'mean':
            x = global_mean_pool(x, batch)
        elif self.pooling == 'max':
            x = global_max_pool(x, batch)
        elif self.pooling == 'mean_max':
            x = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], dim=1)
        
        return self.classifier(x)

def train_and_eval(pooling, epochs=100):
    model = GraphClassifier(proteins.num_node_features, 64, 
                           proteins.num_classes, pooling).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    best_acc = 0
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = F.cross_entropy(out, batch.y)
            loss.backward()
            optimizer.step()
        
        model.eval()
        correct = 0
        for batch in test_loader:
            batch = batch.to(device)
            with torch.no_grad():
                pred = model(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
                correct += (pred == batch.y).sum().item()
        acc = correct / len(test_dataset)
        if acc > best_acc:
            best_acc = acc
    
    return best_acc

# Compare pooling strategies
print("\nComparing Pooling Strategies on PROTEINS")
print("=" * 50)

for pooling in ['mean', 'max', 'mean_max']:
    acc = train_and_eval(pooling)
    print(f"{pooling:10s}: {acc:.4f}")

## Exercise 2 Solution: TopK Pooling

In [None]:
from torch_geometric.nn import TopKPooling, global_mean_pool

class HierarchicalGraphClassifier(nn.Module):
    """
    Graph classifier with hierarchical TopK pooling.
    
    Architecture:
        GCN â†’ TopK Pool (keep 50%) â†’ GCN â†’ TopK Pool (keep 50%) â†’ Global Pool â†’ MLP
    """
    
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        
        # First block
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.pool1 = TopKPooling(hidden_dim, ratio=0.5)
        
        # Second block
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.pool2 = TopKPooling(hidden_dim, ratio=0.5)
        
        # Third block (no pooling)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x, edge_index, batch):
        # Block 1
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)
        
        # Block 2
        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)
        
        # Block 3
        x = F.relu(self.conv3(x, edge_index))
        
        # Global pool and classify
        x = global_mean_pool(x, batch)
        return self.classifier(x)

# Train hierarchical model
model = HierarchicalGraphClassifier(
    proteins.num_node_features, 64, proteins.num_classes
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("Training Hierarchical Graph Classifier with TopK Pooling")
print("=" * 60)

best_acc = 0
for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    model.eval()
    correct = 0
    for batch in test_loader:
        batch = batch.to(device)
        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
            correct += (pred == batch.y).sum().item()
    acc = correct / len(test_dataset)
    
    if acc > best_acc:
        best_acc = acc
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss: {total_loss/len(train_loader):.4f} | Test: {acc:.4f}")

print(f"\nðŸŽ‰ Best accuracy with TopK pooling: {best_acc:.4f}")

In [None]:
# Compare all methods
print("\n" + "=" * 50)
print("FINAL COMPARISON ON PROTEINS")
print("=" * 50)

# Train each method 3 times and average
results = {}

for pooling in ['mean', 'max', 'mean_max']:
    accs = [train_and_eval(pooling) for _ in range(3)]
    results[pooling] = (np.mean(accs), np.std(accs))
    print(f"{pooling:12s}: {np.mean(accs):.4f} Â± {np.std(accs):.4f}")

print(f"{'topk':12s}: {best_acc:.4f} (single run)")

print("\nðŸ’¡ Insights for PROTEINS:")
print("   - Proteins are larger graphs (avg ~39 nodes) vs MUTAG (~18 nodes)")
print("   - Hierarchical pooling can help capture multi-scale structure")
print("   - Mean+Max pooling is often a strong baseline!")

## Challenge Solution: Simplified DiffPool

In [None]:
class SimpleDiffPool(nn.Module):
    """
    Simplified Differentiable Pooling.
    
    Learns soft cluster assignments:
    S = softmax(GNN(X, A))  # Cluster assignment matrix
    X_new = S^T @ X         # Coarsened features
    A_new = S^T @ A @ S     # Coarsened adjacency
    """
    
    def __init__(self, in_channels, num_clusters):
        super().__init__()
        self.num_clusters = num_clusters
        
        # GNN for computing cluster assignments
        self.assign_gnn = GCNConv(in_channels, num_clusters)
        
        # GNN for computing node embeddings
        self.embed_gnn = GCNConv(in_channels, in_channels)
    
    def forward(self, x, edge_index, batch):
        """
        Forward pass.
        
        Returns:
            Pooled features, assignment matrix
        """
        # Compute node embeddings
        z = F.relu(self.embed_gnn(x, edge_index))
        
        # Compute soft cluster assignments
        s = self.assign_gnn(x, edge_index)  # [N, num_clusters]
        s = F.softmax(s, dim=-1)
        
        # For each graph in batch, pool separately
        # This is a simplified version - real DiffPool handles batching differently
        
        # Pool features: X_new = S^T @ Z
        # We use global pooling as a simplified approach
        pooled = torch.zeros(batch.max().item() + 1, z.size(1), device=z.device)
        
        for i in range(batch.max().item() + 1):
            mask = batch == i
            z_i = z[mask]
            s_i = s[mask]
            
            # Weighted sum by assignment probabilities
            pooled[i] = (s_i.sum(dim=1, keepdim=True) * z_i).sum(dim=0)
        
        return pooled, s

print("SimpleDiffPool implemented!")
print("\nðŸ’¡ Real DiffPool (Ying et al., 2018) is more complex:")
print("   - Properly handles the coarsened adjacency matrix")
print("   - Uses auxiliary loss for cluster quality")
print("   - Performs hierarchical coarsening in multiple stages")
print("\n   For production use, see torch_geometric.nn.dense.diff_pool")

---

## Key Takeaways

1. **PROTEINS is larger and harder** than MUTAG - more nodes, more complex structure
2. **TopK pooling** creates hierarchical representations - good for multi-scale patterns
3. **Mean+Max pooling** is a strong, simple baseline for most tasks
4. **DiffPool** learns soft cluster assignments for adaptive pooling
5. **Choose pooling based on task** - no single method is always best!