# Lesson 7: Graph Pooling & Hierarchical GNNs

## Practical Implementation and Experiments

In this notebook, we'll implement and experiment with various graph pooling techniques:
1. Global pooling operations
2. Hierarchical pooling
3. DiffPool
4. SAGPool
5. Graph classification tasks
6. Molecular property prediction
7. Visualization of pooled graphs

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Optional
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split

try:
    import torch_geometric
    from torch_geometric.data import Data, DataLoader as GeoDataLoader, Batch
    from torch_geometric.nn import GCNConv, GlobalMeanPool, GlobalMaxPool, global_mean_pool, global_max_pool, global_add_pool
    from torch_geometric.utils import to_networkx, subgraph
    import torch_geometric.transforms as T
except ImportError:
    print("PyTorch Geometric not installed. Installing...")
    import subprocess
    subprocess.check_call(['pip', 'install', 'torch-geometric'])
    from torch_geometric.data import Data, DataLoader as GeoDataLoader, Batch
    from torch_geometric.nn import GCNConv, GlobalMeanPool, GlobalMaxPool, global_mean_pool, global_max_pool, global_add_pool

import networkx as nx
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split

print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

---

## Part 1: Global Pooling Operations

Let's start by implementing and understanding basic pooling operations.

In [None]:
class GlobalPooling(nn.Module):
    """Implements different global pooling operations."""

    def __init__(self, method: str = 'mean'):
        """
        Args:
            method: 'sum', 'mean', 'max', or 'concat' (concatenates all three)
        """
        super().__init__()
        assert method in ['sum', 'mean', 'max', 'concat']
        self.method = method

    def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features [num_nodes, num_features]
            batch: Batch assignment tensor [num_nodes] - which graph each node belongs to

        Returns:
            Graph-level embeddings [num_graphs, num_features or 3*num_features]
        """
        if self.method == 'sum':
            return global_add_pool(x, batch)
        elif self.method == 'mean':
            return global_mean_pool(x, batch)
        elif self.method == 'max':
            return global_max_pool(x, batch)
        else:  # concat
            sum_pool = global_add_pool(x, batch)
            mean_pool = global_mean_pool(x, batch)
            max_pool = global_max_pool(x, batch)
            return torch.cat([sum_pool, mean_pool, max_pool], dim=1)

    def __repr__(self):
        return f"GlobalPooling(method='{self.method}')"


# Test the pooling operations
print("Global Pooling Operations\n" + "="*50)

# Create sample node embeddings from 3 graphs
# Graph 1: 4 nodes, Graph 2: 3 nodes, Graph 3: 5 nodes
x = torch.tensor([
    # Graph 1
    [1.0, 2.0],
    [3.0, 4.0],
    [5.0, 6.0],
    [7.0, 8.0],
    # Graph 2
    [2.0, 3.0],
    [4.0, 5.0],
    [6.0, 7.0],
    # Graph 3
    [1.0, 1.0],
    [2.0, 2.0],
    [3.0, 3.0],
    [4.0, 4.0],
    [5.0, 5.0],
], dtype=torch.float32)

# Batch assignment: [0,0,0,0, 1,1,1, 2,2,2,2,2]
batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2])

# Test different pooling methods
for method in ['sum', 'mean', 'max', 'concat']:
    pool = GlobalPooling(method=method)
    result = pool(x, batch)
    print(f"\n{method.upper()} Pooling:")
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {result.shape}")
    print(f"  Output:\n{result}")

### Visualizing Pooling Effects

Let's visualize how different pooling methods aggregate information differently.

In [None]:
def visualize_pooling_effects():
    """Visualize how different pooling methods affect information aggregation."""

    # Create sample feature vectors with different characteristics
    features = {
        'Uniform': torch.ones(10, 3),
        'Sparse': torch.tensor(
            [[1, 0, 0]] * 5 + [[0, 0, 0]] * 5,
            dtype=torch.float32
        ),
        'Varied': torch.tensor(
            [[i * 0.5, i * 0.3, i * 0.7] for i in range(10)],
            dtype=torch.float32
        ),
        'Outliers': torch.tensor(
            [[1, 1, 1]] * 9 + [[100, 100, 100]],
            dtype=torch.float32
        ),
    }

    batch = torch.zeros(10, dtype=torch.long)

    fig, axes = plt.subplots(len(features), 4, figsize=(14, 3 * len(features)))

    for row, (name, feat) in enumerate(features.items()):
        # Original features
        axes[row, 0].imshow(feat.numpy(), cmap='viridis', aspect='auto')
        axes[row, 0].set_title(f'{name}: Original Features')
        axes[row, 0].set_xlabel('Feature Dimension')
        axes[row, 0].set_ylabel('Node')

        # Sum pooling
        sum_pool = global_add_pool(feat, batch)
        axes[row, 1].bar(range(3), sum_pool[0].detach().numpy())
        axes[row, 1].set_title(f'{name}: Sum Pooling')
        axes[row, 1].set_ylabel('Aggregated Value')

        # Mean pooling
        mean_pool = global_mean_pool(feat, batch)
        axes[row, 2].bar(range(3), mean_pool[0].detach().numpy())
        axes[row, 2].set_title(f'{name}: Mean Pooling')
        axes[row, 2].set_ylabel('Aggregated Value')

        # Max pooling
        max_pool = global_max_pool(feat, batch)
        axes[row, 3].bar(range(3), max_pool[0].detach().numpy())
        axes[row, 3].set_title(f'{name}: Max Pooling')
        axes[row, 3].set_ylabel('Aggregated Value')

    plt.tight_layout()
    plt.show()

    # Summary table
    print("\nPooling Methods Comparison")
    print("=" * 80)
    print(
        f"{'Method':<15} {'Uniform':<20} {'Sparse':<20} {'Varied':<20} {'Outliers':<20}"
    )
    print("-" * 80)

    for method_name, method_func in [
        ('Sum', global_add_pool),
        ('Mean', global_mean_pool),
        ('Max', global_max_pool),
    ]:
        results = []
        for name, feat in features.items():
            batch = torch.zeros(len(feat), dtype=torch.long)
            pooled = method_func(feat, batch)
            # Use first feature dimension for comparison
            results.append(f"{pooled[0, 0].item():.2f}")

        print(f"{method_name:<15} {results[0]:<20} {results[1]:<20} {results[2]:<20} {results[3]:<20}")


visualize_pooling_effects()

---

## Part 2: Implementing Hierarchical Pooling

Now let's implement a learnable hierarchical pooling layer (Top-K pooling).

In [None]:
class TopKPooling(nn.Module):
    """
    Top-K Pooling layer.
    Selects the top-k nodes based on learned importance scores.
    """

    def __init__(self, in_channels: int, ratio: float = 0.8, nonlinearity=torch.sigmoid):
        """
        Args:
            in_channels: Size of node features
            ratio: Pooling ratio (fraction of nodes to keep)
            nonlinearity: Nonlinearity for scoring (sigmoid or tanh)
        """
        super().__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.nonlinearity = nonlinearity

        # Learnable scoring weights
        self.weight = nn.Parameter(torch.ones(in_channels))
        nn.init.xavier_uniform_(self.weight.view(-1, 1))

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Node features [num_nodes, in_channels]
            edge_index: Edge indices [2, num_edges]
            batch: Batch assignment [num_nodes]

        Returns:
            x_pool: Pooled node features
            edge_index_pool: Pooled edge indices
            batch_pool: Pooled batch assignment
        """
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # Compute importance scores
        scores = torch.sigmoid(torch.sum(x * self.weight, dim=1))

        # Determine number of nodes to keep
        num_nodes = x.size(0)
        num_keep = int(self.ratio * num_nodes)
        num_keep = max(1, num_keep)  # Keep at least 1 node

        # Select top-k nodes
        keep_idx = torch.topk(scores, k=num_keep)[1]
        keep_mask = torch.zeros(num_nodes, dtype=torch.bool, device=x.device)
        keep_mask[keep_idx] = True

        # Pool node features
        x_pool = x[keep_idx]

        # Pool batch assignment
        batch_pool = batch[keep_idx]

        # Pool edge indices (keep edges between kept nodes)
        mask = keep_mask[edge_index[0]] & keep_mask[edge_index[1]]
        edge_index_pool = edge_index[:, mask]

        # Remap node indices
        node_idx_mapping = torch.full((num_nodes,), -1, dtype=torch.long, device=x.device)
        node_idx_mapping[keep_idx] = torch.arange(num_keep, device=x.device)
        edge_index_pool = node_idx_mapping[edge_index_pool]

        return x_pool, edge_index_pool, batch_pool

    def __repr__(self):
        return f"TopKPooling(in_channels={self.in_channels}, ratio={self.ratio})"


# Test Top-K Pooling
print("Top-K Pooling Implementation\n" + "="*50)

# Create a simple graph
x = torch.randn(10, 8)
edge_index = torch.tensor(
    [
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 4, 6, 8],
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 3, 5, 7, 9, 1],
    ],
    dtype=torch.long,
)
batch = torch.zeros(10, dtype=torch.long)

pool_layer = TopKPooling(in_channels=8, ratio=0.7)
x_pool, edge_index_pool, batch_pool = pool_layer(x, edge_index, batch)

print(f"Original graph:")
print(f"  Nodes: {x.shape[0]}, Edges: {edge_index.shape[1]}")
print(f"\nAfter Top-K Pooling (ratio=0.7):")
print(f"  Nodes: {x_pool.shape[0]}, Edges: {edge_index_pool.shape[1]}")
print(f"  Reduction: {x.shape[0]} → {x_pool.shape[0]} nodes")

### Visualizing Pooling Selection

Let's visualize which nodes are selected by the Top-K pooling.

In [None]:
def visualize_topk_selection():
    """Visualize which nodes are selected by Top-K pooling."""
    
    # Create a simple synthetic graph
    np.random.seed(42)
    num_nodes = 20
    
    # Create node positions for visualization
    pos = np.random.randn(num_nodes, 2)
    
    # Create edges
    edges = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            # Connect nearby nodes
            dist = np.linalg.norm(pos[i] - pos[j])
            if dist < 1.5:
                edges.append([i, j])
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create node features
    x = torch.randn(num_nodes, 16)
    batch = torch.zeros(num_nodes, dtype=torch.long)
    
    # Apply pooling with different ratios
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    ratios = [1.0, 0.8, 0.6, 0.4, 0.2]
    
    for idx, ratio in enumerate(ratios):
        ax = axes.flatten()[idx]
        
        # Apply pooling
        pool = TopKPooling(in_channels=16, ratio=ratio)
        x_pool, edge_index_pool, _ = pool(x, edge_index, batch)
        
        # Get selected node indices
        num_keep = len(x_pool)
        
        # Create network graph
        G = nx.Graph()
        G.add_nodes_from(range(num_nodes))
        
        for edge in edge_index.t().numpy():
            G.add_edge(edge[0], edge[1])
        
        # Draw
        node_colors = ['lightblue' if i < num_keep else 'lightgray' 
                       for i in range(num_nodes)]
        
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                              node_size=300, ax=ax)
        nx.draw_networkx_edges(G, pos, alpha=0.3, ax=ax)
        
        ax.set_title(f'Pooling Ratio: {ratio:.1f}\n({num_keep}/{num_nodes} nodes kept)')
        ax.axis('off')
    
    # Remove extra subplot
    axes.flatten()[-1].remove()
    
    plt.tight_layout()
    plt.show()
    
    print("Blue nodes: Selected by Top-K Pooling")
    print("Gray nodes: Removed by pooling")

visualize_topk_selection()

---

## Part 3: SAGPool (Self-Attention Graph Pooling)

Implement a more sophisticated pooling method using attention mechanisms.

In [None]:
class SAGPooling(nn.Module):
    """
    Self-Attention Graph Pooling.
    Uses attention-based scoring to select important nodes.
    """

    def __init__(self, in_channels: int, ratio: float = 0.8, GNN=GCNConv):
        """
        Args:
            in_channels: Size of input features
            ratio: Pooling ratio
            GNN: GNN layer to use for feature computation
        """
        super().__init__()
        self.in_channels = in_channels
        self.ratio = ratio

        # GNN for computing node importance
        self.gnn = GNN(in_channels, 1)

        # Scoring function
        self.weight = nn.Parameter(torch.ones(in_channels))
        nn.init.xavier_uniform_(self.weight.view(-1, 1))

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Node features [num_nodes, in_channels]
            edge_index: Edge indices [2, num_edges]
            batch: Batch assignment [num_nodes]

        Returns:
            x_pool: Pooled node features
            edge_index_pool: Pooled edge indices
            batch_pool: Pooled batch assignment
        """
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # Compute attention scores using GNN
        scores = self.gnn(x, edge_index).squeeze()
        scores = torch.sigmoid(scores)

        # Determine number of nodes to keep
        num_nodes = x.size(0)
        num_keep = int(self.ratio * num_nodes)
        num_keep = max(1, num_keep)

        # Select top-k nodes
        keep_idx = torch.topk(scores, k=num_keep)[1]
        keep_mask = torch.zeros(num_nodes, dtype=torch.bool, device=x.device)
        keep_mask[keep_idx] = True

        # Pool node features
        x_pool = x[keep_idx]

        # Pool batch assignment
        batch_pool = batch[keep_idx]

        # Pool edge indices
        mask = keep_mask[edge_index[0]] & keep_mask[edge_index[1]]
        edge_index_pool = edge_index[:, mask]

        # Remap node indices
        node_idx_mapping = torch.full((num_nodes,), -1, dtype=torch.long, device=x.device)
        node_idx_mapping[keep_idx] = torch.arange(num_keep, device=x.device)
        edge_index_pool = node_idx_mapping[edge_index_pool]

        return x_pool, edge_index_pool, batch_pool, scores

    def __repr__(self):
        return f"SAGPooling(in_channels={self.in_channels}, ratio={self.ratio})"


# Test SAGPool
print("SAGPool (Self-Attention Graph Pooling)\n" + "="*50)

x = torch.randn(10, 8)
edge_index = torch.tensor(
    [
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 4, 6, 8],
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 3, 5, 7, 9, 1],
    ],
    dtype=torch.long,
)
batch = torch.zeros(10, dtype=torch.long)

pool_layer = SAGPooling(in_channels=8, ratio=0.7)
x_pool, edge_index_pool, batch_pool, scores = pool_layer(x, edge_index, batch)

print(f"Original graph: {x.shape[0]} nodes, {edge_index.shape[1]} edges")
print(f"After SAGPooling: {x_pool.shape[0]} nodes, {edge_index_pool.shape[1]} edges")
print(f"\nNode importance scores:")
for i, score in enumerate(scores):
    print(f"  Node {i}: {score.item():.4f}")

---

## Part 4: Graph Classification Task

Now let's build a complete GNN for graph classification using hierarchical pooling.

In [None]:
class HierarchicalGNN(nn.Module):
    """
    Hierarchical GNN with multiple pooling layers.
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_classes: int,
        num_layers: int = 3,
        pooling_ratio: float = 0.8,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes
        self.num_layers = num_layers

        # Input layer
        self.embed = nn.Linear(in_channels, hidden_channels)

        # GNN and pooling layers
        self.convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.global_pools = nn.ModuleList()

        for i in range(num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.pools.append(TopKPooling(hidden_channels, ratio=pooling_ratio))
            self.global_pools.append(GlobalPooling(method='concat'))

        # Output layer
        # 3 * hidden_channels from concatenated pooling (sum, mean, max)
        self.mlp = nn.Sequential(
            nn.Linear(3 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_channels, num_classes),
        )

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Node features
            edge_index: Edge indices
            batch: Batch assignment

        Returns:
            Graph-level predictions
        """
        # Embedding
        x = self.embed(x)

        # GNN + Pooling layers
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x, edge_index, batch = self.pools[i](x, edge_index, batch)

        # Global pooling
        x = self.global_pools[-1](x, batch)

        # Readout
        x = self.mlp(x)

        return x

    def __repr__(self):
        return (
            f"HierarchicalGNN(in_channels={self.in_channels}, "
            f"hidden_channels={self.hidden_channels}, "
            f"num_classes={self.num_classes}, "
            f"num_layers={self.num_layers})"
        )


# Test the model
print("Hierarchical GNN Model\n" + "="*50)

model = HierarchicalGNN(
    in_channels=8,
    hidden_channels=32,
    num_classes=2,
    num_layers=2,
)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Test forward pass
x = torch.randn(10, 8)
edge_index = torch.tensor(
    [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 4, 6, 8],
     [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 3, 5, 7, 9, 1]],
    dtype=torch.long,
)
batch = torch.zeros(10, dtype=torch.long)

output = model(x, edge_index, batch)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output logits: {output}")

---

## Part 5: Synthetic Graph Classification Dataset

Create a synthetic graph classification dataset and train models on it.

In [None]:
def create_synthetic_graph_dataset(num_graphs: int = 100, num_classes: int = 2):
    """
    Create a synthetic graph dataset.
    Class 0: Random graphs
    Class 1: Graphs with strong community structure
    """
    graphs = []

    for graph_id in range(num_graphs):
        if graph_id % 2 == 0:
            # Class 0: Random graph
            num_nodes = np.random.randint(10, 30)
            p = 0.15
            G = nx.erdos_renyi_graph(num_nodes, p)
            label = 0
        else:
            # Class 1: Graph with community structure
            num_nodes = np.random.randint(10, 30)
            num_communities = np.random.randint(2, 4)
            
            # Create community structure
            G = nx.Graph()
            nodes_per_community = num_nodes // num_communities
            node_idx = 0

            for c in range(num_communities):
                # Dense connections within community
                for i in range(nodes_per_community):
                    for j in range(i + 1, nodes_per_community):
                        if np.random.random() < 0.7:
                            G.add_edge(node_idx + i, node_idx + j)

                # Few connections between communities
                if c < num_communities - 1:
                    for _ in range(2):
                        u = np.random.randint(node_idx, node_idx + nodes_per_community)
                        v = np.random.randint(
                            node_idx + nodes_per_community,
                            min(node_idx + 2 * nodes_per_community, num_nodes),
                        )
                        if v < num_nodes:
                            G.add_edge(u, v)

                node_idx += nodes_per_community

            label = 1

        if G.number_of_nodes() > 0 and G.number_of_edges() > 0:
            # Convert to PyG Data
            edge_index = torch.tensor(
                list(G.edges()), dtype=torch.long
            ).t().contiguous()
            
            # If no edges, skip
            if edge_index.size(1) > 0:
                # Add reverse edges (undirected)
                edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
                edge_index = torch.unique(edge_index, dim=1)

            # Node features: random features + degree
            num_nodes = G.number_of_nodes()
            degree = torch.tensor([G.degree(i) for i in range(num_nodes)], dtype=torch.float32)
            x = torch.randn(num_nodes, 8)
            x[:, 0] = degree  # Add degree as a feature

            data = Data(
                x=x,
                edge_index=edge_index,
                y=torch.tensor([label], dtype=torch.long),
            )
            graphs.append(data)

    return graphs


# Create dataset
print("Creating Synthetic Graph Classification Dataset\n" + "="*50)
dataset = create_synthetic_graph_dataset(num_graphs=100, num_classes=2)
print(f"Total graphs: {len(dataset)}")

# Dataset statistics
num_nodes_list = [g.num_nodes for g in dataset]
num_edges_list = [g.num_edges for g in dataset]

print(f"\nDataset Statistics:")
print(f"  Number of graphs: {len(dataset)}")
print(f"  Number of nodes: min={min(num_nodes_list)}, max={max(num_nodes_list)}, avg={np.mean(num_nodes_list):.1f}")
print(f"  Number of edges: min={min(num_edges_list)}, max={max(num_edges_list)}, avg={np.mean(num_edges_list):.1f}")

# Class distribution
labels = [g.y.item() for g in dataset]
print(f"\nClass distribution:")
print(f"  Class 0 (Random): {labels.count(0)}")
print(f"  Class 1 (Community): {labels.count(1)}")

# Visualize some graphs
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, graph_idx in enumerate([0, 1, 10, 11, 20, 21]):
    ax = axes.flatten()[idx]
    
    data = dataset[graph_idx]
    G = nx.Graph()
    G.add_nodes_from(range(data.num_nodes))
    for edge in data.edge_index.t().numpy():
        G.add_edge(edge[0], edge[1])
    
    pos = nx.spring_layout(G, seed=42)
    label_text = "Random" if data.y.item() == 0 else "Community"
    
    nx.draw_networkx_nodes(G, pos, node_size=100, node_color='lightblue', ax=ax)
    nx.draw_networkx_edges(G, pos, alpha=0.3, ax=ax)
    ax.set_title(f"{label_text} Graph (n={data.num_nodes}, m={data.num_edges//2})")
    ax.axis('off')

plt.tight_layout()
plt.show()

### Training a Graph Classification Model

In [None]:
def train_graph_classifier(model, train_loader, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Forward pass
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(out, batch.y.view(-1))

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == batch.y.view(-1)).sum().item()
        total += batch.y.size(0)

    return total_loss / len(train_loader), correct / total


def evaluate_graph_classifier(model, loader, device):
    """Evaluate the model."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            pred = out.argmax(dim=1)
            correct += (pred == batch.y.view(-1)).sum().item()
            total += batch.y.size(0)

    return correct / total


# Train model
print("Training Graph Classification Model\n" + "="*50)

device = torch.device('cpu')  # Use CPU for compatibility

# Split dataset
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = random_split(
    dataset, [train_size, val_size, test_size]
)

# Data loaders
batch_size = 8
train_loader = GeoDataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = GeoDataLoader(val_set, batch_size=batch_size)
test_loader = GeoDataLoader(test_set, batch_size=batch_size)

print(f"Train size: {len(train_set)}, Val size: {len(val_set)}, Test size: {len(test_set)}")

# Create and train model
model = HierarchicalGNN(
    in_channels=8,
    hidden_channels=32,
    num_classes=2,
    num_layers=2,
)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Training loop
num_epochs = 50
train_losses = []
train_accs = []
val_accs = []

print(f"\nTraining for {num_epochs} epochs...")
for epoch in range(num_epochs):
    train_loss, train_acc = train_graph_classifier(model, train_loader, optimizer, device)
    val_acc = evaluate_graph_classifier(model, val_loader, device)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch {epoch+1:3d} | Loss: {train_loss:.4f} | "
            f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}"
        )

# Test performance
test_acc = evaluate_graph_classifier(model, test_loader, device)
print(f"\nTest Accuracy: {test_acc:.4f}")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses, label='Training Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(train_accs, label='Training Accuracy')
axes[1].plot(val_accs, label='Validation Accuracy')
axes[1].axhline(y=test_acc, color='r', linestyle='--', label=f'Test Accuracy ({test_acc:.4f})')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Part 6: Molecular Property Prediction

Now let's work with a real molecular dataset using SMILES strings.

In [None]:
# Create a simple molecular dataset
# In practice, you would use RDKit to convert SMILES to graphs

def create_molecule_dataset(num_molecules: int = 50):
    """
    Create a synthetic molecular dataset.
    In practice, this would use RDKit and real molecular data.
    """
    molecules = []

    for mol_id in range(num_molecules):
        # Create random molecular-like graphs
        num_atoms = np.random.randint(5, 20)
        
        # Create a random tree-like structure (like a molecule)
        G = nx.Graph()
        G.add_nodes_from(range(num_atoms))
        
        # Add edges to form connected structure
        for i in range(1, num_atoms):
            parent = np.random.randint(0, i)
            G.add_edge(parent, i)
        
        # Randomly add extra edges
        for _ in range(np.random.randint(0, 3)):
            u = np.random.randint(0, num_atoms)
            v = np.random.randint(0, num_atoms)
            if u != v:
                G.add_edge(u, v)
        
        # Create property label (simulate bioactivity)
        # Higher connectivity → higher activity (simplified)
        avg_degree = sum(dict(G.degree()).values()) / num_atoms
        activity = 1 if avg_degree > 2.0 else 0
        
        # Convert to PyG Data
        edge_index = torch.tensor(
            list(G.edges()), dtype=torch.long
        ).t().contiguous()
        
        if edge_index.size(1) > 0:
            edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            edge_index = torch.unique(edge_index, dim=1)
        
        # Atom features (atomic number, degree, etc.)
        degree = torch.tensor([G.degree(i) for i in range(num_atoms)], dtype=torch.float32)
        x = torch.randn(num_atoms, 8)
        x[:, 0] = degree
        
        data = Data(
            x=x,
            edge_index=edge_index,
            y=torch.tensor([activity], dtype=torch.long),
        )
        molecules.append(data)
    
    return molecules


# Create molecular dataset
print("Creating Molecular Dataset\n" + "="*50)
mol_dataset = create_molecule_dataset(num_molecules=100)
print(f"Total molecules: {len(mol_dataset)}")

# Dataset statistics
num_atoms_list = [g.num_nodes for g in mol_dataset]
labels = [g.y.item() for g in mol_dataset]

print(f"\nMolecular Dataset Statistics:")
print(f"  Number of molecules: {len(mol_dataset)}")
print(f"  Atoms per molecule: min={min(num_atoms_list)}, max={max(num_atoms_list)}, avg={np.mean(num_atoms_list):.1f}")
print(f"  Active molecules: {labels.count(1)}")
print(f"  Inactive molecules: {labels.count(0)}")

# Visualize some molecules
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx in range(6):
    ax = axes.flatten()[idx]
    
    data = mol_dataset[idx]
    G = nx.Graph()
    G.add_nodes_from(range(data.num_nodes))
    for edge in data.edge_index.t().numpy():
        G.add_edge(edge[0], edge[1])
    
    pos = nx.spring_layout(G, seed=42, k=0.5)
    activity = "Active" if data.y.item() == 1 else "Inactive"
    
    nx.draw_networkx_nodes(G, pos, node_size=150, node_color='lightcoral', ax=ax)
    nx.draw_networkx_edges(G, pos, alpha=0.3, ax=ax)
    ax.set_title(f"{activity} Molecule (atoms={data.num_nodes})")
    ax.axis('off')

plt.tight_layout()
plt.show()

### Training on Molecular Dataset

In [None]:
# Train on molecular dataset
print("Training on Molecular Dataset\n" + "="*50)

# Split dataset
mol_train_size = int(0.7 * len(mol_dataset))
mol_val_size = int(0.15 * len(mol_dataset))
mol_test_size = len(mol_dataset) - mol_train_size - mol_val_size

mol_train_set, mol_val_set, mol_test_set = random_split(
    mol_dataset, [mol_train_size, mol_val_size, mol_test_size]
)

# Data loaders
mol_train_loader = GeoDataLoader(mol_train_set, batch_size=8, shuffle=True)
mol_val_loader = GeoDataLoader(mol_val_set, batch_size=8)
mol_test_loader = GeoDataLoader(mol_test_set, batch_size=8)

print(f"Train: {len(mol_train_set)}, Val: {len(mol_val_set)}, Test: {len(mol_test_set)}")

# Create model
mol_model = HierarchicalGNN(
    in_channels=8,
    hidden_channels=32,
    num_classes=2,
    num_layers=2,
)
mol_model = mol_model.to(device)
mol_optimizer = Adam(mol_model.parameters(), lr=0.001, weight_decay=1e-5)

# Train
num_epochs = 50
mol_train_losses = []
mol_train_accs = []
mol_val_accs = []

print(f"\nTraining for {num_epochs} epochs...")
for epoch in range(num_epochs):
    train_loss, train_acc = train_graph_classifier(mol_model, mol_train_loader, mol_optimizer, device)
    val_acc = evaluate_graph_classifier(mol_model, mol_val_loader, device)

    mol_train_losses.append(train_loss)
    mol_train_accs.append(train_acc)
    mol_val_accs.append(val_acc)

    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch {epoch+1:3d} | Loss: {train_loss:.4f} | "
            f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}"
        )

# Test
mol_test_acc = evaluate_graph_classifier(mol_model, mol_test_loader, device)
print(f"\nTest Accuracy: {mol_test_acc:.4f}")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(mol_train_losses, label='Training Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Molecular Dataset: Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(mol_train_accs, label='Training Accuracy')
axes[1].plot(mol_val_accs, label='Validation Accuracy')
axes[1].axhline(y=mol_test_acc, color='r', linestyle='--', label=f'Test Accuracy ({mol_test_acc:.4f})')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Molecular Dataset: Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Part 7: Visualizing Pooled Graph Hierarchies

Let's visualize what happens at different levels of hierarchical pooling.

In [None]:
class VisualizableHierarchicalGNN(nn.Module):
    """
    Hierarchical GNN that returns intermediate representations for visualization.
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        num_classes: int,
        num_layers: int = 3,
        pooling_ratio: float = 0.8,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes
        self.num_layers = num_layers

        self.embed = nn.Linear(in_channels, hidden_channels)

        self.convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.global_pools = nn.ModuleList()

        for i in range(num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.pools.append(TopKPooling(hidden_channels, ratio=pooling_ratio))
            self.global_pools.append(GlobalPooling(method='concat'))

        self.mlp = nn.Sequential(
            nn.Linear(3 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_channels, num_classes),
        )

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor
    ) -> Tuple[torch.Tensor, List]:
        """
        Returns predictions and intermediate representations.
        """
        x = self.embed(x)

        intermediates = []
        intermediates.append((x.clone().detach(), edge_index.clone().detach(), batch.clone().detach()))

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x, edge_index, batch = self.pools[i](x, edge_index, batch)
            intermediates.append((x.clone().detach(), edge_index.clone().detach(), batch.clone().detach()))

        x = self.global_pools[-1](x, batch)
        x = self.mlp(x)

        return x, intermediates


# Create and test
print("Visualizing Hierarchical Pooling\n" + "="*50)

viz_model = VisualizableHierarchicalGNN(
    in_channels=8,
    hidden_channels=32,
    num_classes=2,
    num_layers=3,
    pooling_ratio=0.7,
)
viz_model.eval()

# Create a test graph
test_graph = mol_dataset[0]
x = test_graph.x.unsqueeze(0)  # Add batch dimension
batch = torch.zeros(test_graph.num_nodes, dtype=torch.long)
edge_index = test_graph.edge_index

# Forward pass to get intermediate representations
with torch.no_grad():
    output, intermediates = viz_model(x.squeeze(0), edge_index, batch)

# Visualize hierarchy
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for layer, (x_layer, edge_layer, batch_layer) in enumerate(intermediates):
    # Create networkx graph
    G = nx.Graph()
    num_nodes = x_layer.shape[0]
    G.add_nodes_from(range(num_nodes))

    for edge in edge_layer.t().numpy():
        if edge[0] < num_nodes and edge[1] < num_nodes:
            G.add_edge(edge[0], edge[1])

    # Draw
    if G.number_of_nodes() > 0:
        pos = nx.spring_layout(G, seed=42, k=0.5)
        
        # Node colors based on features
        node_colors = x_layer[:, 0].numpy()  # Use first feature
        
        nodes = nx.draw_networkx_nodes(
            G, pos, node_color=node_colors, node_size=200,
            cmap='viridis', ax=axes[layer]
        )
        nx.draw_networkx_edges(G, pos, alpha=0.3, ax=axes[layer])
        
        axes[layer].set_title(
            f"Layer {layer}\n({num_nodes} nodes, {G.number_of_edges()} edges)"
        )
        axes[layer].axis('off')

plt.tight_layout()
plt.show()

print("\nHierarchical Pooling Summary:")
print("="*50)
for layer, (x_layer, edge_layer, batch_layer) in enumerate(intermediates):
    print(f"Layer {layer}: {x_layer.shape[0]} nodes, {edge_layer.shape[1]//2} edges")

---

## Part 8: Comparing Pooling Methods

Let's compare the performance and properties of different pooling methods.

In [None]:
class SimpleGNNWithGlobalPool(nn.Module):
    """GNN with global pooling only (baseline)."""

    def __init__(self, in_channels: int, hidden_channels: int, num_classes: int):
        super().__init__()
        self.embed = nn.Linear(in_channels, hidden_channels)
        self.conv = GCNConv(hidden_channels, hidden_channels)
        self.pool = GlobalPooling(method='concat')
        self.mlp = nn.Sequential(
            nn.Linear(3 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_channels, num_classes),
        )

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)
        x = self.conv(x, edge_index)
        x = F.relu(x)
        x = self.pool(x, batch)
        return self.mlp(x)


class GNNWithTopKPool(nn.Module):
    """GNN with Top-K pooling."""

    def __init__(self, in_channels: int, hidden_channels: int, num_classes: int):
        super().__init__()
        self.embed = nn.Linear(in_channels, hidden_channels)
        self.conv1 = GCNConv(hidden_channels, hidden_channels)
        self.pool = TopKPooling(hidden_channels, ratio=0.7)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.global_pool = GlobalPooling(method='concat')
        self.mlp = nn.Sequential(
            nn.Linear(3 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_channels, num_classes),
        )

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        x = self.embed(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x, edge_index, batch = self.pool(x, edge_index, batch)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.global_pool(x, batch)
        return self.mlp(x)


print("Comparing Pooling Methods\n" + "="*50)

models_to_compare = {
    'Global Pool': SimpleGNNWithGlobalPool(8, 32, 2),
    'Top-K Pool': GNNWithTopKPool(8, 32, 2),
    'Hierarchical': HierarchicalGNN(8, 32, 2, num_layers=2),
}

# Train each model
results = {}
for model_name, model in models_to_compare.items():
    print(f"\nTraining {model_name}...")
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

    train_losses = []
    val_accs = []

    for epoch in range(30):
        train_loss, _ = train_graph_classifier(model, mol_train_loader, optimizer, device)
        val_acc = evaluate_graph_classifier(model, mol_val_loader, device)
        train_losses.append(train_loss)
        val_accs.append(val_acc)

    test_acc = evaluate_graph_classifier(model, mol_test_loader, device)

    # Count parameters
    num_params = sum(p.numel() for p in model.parameters())

    results[model_name] = {
        'val_accs': val_accs,
        'test_acc': test_acc,
        'num_params': num_params,
    }

    print(f"  Test Accuracy: {test_acc:.4f}, Parameters: {num_params:,}")

# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Validation accuracy comparison
for model_name, result in results.items():
    axes[0].plot(result['val_accs'], label=model_name, marker='o', markersize=3)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Accuracy')
axes[0].set_title('Validation Accuracy Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Test accuracy and parameter count
model_names = list(results.keys())
test_accs = [results[m]['test_acc'] for m in model_names]
num_params = [results[m]['num_params'] for m in model_names]

ax2 = axes[1]
ax3 = ax2.twinx()

colors = ['C0', 'C1', 'C2']
x_pos = np.arange(len(model_names))

bars1 = ax2.bar(x_pos - 0.2, test_accs, 0.4, label='Test Accuracy', color=colors)
ax2.set_ylabel('Test Accuracy', color=colors[0])
ax2.set_ylim([0.4, 1.0])

bars2 = ax3.bar(x_pos + 0.2, [p/1000 for p in num_params], 0.4, label='Parameters (K)', color=colors)
ax3.set_ylabel('Parameters (Thousands)', color=colors[1])

ax2.set_xticks(x_pos)
ax2.set_xticklabels(model_names)
ax2.set_title('Test Accuracy vs Model Complexity')

fig.tight_layout()
plt.show()

# Print comparison table
print("\n" + "="*70)
print("Model Comparison Summary")
print("="*70)
print(f"{'Model':<20} {'Test Accuracy':<20} {'Parameters':<20}")
print("-"*70)
for model_name in model_names:
    test_acc = results[model_name]['test_acc']
    num_params = results[model_name]['num_params']
    print(f"{model_name:<20} {test_acc:<20.4f} {num_params:<20,}")

---

## Exercises

### Exercise 1: Global Pooling Comparison
Implement a comparison of all global pooling methods (sum, mean, max) on different graph size distributions.

### Exercise 2: Custom Pooling Ratio
Modify the Top-K pooling layer to have different pooling ratios at different layers. Train a model and compare results.

### Exercise 3: SAGPool Implementation
Complete the SAGPool implementation to use neighborhood information for scoring, not just individual node features.

### Exercise 4: Visualization Analysis
Create a function that visualizes the hierarchical clustering created by pooling layers.

### Exercise 5: Large Graph Scalability
Test pooling methods on increasingly larger graphs and measure runtime complexity.

### Exercise 6: Hyperparameter Search
Implement a grid search over pooling ratios and number of hierarchical layers.

### Exercise 7: Real Molecular Data
Use RDKit to convert SMILES strings to molecular graphs and train a model for property prediction.

### Exercise 8: Attention Visualization
Implement and visualize attention weights in SAGPool to show which nodes are important.