# Lesson 5: Advanced GNN Architectures
## GraphSAGE and Graph Isomorphism Networks (GIN)

In this notebook, we'll implement and compare two advanced GNN architectures:
1. **GraphSAGE**: For inductive learning with neighborhood sampling
2. **GIN**: For maximum expressiveness through injective aggregation

We'll explore sampling strategies, inductive learning, and expressiveness analysis.

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict, Set, Optional
import networkx as nx
from collections import defaultdict
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.data import Data, DataLoader, NeighborSampler
from torch_geometric.nn import GCNConv, SAGEConv
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## Part 1: Understanding GraphSAGE

### 1.1 Neighborhood Sampling

Let's start by implementing neighborhood sampling mechanisms.

In [None]:
class NeighborhoodSampler:
    """Base class for neighborhood sampling strategies"""
    
    def __init__(self, graph_dict: Dict[int, List[int]]):
        """
        Args:
            graph_dict: Dictionary mapping node -> list of neighbors
        """
        self.graph_dict = graph_dict
        self.node_degrees = {v: len(neighbors) for v, neighbors in graph_dict.items()}
    
    def sample_neighbors(self, node: int, sample_size: int) -> List[int]:
        """Sample neighbors of a node. To be implemented by subclasses."""
        raise NotImplementedError

class UniformSampler(NeighborhoodSampler):
    """Uniform random sampling of neighbors"""
    
    def sample_neighbors(self, node: int, sample_size: int) -> List[int]:
        neighbors = self.graph_dict.get(node, [])
        if len(neighbors) == 0:
            return []
        sample_size = min(sample_size, len(neighbors))
        return list(np.random.choice(neighbors, size=sample_size, replace=False))

class ImportanceSampler(NeighborhoodSampler):
    """Sample neighbors based on importance scores (e.g., node degree)"""
    
    def __init__(self, graph_dict: Dict[int, List[int]], importance_fn=None):
        super().__init__(graph_dict)
        if importance_fn is None:
            # Default: degree-based importance
            importance_fn = lambda node: self.node_degrees.get(node, 1)
        self.importance_fn = importance_fn
    
    def sample_neighbors(self, node: int, sample_size: int) -> List[int]:
        neighbors = self.graph_dict.get(node, [])
        if len(neighbors) == 0:
            return []
        
        # Compute importance scores
        scores = np.array([self.importance_fn(n) for n in neighbors])
        scores = scores / scores.sum()  # Normalize to probabilities
        
        sample_size = min(sample_size, len(neighbors))
        return list(np.random.choice(neighbors, size=sample_size, replace=False, p=scores))

class AdaptiveSampler(NeighborhoodSampler):
    """Adaptive sampling with variance reduction"""
    
    def __init__(self, graph_dict: Dict[int, List[int]], importance_scores: Dict[int, float]):
        super().__init__(graph_dict)
        self.importance_scores = importance_scores
    
    def sample_neighbors(self, node: int, sample_size: int) -> List[int]:
        neighbors = self.graph_dict.get(node, [])
        if len(neighbors) == 0:
            return []
        
        # Use learned importance scores
        scores = np.array([self.importance_scores.get(n, 1.0) for n in neighbors])
        scores = np.maximum(scores, 0.01)  # Avoid zero probabilities
        scores = scores / scores.sum()
        
        sample_size = min(sample_size, len(neighbors))
        return list(np.random.choice(neighbors, size=sample_size, replace=False, p=scores))

print("Neighborhood Sampler classes defined successfully!")

### 1.2 GraphSAGE Aggregators

Implement different aggregation functions for GraphSAGE.

In [None]:
class GraphSAGEAggregator(nn.Module):
    """Base class for GraphSAGE aggregators"""
    
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
    
    def forward(self, node_features: torch.Tensor, neighbor_features: List[torch.Tensor]) -> torch.Tensor:
        """Aggregate neighbor features and combine with node features.
        
        Args:
            node_features: Features of central node(s), shape (batch_size, input_dim)
            neighbor_features: Features of neighbors, list of (batch_size, input_dim)
        
        Returns:
            aggregated: Combined features, shape (batch_size, output_dim)
        """
        raise NotImplementedError

class MeanAggregator(GraphSAGEAggregator):
    """Mean aggregator: average neighbor embeddings"""
    
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__(input_dim, output_dim, bias)
        self.linear = nn.Linear(input_dim * 2, output_dim, bias=bias)
    
    def forward(self, node_features: torch.Tensor, neighbor_features: List[torch.Tensor]) -> torch.Tensor:
        if len(neighbor_features) == 0:
            # No neighbors, use node features only
            neighbor_mean = torch.zeros_like(node_features)
        else:
            neighbor_mean = torch.mean(torch.stack(neighbor_features), dim=0)
        
        combined = torch.cat([node_features, neighbor_mean], dim=1)
        return F.relu(self.linear(combined))

class PoolingAggregator(GraphSAGEAggregator):
    """Pooling aggregator: max pooling over neighbor features"""
    
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__(input_dim, output_dim, bias)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, input_dim)
        )
        self.linear = nn.Linear(input_dim * 2, output_dim, bias=bias)
    
    def forward(self, node_features: torch.Tensor, neighbor_features: List[torch.Tensor]) -> torch.Tensor:
        if len(neighbor_features) == 0:
            neighbor_pool = torch.zeros_like(node_features)
        else:
            # Apply MLP to neighbors and max pool
            neighbor_mlps = torch.stack([self.mlp(nf) for nf in neighbor_features])
            neighbor_pool, _ = torch.max(neighbor_mlps, dim=0)
        
        combined = torch.cat([node_features, neighbor_pool], dim=1)
        return F.relu(self.linear(combined))

class LSTMAggregator(GraphSAGEAggregator):
    """LSTM aggregator: sequential aggregation of neighbors"""
    
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__(input_dim, output_dim, bias)
        self.lstm = nn.LSTM(input_dim, input_dim, batch_first=True)
        self.linear = nn.Linear(input_dim * 2, output_dim, bias=bias)
    
    def forward(self, node_features: torch.Tensor, neighbor_features: List[torch.Tensor]) -> torch.Tensor:
        if len(neighbor_features) == 0:
            neighbor_lstm = torch.zeros_like(node_features)
        else:
            # Stack neighbors into sequence
            neighbor_seq = torch.stack(neighbor_features, dim=1)  # (batch, seq_len, dim)
            _, (hidden, _) = self.lstm(neighbor_seq)
            neighbor_lstm = hidden.squeeze(0)
        
        combined = torch.cat([node_features, neighbor_lstm], dim=1)
        return F.relu(self.linear(combined))

print("Aggregator classes defined successfully!")

### 1.3 GraphSAGE Model Implementation

In [None]:
class GraphSAGE(nn.Module):
    """GraphSAGE: Inductive Representation Learning on Large Graphs
    
    Key features:
    - Inductive learning: learns to embed unseen nodes
    - Mini-batch training: scalable to large graphs
    - Multiple aggregators: flexible architecture
    """
    
    def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,
                 aggregator_type: str = 'mean', dropout: float = 0.0):
        """
        Args:
            input_dim: Input feature dimension
            hidden_dims: List of hidden layer dimensions
            output_dim: Output dimension
            aggregator_type: 'mean', 'pooling', or 'lstm'
            dropout: Dropout probability
        """
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.dropout = dropout
        
        # Create aggregators for each layer
        self.aggregators = nn.ModuleList()
        dims = [input_dim] + hidden_dims + [output_dim]
        
        aggregator_class = {
            'mean': MeanAggregator,
            'pooling': PoolingAggregator,
            'lstm': LSTMAggregator
        }[aggregator_type]
        
        for i in range(len(dims) - 1):
            self.aggregators.append(aggregator_class(dims[i], dims[i + 1]))
    
    def forward(self, node_features: torch.Tensor, neighbors_per_layer: List[List[List[int]]]) -> torch.Tensor:
        """
        Forward pass with multi-hop aggregation.
        
        Args:
            node_features: Node feature embeddings, shape (num_nodes, input_dim)
            neighbors_per_layer: neighbors_per_layer[layer][node_idx] = list of neighbor indices
        
        Returns:
            embeddings: Final embeddings, shape (batch_size, output_dim)
        """
        batch_size = len(neighbors_per_layer[0])
        current_features = node_features[:batch_size]
        
        for layer_idx, aggregator in enumerate(self.aggregators):
            neighbor_indices = neighbors_per_layer[layer_idx]
            neighbor_features = []
            
            for node_idx, neighbors in enumerate(neighbor_indices):
                if neighbors:
                    neighbor_feats = node_features[neighbors]
                    neighbor_features.append(neighbor_feats)
                else:
                    neighbor_features.append(torch.zeros(0, node_features.shape[1], device=node_features.device))
            
            # Aggregate
            aggregated = aggregator(current_features, neighbor_features)
            current_features = F.dropout(aggregated, p=self.dropout, training=self.training)
        
        return current_features

print("GraphSAGE model defined successfully!")

## Part 2: Graph Isomorphism Networks (GIN)

### 2.1 Understanding Weisfeiler-Lehman Test

Implement the Weisfeiler-Lehman (WL) graph isomorphism test to understand GIN's theoretical foundation.

In [None]:
def weisfeiler_lehman_test(graph_dict: Dict[int, List[int]], num_iterations: int = 3) -> Dict:
    """Perform Weisfeiler-Lehman test on a graph.
    
    Args:
        graph_dict: Adjacency list representation
        num_iterations: Number of WL iterations
    
    Returns:
        Dictionary with iteration-by-iteration color assignments
    """
    num_nodes = len(graph_dict)
    colors = {}
    color_sequence = {}
    
    # Iteration 0: Color by degree
    for node in range(num_nodes):
        colors[node] = len(graph_dict.get(node, []))
    color_sequence[0] = colors.copy()
    
    # Perform WL iterations
    for iteration in range(num_iterations):
        new_colors = {}
        color_map = {}  # Map unique signatures to new colors
        next_color = 0
        
        for node in range(num_nodes):
            # Get current color
            current_color = colors[node]
            
            # Get sorted colors of neighbors
            neighbor_colors = sorted([colors[n] for n in graph_dict.get(node, [])])
            
            # Create signature
            signature = (current_color, tuple(neighbor_colors))
            
            # Map signature to new color
            if signature not in color_map:
                color_map[signature] = next_color
                next_color += 1
            
            new_colors[node] = color_map[signature]
        
        colors = new_colors
        color_sequence[iteration + 1] = colors.copy()
    
    return color_sequence

print("Weisfeiler-Lehman test function defined successfully!")

### 2.2 GIN Layer Implementation

In [None]:
class GINLayer(nn.Module):
    """Graph Isomorphism Network (GIN) Layer
    
    Mathematical formulation:
    h_v^{k+1} = MLP^k((1 + eps_k) * h_v^k + sum_{u in N(v)} h_u^k)
    
    Key property: Injective aggregation preserves injectivity of node functions.
    """
    
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: Optional[int] = None,
                 eps: float = 0.0, learn_eps: bool = False):
        """
        Args:
            input_dim: Input feature dimension
            output_dim: Output feature dimension
            hidden_dim: Hidden dimension for MLP (default: output_dim)
            eps: Initial epsilon value
            learn_eps: Whether to learn epsilon
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.learn_eps = learn_eps
        
        if hidden_dim is None:
            hidden_dim = output_dim
        
        # Register epsilon
        if learn_eps:
            self.eps = nn.Parameter(torch.tensor(eps, dtype=torch.float32))
        else:
            self.register_buffer('eps', torch.tensor(eps, dtype=torch.float32))
        
        # MLP: (1+eps)*x + sum(neighbors) -> output
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, node_features: torch.Tensor, neighbor_sum: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            node_features: Features of nodes, shape (num_nodes, input_dim)
            neighbor_sum: Sum of neighbor features, shape (num_nodes, input_dim)
        
        Returns:
            updated: Updated node features, shape (num_nodes, output_dim)
        """
        # Apply GIN formula: MLP((1+eps)*x + sum(neighbors))
        aggregated = (1 + self.eps) * node_features + neighbor_sum
        return self.mlp(aggregated)

class GIN(nn.Module):
    """Graph Isomorphism Network
    
    Theoretically grounded architecture with guaranteed expressiveness
    matching the Weisfeiler-Lehman graph isomorphism test.
    """
    
    def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int,
                 dropout: float = 0.0, learn_eps: bool = False):
        """
        Args:
            input_dim: Input feature dimension
            hidden_dims: List of hidden layer dimensions
            output_dim: Output dimension
            dropout: Dropout probability
            learn_eps: Whether to learn epsilon in GIN layers
        """
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.dropout = dropout
        
        dims = [input_dim] + hidden_dims + [output_dim]
        self.gin_layers = nn.ModuleList()
        
        for i in range(len(dims) - 1):
            self.gin_layers.append(GINLayer(dims[i], dims[i + 1], learn_eps=learn_eps))
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Node features, shape (num_nodes, input_dim)
            edge_index: Edge indices, shape (2, num_edges)
        
        Returns:
            embeddings: Final embeddings, shape (num_nodes, output_dim)
        """
        for i, gin_layer in enumerate(self.gin_layers):
            # Compute neighbor sum via sparse matrix multiplication
            num_nodes = x.shape[0]
            
            # Create adjacency matrix
            adj = torch.sparse.FloatTensor(
                edge_index,
                torch.ones(edge_index.shape[1], device=edge_index.device),
                torch.Size([num_nodes, num_nodes])
            ).to(x.device)
            
            # Sum of neighbor features
            neighbor_sum = torch.sparse.mm(adj, x)
            
            # Apply GIN layer
            x = gin_layer(x, neighbor_sum)
            
            if i < len(self.gin_layers) - 1:
                x = F.dropout(x, p=self.dropout, training=self.training)
        
        return x
    
    def readout(self, x: torch.Tensor) -> torch.Tensor:
        """Graph-level readout: sum all node embeddings"""
        return torch.sum(x, dim=0, keepdim=True)

print("GIN model defined successfully!")

## Part 3: Practical Experiments

### 3.1 Load and Prepare Dataset

In [None]:
# Load Cora dataset for node classification
dataset = Planetoid(root='/tmp/cora', name='Cora')
data = dataset[0]

print(f"Dataset: {dataset}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Feature dimension: {data.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Train/Val/Test split: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")

### 3.2 Neighborhood Sampling Analysis

In [None]:
# Convert to adjacency list
def edge_index_to_adj_list(edge_index: torch.Tensor, num_nodes: int) -> Dict[int, List[int]]:
    """Convert edge_index to adjacency list"""
    adj_list = defaultdict(list)
    for i in range(edge_index.shape[1]):
        u, v = edge_index[0, i].item(), edge_index[1, i].item()
        adj_list[u].append(v)
    # Ensure all nodes are in dict
    for i in range(num_nodes):
        if i not in adj_list:
            adj_list[i] = []
    return dict(adj_list)

adj_list = edge_index_to_adj_list(data.edge_index, data.num_nodes)

# Compare sampling strategies
uniform_sampler = UniformSampler(adj_list)
importance_sampler = ImportanceSampler(adj_list)

node_id = 0
sample_size = 5

print(f"Node {node_id} neighbors: {adj_list[node_id][:10]}...")  # Show first 10
print(f"Degree: {len(adj_list[node_id])}")
print()

# Sample multiple times to analyze variance
num_samples = 100
uniform_samples = []
importance_samples = []

for _ in range(num_samples):
    u_sample = set(uniform_sampler.sample_neighbors(node_id, sample_size))
    i_sample = set(importance_sampler.sample_neighbors(node_id, sample_size))
    uniform_samples.append(u_sample)
    importance_samples.append(i_sample)

print(f"Uniform sampler - avg sample size: {np.mean([len(s) for s in uniform_samples]):.2f}")
print(f"Importance sampler - avg sample size: {np.mean([len(s) for s in importance_samples]):.2f}")
print()
print("Sample diversity (Jaccard similarity):")
print(f"Uniform: {np.mean([len(uniform_samples[i] & uniform_samples[j]) / len(uniform_samples[i] | uniform_samples[j]) for i in range(10) for j in range(i+1, 10)]):.3f}")

### 3.3 Weisfeiler-Lehman Test Demonstration

In [None]:
# Create small example graphs
def create_cycle_graph(n: int) -> Dict[int, List[int]]:
    """Create a cycle graph"""
    graph = {i: [] for i in range(n)}
    for i in range(n):
        graph[i] = [(i - 1) % n, (i + 1) % n]
    return graph

def create_complete_bipartite_graph(n: int) -> Dict[int, List[int]]:
    """Create complete bipartite graph K_{n,n}"""
    graph = {i: [] for i in range(2 * n)}
    for i in range(n):
        for j in range(n, 2 * n):
            graph[i].append(j)
            graph[j].append(i)
    return graph

# Create example graphs
cycle_graph = create_cycle_graph(6)
bipartite_graph = create_complete_bipartite_graph(3)

print("Cycle graph (6 nodes):")
print(cycle_graph)
print()
print("Complete bipartite graph K_{3,3}:")
print(bipartite_graph)
print()

# Run WL test
cycle_wl = weisfeiler_lehman_test(cycle_graph, num_iterations=2)
bipartite_wl = weisfeiler_lehman_test(bipartite_graph, num_iterations=2)

print("Cycle graph WL color evolution:")
for it, colors in cycle_wl.items():
    unique_colors = len(set(colors.values()))
    print(f"  Iteration {it}: {unique_colors} unique colors")

print()
print("Bipartite graph WL color evolution:")
for it, colors in bipartite_wl.items():
    unique_colors = len(set(colors.values()))
    print(f"  Iteration {it}: {unique_colors} unique colors")

# Check if graphs are distinguished
cycle_final = tuple(sorted(cycle_wl[2].values()))
bipartite_final = tuple(sorted(bipartite_wl[2].values()))
print()
print(f"Graphs distinguished by WL test: {cycle_final != bipartite_final}")

### 3.4 GIN Layer Expressiveness Test

Verify that GIN can distinguish graphs based on WL test expressiveness.

In [None]:
def adjacency_list_to_edge_index(adj_list: Dict[int, List[int]]) -> torch.Tensor:
    """Convert adjacency list to edge_index tensor"""
    edges = []
    for u, neighbors in adj_list.items():
        for v in neighbors:
            edges.append([u, v])
    if edges:
        return torch.tensor(edges, dtype=torch.long).t().contiguous()
    else:
        return torch.zeros((2, 0), dtype=torch.long)

# Create small graphs for expressiveness test
graph1 = create_cycle_graph(4)
graph2 = create_complete_bipartite_graph(2)

edge1 = adjacency_list_to_edge_index(graph1)
edge2 = adjacency_list_to_edge_index(graph2)

num_nodes_1 = len(graph1)
num_nodes_2 = len(graph2)

print(f"Graph 1 (cycle): {num_nodes_1} nodes, {edge1.shape[1]} edges")
print(f"Graph 2 (bipartite): {num_nodes_2} nodes, {edge2.shape[1]} edges")
print()

# Create GIN model
gin = GIN(input_dim=1, hidden_dims=[16], output_dim=8)

# Initialize node features (all ones for simplicity)
features1 = torch.ones(num_nodes_1, 1)
features2 = torch.ones(num_nodes_2, 1)

print("Testing GIN expressiveness on different graph structures...")
print()

# Forward pass
with torch.no_grad():
    embeddings1 = gin(features1, edge1)
    embeddings2 = gin(features2, edge2)
    
    # Compute graph-level readouts
    readout1 = gin.readout(embeddings1)
    readout2 = gin.readout(embeddings2)
    
    # Compute similarity
    similarity = torch.cosine_similarity(readout1, readout2).item()

print(f"Graph 1 readout shape: {readout1.shape}")
print(f"Graph 2 readout shape: {readout2.shape}")
print(f"Cosine similarity: {similarity:.4f}")
print(f"Are graphs sufficiently different (sim < 0.95): {similarity < 0.95}")

### 3.5 Training: Inductive vs Transductive Learning

Compare inductive (GraphSAGE-like) and transductive (full-batch) learning approaches.

In [None]:
class NodeClassifier(nn.Module):
    """Wrapper for node classification task"""
    
    def __init__(self, gnn_model: nn.Module, num_classes: int):
        super().__init__()
        self.gnn = gnn_model
        self.classifier = nn.Linear(gnn_model.output_dim, num_classes)
    
    def forward(self, *args, **kwargs):
        embeddings = self.gnn(*args, **kwargs)
        logits = self.classifier(embeddings)
        return logits

def train_epoch(model: nn.Module, optimizer: torch.optim.Optimizer,
                x: torch.Tensor, edge_index: torch.Tensor,
                labels: torch.Tensor, mask: torch.Tensor) -> float:
    """Train for one epoch"""
    model.train()
    optimizer.zero_grad()
    
    logits = model(x, edge_index)
    loss = F.cross_entropy(logits[mask], labels[mask])
    loss.backward()
    optimizer.step()
    
    return loss.item()

def evaluate(model: nn.Module, x: torch.Tensor, edge_index: torch.Tensor,
             labels: torch.Tensor, mask: torch.Tensor) -> Tuple[float, float]:
    """Evaluate model"""
    model.eval()
    with torch.no_grad():
        logits = model(x, edge_index)
        loss = F.cross_entropy(logits[mask], labels[mask])
        pred = logits[mask].argmax(dim=1)
        acc = accuracy_score(labels[mask].cpu(), pred.cpu())
    return loss.item(), acc

# Create models
gin_model = GIN(input_dim=data.num_features, hidden_dims=[64, 64], output_dim=32, dropout=0.5)
gin_classifier = NodeClassifier(gin_model, dataset.num_classes)

# Train GIN
print("Training GIN model...")
optimizer = Adam(gin_classifier.parameters(), lr=0.01, weight_decay=5e-4)

train_losses = []
val_accs = []
test_accs = []

for epoch in range(100):
    train_loss = train_epoch(gin_classifier, optimizer, data.x, data.edge_index,
                             data.y, data.train_mask)
    val_loss, val_acc = evaluate(gin_classifier, data.x, data.edge_index,
                                 data.y, data.val_mask)
    test_loss, test_acc = evaluate(gin_classifier, data.x, data.edge_index,
                                   data.y, data.test_mask)
    
    train_losses.append(train_loss)
    val_accs.append(val_acc)
    test_accs.append(test_acc)
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d} | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")

print(f"\nFinal GIN Performance:")
print(f"  Val Accuracy: {val_accs[-1]:.4f}")
print(f"  Test Accuracy: {test_accs[-1]:.4f}")

### 3.6 Architecture Comparison

In [None]:
# Compare different aggregators in a toy GraphSAGE model
print("Comparing GraphSAGE aggregators on small graph...")
print()

# Create small feature matrix
small_x = torch.randn(10, 8)
edge_index_small = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dtype=torch.long)

# Define sample neighborhoods
neighbors_layer0 = [[1, 2], [0, 2, 3], [1, 3], [2, 4], [3, 5], [4], [7], [6, 8], [7, 9], [8]]
neighbor_indices = [neighbors_layer0]  # Just one layer for demo

# Test each aggregator
print("Aggregator Comparison:")
print("-" * 60)

for agg_name in ['mean', 'pooling', 'lstm']:
    try:
        sage = GraphSAGE(input_dim=8, hidden_dims=[16], output_dim=8, aggregator_type=agg_name)
        with torch.no_grad():
            output = sage(small_x, [neighbors_layer0])
        params = sum(p.numel() for p in sage.parameters())
        print(f"{agg_name.capitalize():10s} | Output shape: {output.shape} | Params: {params}")
    except Exception as e:
        print(f"{agg_name.capitalize():10s} | Error: {str(e)[:40]}")

print()
print("Note: Different aggregators have different computational characteristics:")
print("  - Mean: Fastest, simplest")
print("  - Pooling: Captures non-linear neighbor interactions")
print("  - LSTM: Sequential processing, order-dependent")

### 3.7 Sampling Impact on Performance

Analyze how different sampling strategies affect model performance.

In [None]:
# Analyze sampling impact
print("Sampling Strategy Analysis")
print("=" * 60)
print()

# Sample from different sample sizes
sample_sizes = [5, 10, 25, 50]
results = {'sample_size': [], 'uniform_coverage': [], 'importance_coverage': []}

for sample_size in sample_sizes:
    # Test on high-degree nodes
    high_degree_nodes = sorted([(len(adj_list[n]), n) for n in range(data.num_nodes)],
                               reverse=True)[:10]
    high_degree_nodes = [n for _, n in high_degree_nodes]
    
    uniform_coverage = []
    importance_coverage = []
    
    for node in high_degree_nodes:
        degree = len(adj_list[node])
        if degree > 0:
            u_sample = uniform_sampler.sample_neighbors(node, sample_size)
            i_sample = importance_sampler.sample_neighbors(node, sample_size)
            uniform_coverage.append(len(u_sample) / degree)
            importance_coverage.append(len(i_sample) / degree)
    
    results['sample_size'].append(sample_size)
    results['uniform_coverage'].append(np.mean(uniform_coverage))
    results['importance_coverage'].append(np.mean(importance_coverage))

print(f"{'Sample Size':<15} {'Uniform Cov':<15} {'Importance Cov':<15}")
print("-" * 45)
for i, size in enumerate(sample_sizes):
    print(f"{size:<15} {results['uniform_coverage'][i]:<15.4f} {results['importance_coverage'][i]:<15.4f}")

print()
print("Observations:")
print("  - Coverage increases with sample size (as expected)")
print("  - Importance sampling prioritizes high-degree neighbors")
print("  - Trade-off between bias and variance in sampling")

### 3.8 Training Dynamics Visualization

In [None]:
# Visualize training dynamics
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Training loss
axes[0].plot(train_losses, label='Training Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('GIN: Training Loss Over Time')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Plot 2: Validation and test accuracy
axes[1].plot(val_accs, label='Validation Accuracy', linewidth=2)
axes[1].plot(test_accs, label='Test Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('GIN: Node Classification Accuracy')
axes[1].grid(True, alpha=0.3)
axes[1].legend()
axes[1].set_ylim([0, 1])

plt.tight_layout()
plt.savefig('training_dynamics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training dynamics visualization saved!")

### 3.9 Model Complexity Analysis

In [None]:
def count_parameters(model: nn.Module) -> int:
    """Count trainable parameters in model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Complexity Analysis")
print("=" * 60)
print()

# Create different model sizes
configs = [
    {'name': 'Small', 'hidden_dims': [32]},
    {'name': 'Medium', 'hidden_dims': [64, 64]},
    {'name': 'Large', 'hidden_dims': [128, 128, 128]},
]

print(f"{'Model':<12} {'Architecture':<30} {'Parameters':<15}")
print("-" * 57)

for config in configs:
    gin = GIN(input_dim=data.num_features, hidden_dims=config['hidden_dims'],
              output_dim=32)
    classifier = NodeClassifier(gin, dataset.num_classes)
    params = count_parameters(classifier)
    arch_str = f"[{data.num_features}"] + [str(h) for h in config['hidden_dims']] + ['32', str(dataset.num_classes)]
    arch_str = '-'.join(arch_str)
    print(f"{config['name']:<12} {arch_str:<30} {params:<15}")

print()
print("Observation: Larger models have more capacity but risk overfitting.")

## Part 4: Exercises

### Exercise 1: Implement Custom Aggregator

Create a custom aggregator for GraphSAGE that combines mean and attention.

In [None]:
# TODO: Exercise 1
# Implement an AttentionAggregator for GraphSAGE that:
# 1. Computes attention weights over neighbors
# 2. Aggregates neighbors using attention-weighted sum
# 3. Combines with node features

class AttentionAggregator(GraphSAGEAggregator):
    """Custom attention-based aggregator"""
    
    def __init__(self, input_dim: int, output_dim: int, bias: bool = True):
        super().__init__(input_dim, output_dim, bias)
        # TODO: Implement attention mechanism
        # Hints:
        # - Create attention parameters
        # - Use softmax for normalized weights
        # - Combine attention-weighted neighbors with node features
        pass
    
    def forward(self, node_features: torch.Tensor, neighbor_features: List[torch.Tensor]) -> torch.Tensor:
        # TODO: Implement forward pass
        pass

# Test your implementation
print("Exercise 1: Implement custom aggregator")
print("See TODO above for implementation details.")

### Exercise 2: Analyze GIN Expressiveness

Create non-isomorphic graphs that WL test cannot distinguish and verify GIN's limits.

In [None]:
# TODO: Exercise 2
# Find or create pairs of non-isomorphic graphs that:
# 1. Have same WL coloring sequence
# 2. Are structurally different
# 3. Cannot be distinguished by GIN

# Classic example: McKay graphs (require careful construction)
# Simpler alternative: Regular graphs with same degree distribution

print("Exercise 2: GIN Expressiveness Analysis")
print("TODO: Create pairs of graphs that GIN cannot distinguish")
print()
print("Hints:")
print("  1. Start with regular graphs (all nodes have same degree)")
print("  2. Create different graph structures with same degree")
print("  3. Run WL test to verify they have same coloring")
print("  4. Pass to GIN and verify similar embeddings")
print()
print("Example: 6-node cycle vs 6-node disjoint edges")

### Exercise 3: Implement Inductive Learning

Modify GraphSAGE to perform true inductive learning on unseen nodes.

In [None]:
# TODO: Exercise 3
# Implement inductive node classification where:
# 1. Train on subset of graph
# 2. Evaluate on completely unseen nodes
# 3. Use only aggregation functions (no node embeddings)

def inductive_split(num_nodes: int, train_ratio: float = 0.6,
                    test_ratio: float = 0.2) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Split graph into train/val/test where test nodes are completely unseen during training.
    
    Returns:
        train_mask, val_mask, test_mask
    """
    # TODO: Implement split
    pass

print("Exercise 3: Inductive Learning")
print("TODO: Implement inductive split and train GraphSAGE on unseen nodes")
print()
print("Steps:")
print("  1. Create inductive split (train/val/test are disjoint)")
print("  2. Only use training node neighborhoods for sampling")
print("  3. Evaluate on test nodes using only aggregation functions")
print("  4. Compare inductive vs transductive performance")

### Exercise 4: Hyperparameter Tuning

Find optimal hyperparameters for both architectures.

In [None]:
# TODO: Exercise 4
# Perform hyperparameter search for GIN and GraphSAGE:
# - Learning rate: [0.001, 0.01, 0.1]
# - Dropout: [0.0, 0.3, 0.5]
# - Hidden dims: [[32], [64], [64, 64], [128, 128]]
# - Layer depth: [1, 2, 3, 4]

import itertools

hyperparams = {
    'learning_rate': [0.001, 0.01],
    'dropout': [0.0, 0.3],
    'hidden_dim': [32, 64],
}

print("Exercise 4: Hyperparameter Tuning")
print(f"Total configurations: {np.prod([len(v) for v in hyperparams.values()])}")
print()
print("TODO: Implement grid search over hyperparameters")
print("Hints:")
print("  1. Use early stopping to avoid overfitting")
print("  2. Track val accuracy for hyperparameter selection")
print("  3. Store results in dictionary for comparison")
print("  4. Visualize results (accuracy vs hyperparameters)")

### Exercise 5: Extend to Link Prediction

Adapt the architectures for link prediction task.

In [None]:
# TODO: Exercise 5
# Implement link prediction using GIN:
# 1. Use node embeddings for link scoring
# 2. Implement edge prediction layer
# 3. Create negative sampling for training
# 4. Evaluate using AUC metric

class LinkPredictor(nn.Module):
    """Link prediction head"""
    
    def __init__(self, embedding_dim: int):
        super().__init__()
        # TODO: Implement link scoring function
        # Options:
        # - Inner product: z_i^T z_j
        # - MLP: MLP([z_i || z_j])
        # - Bilinear: z_i^T W z_j
        pass
    
    def forward(self, z: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        # TODO: Score edges
        pass

print("Exercise 5: Link Prediction")
print("TODO: Implement link prediction using graph embeddings")
print()
print("Steps:")
print("  1. Create LinkPredictor head for edge scoring")
print("  2. Implement negative sampling for training")
print("  3. Train link predictor end-to-end")
print("  4. Evaluate using ROC-AUC on validation/test edges")

## Summary

In this lesson, we explored two advanced GNN architectures:

### GraphSAGE (Inductive Learning)
- **Key innovation**: Neighborhood sampling enables inductive learning
- **Aggregators**: Mean, LSTM, pooling for flexible information aggregation
- **Scalability**: Mini-batch training for large graphs
- **Use case**: Production systems, evolving graphs, unseen nodes

### GIN (Expressive Power)
- **Key innovation**: Injective aggregation matches WL test expressiveness
- **Theory**: Grounded in graph isomorphism theory
- **Simplicity**: Clean mathematical formulation
- **Use case**: Graph-level tasks, when theoretical guarantees matter

### Key Takeaways
1. **Sampling strategies** dramatically affect scalability
2. **Inductive learning** enables generalization to unseen nodes
3. **Expressiveness** bounds explain GNN limitations
4. **Trade-offs** between theory, empirical performance, and scalability
5. **Architecture choice** depends on specific problem requirements

### Next Steps
- Implement custom aggregators
- Explore more sampling strategies (importance, subgraph)
- Study higher-order GNN expressiveness
- Apply to real-world problems (recommendation systems, knowledge graphs)
- Investigate temporal and heterogeneous graphs