# Lesson 3: Message Passing & GNN Foundations

## Overview

This lesson explores the core concept behind Graph Neural Networks: **message passing**. We'll understand how information propagates through a graph, implement message passing from scratch, and visualize each step of the process.

### Learning Objectives
1. Understand the message passing mechanism
2. Implement message passing from scratch
3. Explore different aggregation functions
4. Understand receptive fields in GNNs
5. Recognize the over-smoothing problem
6. Build a simple GNN for node classification

### Key Concepts
- **Message Passing**: Core mechanism where nodes exchange information with neighbors
- **Aggregation**: Combining messages from multiple neighbors (sum, mean, max, etc.)
- **Receptive Field**: The set of nodes that can influence a target node's representation
- **Over-smoothing**: Problem where nodes become indistinguishable after many layers
- **Multi-layer GNN**: Stacking multiple message passing layers for deeper models

## Part 1: Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from collections import defaultdict, deque
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import seaborn as sns
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
import matplotlib.patches as mpatches

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print('All imports successful!')
print(f'PyTorch version: {torch.__version__}')

## Part 2: Understanding Message Passing

### Concept

Message passing is the fundamental operation in Graph Neural Networks. The idea is simple:

1. **Message Computation**: Each node prepares a message based on its features
2. **Message Passing**: Messages are sent along edges to neighbors
3. **Aggregation**: Each node collects messages from all its neighbors
4. **Update**: Node updates its representation based on aggregated messages

### Mathematical Formulation

For a given layer $l$:

$$m_i^{(l)} = \\text{AGGREGATE}\\left(\\{h_j^{(l)} : j \\in \\mathcal{N}(i)\\}\\right)$$

$$h_i^{(l+1)} = \\sigma\\left(W \\cdot [h_i^{(l)} || m_i^{(l)}]\\right)$$

where:
- $h_i^{(l)}$ is the hidden state of node $i$ at layer $l$
- $\\mathcal{N}(i)$ is the set of neighbors of node $i$
- $m_i^{(l)}$ is the aggregated message
- $W$ is a learnable weight matrix
- $\\sigma$ is an activation function
- $||$ denotes concatenation

## Part 3: Message Passing from Scratch

Let's implement message passing step-by-step without using any GNN libraries.

In [None]:
class SimpleMessagePassing:
    """
    A simple message passing implementation from scratch.
    
    This class demonstrates the basic message passing mechanism:
    1. Each node sends its features to its neighbors
    2. Each node aggregates messages from neighbors
    3. Each node updates its representation
    """
    
    def __init__(self, num_nodes, feature_dim, aggregation='mean'):
        """
        Initialize the message passing layer.
        
        Args:
            num_nodes: Number of nodes in the graph
            feature_dim: Dimension of node features
            aggregation: Type of aggregation ('mean', 'sum', 'max')
        """
        self.num_nodes = num_nodes
        self.feature_dim = feature_dim
        self.aggregation = aggregation
        
        # Initialize node features
        self.features = np.random.randn(num_nodes, feature_dim)
        
        # Initialize learnable weight matrix
        self.W = np.random.randn(feature_dim, feature_dim) * 0.1
        
        # Store adjacency list
        self.adj_list = defaultdict(list)
        
        # Store history for visualization
        self.history = []
    
    def add_edge(self, u, v):
        """Add an undirected edge between nodes u and v."""
        self.adj_list[u].append(v)
        self.adj_list[v].append(u)
    
    def message_fn(self, node_features):
        """
        Message function: prepares features as messages.
        """
        return node_features
    
    def aggregate_fn(self, messages):
        """
        Aggregation function: combines messages from neighbors.
        """
        if len(messages) == 0:
            return np.zeros(self.feature_dim)
        
        messages = np.array(messages)
        
        if self.aggregation == 'mean':
            return np.mean(messages, axis=0)
        elif self.aggregation == 'sum':
            return np.sum(messages, axis=0)
        elif self.aggregation == 'max':
            return np.max(messages, axis=0)
        else:
            raise ValueError(f'Unknown aggregation: {self.aggregation}')
    
    def update_fn(self, node_feature, aggregated_message):
        """
        Update function: combines node's features with aggregated message.
        """
        combined = np.concatenate([node_feature, aggregated_message])
        updated = node_feature * 0.5 + aggregated_message * 0.5
        updated = np.tanh(updated)
        return updated
    
    def forward(self, num_layers=1, return_history=False):
        """
        Perform message passing for specified number of layers.
        """
        if return_history:
            history = [self.features.copy()]
        
        for layer in range(num_layers):
            new_features = np.zeros_like(self.features)
            
            for node_id in range(self.num_nodes):
                neighbor_ids = self.adj_list[node_id]
                neighbor_features = [self.features[nid] for nid in neighbor_ids]
                
                messages = [self.message_fn(f) for f in neighbor_features]
                aggregated = self.aggregate_fn(messages)
                new_features[node_id] = self.update_fn(self.features[node_id], aggregated)
            
            self.features = new_features
            
            if return_history:
                history.append(self.features.copy())
        
        if return_history:
            return self.features, history
        return self.features


print('SimpleMessagePassing class created successfully!')

## Part 4: Visualization of Message Passing Steps

Let's create a simple graph and visualize how information flows through message passing.

In [None]:
# Create a simple graph
G = nx.Graph()
G.add_edges_from([
    (0, 1), (1, 2), (2, 3), (3, 4),
    (1, 5), (5, 6), (3, 7), (4, 8)
])

# Initialize message passing
mp = SimpleMessagePassing(num_nodes=9, feature_dim=3, aggregation='mean')

# Add edges to message passing
for edge in G.edges():
    mp.add_edge(edge[0], edge[1])

# Perform message passing and collect history
_, history = mp.forward(num_layers=3, return_history=True)

print(f'Graph has {len(G.nodes())} nodes and {len(G.edges())} edges')
print(f'Collected history for {len(history)} layers')
print(f'Feature dimension: {mp.feature_dim}')

In [None]:
def visualize_message_passing_steps(G, history):
    """
    Visualize how node features evolve through message passing layers.
    """
    fig, axes = plt.subplots(1, len(history), figsize=(15, 4))
    
    pos = nx.spring_layout(G, seed=42, k=2)
    
    for layer, (ax, features) in enumerate(zip(axes, history)):
        node_colors = features[:, 0]
        node_colors_norm = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min() + 1e-8)
        
        nx.draw_networkx_edges(G, pos, ax=ax, width=1.5, alpha=0.6)
        nodes = nx.draw_networkx_nodes(
            G, pos, ax=ax,
            node_color=node_colors_norm,
            node_size=800,
            cmap='viridis',
            vmin=0, vmax=1
        )
        nx.draw_networkx_labels(G, pos, ax=ax, font_size=10, font_weight='bold')
        
        ax.set_title(f'Layer {layer}', fontsize=12, fontweight='bold')
        ax.axis('off')
    
    plt.colorbar(nodes, ax=axes[-1], label='Feature Value')
    plt.suptitle('Message Passing Evolution Through Layers', fontsize=14, fontweight='bold', y=1.02)
    plt.show()


visualize_message_passing_steps(G, history)

## Part 5: Comparing Aggregation Functions

Different aggregation functions lead to different information propagation patterns. Let's compare them.

In [None]:
def compare_aggregations(G, num_layers=2):
    """
    Compare how different aggregation functions affect message passing.
    """
    aggregations = ['mean', 'sum', 'max']
    results = {}
    
    for agg in aggregations:
        mp = SimpleMessagePassing(num_nodes=len(G.nodes()), feature_dim=4, aggregation=agg)
        for edge in G.edges():
            mp.add_edge(edge[0], edge[1])
        
        features, history = mp.forward(num_layers=num_layers, return_history=True)
        results[agg] = history
    
    return results


agg_results = compare_aggregations(G, num_layers=3)

# Analyze and visualize
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
fig.suptitle('Message Passing with Different Aggregation Functions', fontsize=16, fontweight='bold')

pos = nx.spring_layout(G, seed=42, k=2)

for row, (agg_type, history) in enumerate(agg_results.items()):
    for col, (layer, features) in enumerate(enumerate(history)):
        ax = axes[row, col]
        
        node_colors = features[:, 0]
        node_colors_norm = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min() + 1e-8)
        
        nx.draw_networkx_edges(G, pos, ax=ax, width=1.5, alpha=0.5)
        nx.draw_networkx_nodes(
            G, pos, ax=ax,
            node_color=node_colors_norm,
            node_size=600,
            cmap='plasma',
            vmin=0, vmax=1
        )
        nx.draw_networkx_labels(G, pos, ax=ax, font_size=8)
        
        ax.set_title(f'{agg_type.upper()} - Layer {layer}', fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.show()

print('Aggregation Comparison Complete!')
print('Observations:')
print('- MEAN: Averages neighbor features, preserves scale')
print('- SUM: Accumulates neighbor features, can lead to larger values')
print('- MAX: Takes maximum element-wise, emphasizes strongest signals')

## Part 6: Receptive Field Visualization

The receptive field is the set of nodes that can influence a target node's representation. It grows with each message passing layer.

In [None]:
def compute_receptive_field(G, target_node, max_distance):
    """
    Compute the receptive field of a target node at different distances.
    """
    receptive_field = {}
    
    visited = {target_node: 0}
    queue = deque([target_node])
    
    while queue:
        node = queue.popleft()
        current_distance = visited[node]
        
        if current_distance >= max_distance:
            continue
        
        for neighbor in G.neighbors(node):
            if neighbor not in visited:
                visited[neighbor] = current_distance + 1
                queue.append(neighbor)
    
    for node, distance in visited.items():
        if distance not in receptive_field:
            receptive_field[distance] = set()
        receptive_field[distance].add(node)
    
    return receptive_field


# Visualize receptive fields
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Receptive Field Growth with Message Passing Layers', fontsize=14, fontweight='bold')

pos = nx.spring_layout(G, seed=42, k=2)
target_node = 0

for layer_idx in range(6):
    ax = axes[layer_idx // 3, layer_idx % 3]
    
    rf = compute_receptive_field(G, target_node, max_distance=layer_idx)
    
    node_colors = []
    for node in G.nodes():
        if node == target_node:
            node_colors.append('red')
        elif node in rf.get(layer_idx, set()):
            node_colors.append('lightblue')
        elif node in [n for d in range(layer_idx) for n in rf.get(d, set())]:
            node_colors.append('lightgreen')
        else:
            node_colors.append('lightgray')
    
    nx.draw_networkx_edges(G, pos, ax=ax, width=1.5, alpha=0.5)
    nx.draw_networkx_nodes(
        G, pos, ax=ax,
        node_color=node_colors,
        node_size=800,
        edgecolors='black',
        linewidths=2
    )
    nx.draw_networkx_labels(G, pos, ax=ax, font_size=10, font_weight='bold')
    
    all_in_rf = sum(len(rf.get(d, set())) for d in range(layer_idx + 1))
    
    ax.set_title(f'Layer {layer_idx}\n(Size: {all_in_rf})', fontweight='bold')
    ax.axis('off')

legend_elements = [
    mpatches.Patch(facecolor='red', label='Target Node (0)'),
    mpatches.Patch(facecolor='lightblue', label='New in Layer'),
    mpatches.Patch(facecolor='lightgreen', label='Previous Layers'),
    mpatches.Patch(facecolor='lightgray', label='Outside')
]
fig.legend(handles=legend_elements, loc='lower center', ncol=4, fontsize=10, bbox_to_anchor=(0.5, -0.02))

plt.tight_layout()
plt.show()

print(f'Receptive Field Analysis for Node {target_node}:')
for layer in range(4):
    rf = compute_receptive_field(G, target_node, max_distance=layer)
    total = sum(len(rf.get(d, set())) for d in range(layer + 1))
    print(f'  Layer {layer}: {total} nodes in receptive field')

## Part 7: Over-smoothing Problem

A critical challenge in deep GNNs: as layers increase, node representations become increasingly similar, reducing their distinctiveness. This is called **over-smoothing**.

In [None]:
def demonstrate_over_smoothing(num_layers=10, num_nodes=20):
    """
    Demonstrate the over-smoothing problem in deep GNNs.
    """
    G_random = nx.erdos_renyi_graph(num_nodes, p=0.3, seed=42)
    
    adj = nx.to_numpy_array(G_random)
    adj = adj + np.eye(num_nodes)
    row_sum = adj.sum(axis=1, keepdims=True)
    adj = adj / row_sum
    adj = torch.FloatTensor(adj)
    
    X = torch.randn(num_nodes, 16)
    
    feature_distances = []
    feature_vars = []
    
    current_features = X.clone()
    feature_distances.append(float('inf'))
    feature_vars.append(current_features.var().item())
    
    for layer_idx in range(num_layers):
        aggregated = torch.matmul(adj, current_features)
        
        W = torch.randn(16, 16) * 0.01
        current_features = torch.matmul(aggregated, W)
        current_features = torch.relu(current_features)
        current_features = torch.tanh(current_features)
        
        pairwise_distances = torch.cdist(current_features, current_features)
        mask = ~torch.eye(num_nodes, dtype=torch.bool)
        avg_distance = pairwise_distances[mask].mean().item()
        
        feature_distances.append(avg_distance)
        feature_vars.append(current_features.var().item())
    
    return feature_distances, feature_vars


feature_distances, feature_vars = demonstrate_over_smoothing(num_layers=15, num_nodes=20)

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

ax = axes[0]
ax.plot(range(len(feature_distances)), feature_distances, marker='o', linewidth=2, markersize=6, color='darkblue')
ax.fill_between(range(len(feature_distances)), feature_distances, alpha=0.3, color='blue')
ax.set_xlabel('Layer', fontsize=12)
ax.set_ylabel('Average Pairwise Distance', fontsize=12)
ax.set_title('Over-smoothing: Nodes Become More Similar', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.axvline(x=5, color='red', linestyle='--', alpha=0.5, label='Problem Zone Begins')
ax.legend()

ax = axes[1]
ax.plot(range(len(feature_vars)), feature_vars, marker='s', linewidth=2, markersize=6, color='darkgreen')
ax.fill_between(range(len(feature_vars)), feature_vars, alpha=0.3, color='green')
ax.set_xlabel('Layer', fontsize=12)
ax.set_ylabel('Feature Variance', fontsize=12)
ax.set_title('Feature Variance Decrease with Depth', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print('Over-smoothing Analysis:')
print(f'Initial average distance: {feature_distances[1]:.4f}')
print(f'Final average distance: {feature_distances[-1]:.4f}')
print(f'Distance decreased by: {(1 - feature_distances[-1]/feature_distances[1]) * 100:.1f}%')

## Part 8: Simple GNN Implementation

Now let's implement a complete GNN model for node classification using PyTorch.

In [None]:
class GNNModel(nn.Module):
    """
    A simple yet effective GNN model for node classification.
    
    Architecture:
    Input -> GNN Layer 1 -> ReLU -> GNN Layer 2 -> Output
    """
    
    def __init__(self, in_features, hidden_features, out_features, num_layers=2):
        """
        Initialize GNN model.
        
        Args:
            in_features: Input feature dimension
            hidden_features: Hidden feature dimension
            out_features: Output feature dimension (classes)
            num_layers: Number of GNN layers
        """
        super(GNNModel, self).__init__()
        
        self.num_layers = num_layers
        
        self.linear1 = nn.Linear(in_features, hidden_features)
        
        self.hidden_layers = nn.ModuleList()
        for i in range(num_layers - 2):
            self.hidden_layers.append(nn.Linear(hidden_features, hidden_features))
        
        self.linear2 = nn.Linear(hidden_features, out_features)
        
        self.relu = nn.ReLU()
    
    def forward(self, X, adj):
        """
        Forward pass: apply message passing and classification.
        
        Args:
            X: Node features (num_nodes, in_features)
            adj: Normalized adjacency matrix (num_nodes, num_nodes)
        
        Returns:
            Node predictions (num_nodes, out_features)
        """
        x = torch.matmul(adj, X)
        x = self.linear1(x)
        x = self.relu(x)
        
        for hidden_layer in self.hidden_layers:
            x = torch.matmul(adj, x)
            x = hidden_layer(x)
            x = self.relu(x)
        
        x = torch.matmul(adj, x)
        x = self.linear2(x)
        
        return x


print('GNNModel class defined successfully!')

In [None]:
def create_synthetic_graph_dataset(num_nodes=100, num_features=10, num_classes=3, seed=42):
    """
    Create a synthetic graph dataset for node classification.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    G = nx.connected_watts_strogatz_graph(num_nodes, k=4, p=0.3, seed=seed)
    
    X = torch.randn(num_nodes, num_features)
    
    y = torch.zeros(num_nodes, dtype=torch.long)
    for i in range(num_nodes):
        y[i] = (i // (num_nodes // num_classes)) % num_classes
    
    for i in range(num_nodes):
        X[i, :int(num_features/2)] += float(y[i])
    
    adj = nx.to_numpy_array(G)
    adj = adj + np.eye(num_nodes)
    row_sum = adj.sum(axis=1, keepdims=True)
    adj = adj / row_sum
    adj = torch.FloatTensor(adj)
    
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    
    indices = np.arange(num_nodes)
    np.random.shuffle(indices)
    
    train_idx = indices[:int(0.6 * num_nodes)]
    val_idx = indices[int(0.6 * num_nodes):int(0.8 * num_nodes)]
    test_idx = indices[int(0.8 * num_nodes):]
    
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    return X, y, adj, train_mask, val_mask, test_mask


X, y, adj, train_mask, val_mask, test_mask = create_synthetic_graph_dataset(
    num_nodes=100, num_features=16, num_classes=3
)

print(f'Dataset created:')
print(f'  Nodes: {X.shape[0]}')
print(f'  Features: {X.shape[1]}')
print(f'  Classes: {int(y.max().item()) + 1}')
print(f'  Train samples: {train_mask.sum().item()}')
print(f'  Validation samples: {val_mask.sum().item()}')
print(f'  Test samples: {test_mask.sum().item()}')

In [None]:
def train_gnn(model, X, y, adj, train_mask, val_mask, test_mask, epochs=100, lr=0.01):
    """
    Train the GNN model.
    """
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    val_accs = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        logits = model(X, adj)
        loss = criterion(logits[train_mask], y[train_mask])
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                logits = model(X, adj)
                val_preds = logits[val_mask].argmax(dim=1)
                val_acc = (val_preds == y[val_mask]).float().mean().item()
                val_accs.append(val_acc)
    
    model.eval()
    with torch.no_grad():
        logits = model(X, adj)
        test_preds = logits[test_mask].argmax(dim=1)
        test_acc = (test_preds == y[test_mask]).float().mean().item()
    
    return train_losses, val_accs, test_acc


model = GNNModel(in_features=16, hidden_features=32, out_features=3, num_layers=2)
train_losses, val_accs, test_acc = train_gnn(
    model, X, y, adj, train_mask, val_mask, test_mask,
    epochs=100, lr=0.01
)

print(f'Training completed!')
print(f'Final Test Accuracy: {test_acc:.4f}')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

ax = axes[0]
ax.plot(range(len(train_losses)), train_losses, linewidth=2, color='darkblue')
ax.fill_between(range(len(train_losses)), train_losses, alpha=0.3, color='blue')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Training Loss', fontsize=12)
ax.set_title('Training Loss Over Epochs', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)

ax = axes[1]
val_epochs = [i * 10 for i in range(len(val_accs))]
ax.plot(val_epochs, val_accs, marker='o', linewidth=2, markersize=8, color='darkgreen')
ax.fill_between(val_epochs, val_accs, alpha=0.3, color='green')
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Validation Accuracy', fontsize=12)
ax.set_title('Validation Accuracy Over Epochs', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1])

plt.tight_layout()
plt.show()

print(f'Training Summary:')
print(f'  Initial loss: {train_losses[0]:.4f}')
print(f'  Final loss: {train_losses[-1]:.4f}')
print(f'  Loss reduction: {(1 - train_losses[-1]/train_losses[0]) * 100:.1f}%')
print(f'  Max validation accuracy: {max(val_accs):.4f}')
print(f'  Test accuracy: {test_acc:.4f}')

## Part 9: Key Insights Summary

### Message Passing Mechanism
- Nodes exchange information with neighbors through edges
- Information propagates through multiple layers, creating larger receptive fields
- Different aggregation functions (mean, sum, max) lead to different behaviors

### Receptive Fields
- Layer 0: Node only sees itself
- Layer 1: Node sees itself and 1-hop neighbors
- Layer k: Node sees all neighbors within k-hops
- Deeper models capture longer-range dependencies

### Over-smoothing Problem
- As layers increase, node representations become similar
- Models deeper than 4-6 layers often suffer from this
- Solutions: skip connections, residual networks, attention mechanisms

### GNN Design Considerations
1. **Number of layers**: Balance between receptive field and over-smoothing
2. **Aggregation function**: Choose based on problem characteristics
3. **Feature dimension**: Larger hidden dimensions capture more information
4. **Normalization**: Proper normalization of adjacency matrix is crucial

## Part 10: Exercises

### Exercise 1: Implement Attention-based Aggregation

Instead of simple mean/sum aggregation, implement attention-based aggregation where neighbors with higher attention weights have more influence.

```python
class AttentionAggregation:
    '''
    Implement attention-based message aggregation.
    
    TODO:
    1. For each node, compute attention weights for each neighbor
    2. Weights should be learned or based on feature similarity
    3. Aggregate messages as weighted sum using attention weights
    4. Compare with uniform aggregation
    '''
    pass
```

### Exercise 2: Explore the Over-smoothing Trade-off

Train GNN models with varying depths (1, 2, 3, 4, 5, 6, 7, 8 layers) and measure:
- Training accuracy
- Validation accuracy
- Feature diversity (variance of node representations)

Plot the relationship between depth and performance.

### Exercise 3: Implement Skip Connections

Add residual connections to prevent over-smoothing:

```python
class GNNWithSkipConnections(nn.Module):
    '''
    Implement GNN with skip/residual connections.
    
    Modify the forward pass to:
    h_new = h_old + GNN_layer(h_old)
    
    This helps maintain diversity in node representations.
    '''
    pass
```

### Exercise 4: Design Your Own Graph Dataset

Create a meaningful graph dataset where:
1. Nodes represent entities (people, papers, molecules, etc.)
2. Edges represent relationships
3. Node features have semantic meaning
4. Labels reflect an interesting classification task

Train the GNN on your dataset and analyze the results.

### Exercise 5: Analyze Aggregation Function Impact

For different graph structures (dense, sparse, community-based):
1. Train GNN models with different aggregation functions
2. Measure convergence speed
3. Compare final accuracy
4. Explain why certain aggregations work better for specific graphs

### Exercise 6: Challenge - Implement GNN from Scratch

Implement a custom GNN class that clearly shows all message passing steps without using pre-built modules where possible.

In [None]:
# Challenge: Implement a custom GNN from scratch
# Fill in the TODO sections

class CustomGNNFromScratch(nn.Module):
    """
    A GNN implementation that clearly shows all message passing steps.
    
    TODO: Implement the following methods:
    1. message_fn: Compute messages from source nodes
    2. aggregate_fn: Aggregate messages at target nodes
    3. update_fn: Update node representation using aggregated messages
    4. forward: Orchestrate the message passing process
    """
    
    def __init__(self, in_features, out_features):
        super(CustomGNNFromScratch, self).__init__()
        self.W_msg = nn.Linear(in_features, out_features)
        self.W_upd = nn.Linear(in_features + out_features, out_features)
    
    def message_fn(self, node_features):
        """
        TODO: Implement message function
        Input: node_features (source node features)
        Output: message (transformed feature vector)
        """
        return self.W_msg(node_features)
    
    def aggregate_fn(self, messages):
        """
        TODO: Implement aggregation function
        Input: messages (list of message tensors from neighbors)
        Output: aggregated_message (single tensor)
        """
        if len(messages) == 0:
            return None
        return torch.mean(torch.stack(messages), dim=0)
    
    def update_fn(self, node_feature, aggregated_message):
        """
        TODO: Implement update function
        Input: node_feature (original node feature), aggregated_message
        Output: updated_feature (new node representation)
        """
        combined = torch.cat([node_feature, aggregated_message], dim=-1)
        return self.W_upd(combined)
    
    def forward(self, X, adj):
        """
        TODO: Implement forward pass using message passing
        
        Steps:
        1. For each node, compute messages to send (message_fn)
        2. Pass messages along edges (matrix multiplication with adj)
        3. At each node, aggregate received messages (aggregate_fn)
        4. Update node features using aggregated messages (update_fn)
        
        Input: X (node features), adj (adjacency matrix)
        Output: updated_X (new node features)
        """
        messages = self.message_fn(X)
        aggregated = torch.matmul(adj, messages)
        updated = self.W_upd(torch.cat([X, aggregated], dim=-1))
        return updated


print('CustomGNNFromScratch class template created!')
print('Now it\'s your turn to complete the TODO sections!')

## Part 11: References and Further Reading

### Key Papers
1. **"A Comprehensive Survey on Graph Neural Networks" (Wu et al., 2021)**
   - Comprehensive overview of GNN architectures and applications

2. **"Semi-Supervised Classification with Graph Convolutional Networks" (Kipf & Welling, 2017)**
   - Foundational GCN paper introducing normalized message passing

3. **"Graph Attention Networks" (Veličković et al., 2018)**
   - Introduces attention mechanism for aggregation

4. **"Over-smoothing Problem in Deep GNNs" (Li et al., 2018)**
   - Theoretical analysis of over-smoothing

### Related Concepts
- **Graph Convolutional Networks (GCN)**: Specific instantiation of message passing with spectral theory
- **Graph Attention Networks (GAT)**: Attention-based aggregation
- **GraphSAGE**: Importance sampling-based aggregation
- **Message Passing Neural Networks (MPNN)**: General framework for message passing

### Applications
- Node classification (recommendation systems, social network analysis)
- Graph classification (molecular property prediction, network clustering)
- Link prediction (knowledge graph completion, friend suggestion)
- Graph generation (molecule design, architecture discovery)