# Lab E.3: Graph Attention Networks (GAT)

**Module:** E - Graph Neural Networks  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê (Advanced)

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand the limitations of GCN's equal neighbor weighting
- [ ] Implement the Graph Attention mechanism from scratch
- [ ] Build a multi-head Graph Attention Network
- [ ] Train GAT on Cora and compare to GCN
- [ ] Visualize attention weights to understand what the model learns
- [ ] Analyze which edges get high attention scores

---

## üìö Prerequisites

- Completed: Lab E.2 (GCN from Scratch)
- Knowledge of: Attention mechanisms (helpful but not required)

---

## üåç Real-World Context

**Why do we need attention on graphs?**

In GCN, all neighbors contribute equally. But in reality:
- In social networks, your best friend's opinion matters more than an acquaintance's
- In citation networks, a highly relevant paper is more important than a tangential one
- In molecules, certain atom connections are more important for chemical properties

**Graph Attention Networks learn to weight neighbors differently** based on their relevance to the task!

---

## üßí ELI5: What Is Graph Attention?

> **Imagine you're at a party asking for movie recommendations:**
>
> With GCN (equal weighting):
> - You ask 10 people and give everyone's opinion equal weight
> - Your movie-buff friend counts the same as your tone-deaf uncle
>
> With GAT (attention weighting):
> - You **learn** who gives good recommendations
> - Your movie-buff friend gets 50% weight, uncle gets 2%
> - The weights are learned from experience!
>
> **The magic:** GAT doesn't just average neighbors - it learns which neighbors are important for each specific task.
>
> **In AI terms:** For each pair of connected nodes, GAT computes an "attention score" that represents how important node j is to node i. These scores are learned during training!

---

## Part 1: Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_self_loops, softmax
import time

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load Cora
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)

print(f"\nCora: {data.num_nodes} nodes, {data.num_edges} edges")
print(f"Features: {dataset.num_features}, Classes: {dataset.num_classes}")

---

## Part 2: Understanding the Attention Mechanism

### 2.1 The GAT Formula

For each pair of connected nodes $(i, j)$, GAT computes:

**Step 1: Compute raw attention scores**
$$e_{ij} = \text{LeakyReLU}(\mathbf{a}^T [\mathbf{W}h_i \| \mathbf{W}h_j])$$

Where:
- $\mathbf{W}$ = learnable weight matrix
- $\mathbf{a}$ = learnable attention vector
- $\|$ = concatenation
- LeakyReLU = activation (allows negative values)

**Step 2: Normalize with softmax over neighbors**
$$\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}$$

**Step 3: Weighted aggregation**
$$h'_i = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} \mathbf{W}h_j\right)$$

### üßí ELI5: Breaking Down Attention

> Think of it as a **compatibility score**:
>
> 1. Transform features: "Convert everyone's profile to a common format"
> 2. Compute compatibility: "How well do these two profiles match?"
> 3. Normalize: "Turn scores into percentages (must sum to 100%)"
> 4. Aggregate: "Weight everyone's input by their importance percentage"

In [None]:
# Visualize attention mechanism with a simple example

def visualize_attention_concept():
    """
    Demonstrate attention computation on a tiny graph.
    """
    print("üîç ATTENTION MECHANISM EXAMPLE")
    print("=" * 50)
    
    # Simple graph: Node 0 connected to nodes 1, 2, 3
    # Node 0 features
    h0 = torch.tensor([1.0, 0.5])
    
    # Neighbor features
    h1 = torch.tensor([0.8, 0.6])  # Similar to h0
    h2 = torch.tensor([0.2, 0.1])  # Different from h0
    h3 = torch.tensor([0.9, 0.4])  # Very similar to h0
    
    print("Node 0 features:", h0.tolist())
    print("Neighbor 1 features:", h1.tolist(), "(similar)")
    print("Neighbor 2 features:", h2.tolist(), "(different)")
    print("Neighbor 3 features:", h3.tolist(), "(very similar)")
    
    # Simple attention: dot product similarity
    print("\nüìä Computing attention scores (dot product similarity):")
    
    e01 = torch.dot(h0, h1)
    e02 = torch.dot(h0, h2)
    e03 = torch.dot(h0, h3)
    
    print(f"  e(0,1) = {e01.item():.3f}")
    print(f"  e(0,2) = {e02.item():.3f}")
    print(f"  e(0,3) = {e03.item():.3f}")
    
    # Softmax normalization
    scores = torch.tensor([e01, e02, e03])
    attention = F.softmax(scores, dim=0)
    
    print("\nüìä Normalized attention weights (softmax):")
    print(f"  Œ±(0,1) = {attention[0].item():.3f} ({attention[0].item()*100:.1f}%)")
    print(f"  Œ±(0,2) = {attention[1].item():.3f} ({attention[1].item()*100:.1f}%)")
    print(f"  Œ±(0,3) = {attention[2].item():.3f} ({attention[2].item()*100:.1f}%)")
    
    print("\nüí° Notice: Similar neighbors get MORE attention!")
    print("   Node 3 (very similar) gets ~" + f"{attention[2].item()*100:.0f}%")
    print("   Node 2 (different) gets only ~" + f"{attention[1].item()*100:.0f}%")

visualize_attention_concept()

---

## Part 3: Implementing GAT from Scratch

### 3.1 Single-Head Attention Layer

In [None]:
class GATLayerScratch(nn.Module):
    """
    Graph Attention Layer - Single Head Implementation.
    
    Implements:
        e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
        Œ±_ij = softmax_j(e_ij)
        h'_i = Œ£_j Œ±_ij * Wh_j
    
    Args:
        in_channels: Input feature dimension
        out_channels: Output feature dimension
        dropout: Dropout probability for attention coefficients
        negative_slope: LeakyReLU negative slope
    
    Example:
        >>> layer = GATLayerScratch(1433, 64)
        >>> out, attention = layer(x, edge_index, return_attention=True)
    """
    
    def __init__(self, in_channels: int, out_channels: int, 
                 dropout: float = 0.6, negative_slope: float = 0.2):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout = dropout
        self.negative_slope = negative_slope
        
        # Linear transformation W: [in_channels] -> [out_channels]
        self.W = nn.Linear(in_channels, out_channels, bias=False)
        
        # Attention mechanism: a ‚àà R^(2*out_channels)
        # We split into a_left and a_right for efficiency
        self.a_left = nn.Parameter(torch.Tensor(out_channels, 1))
        self.a_right = nn.Parameter(torch.Tensor(out_channels, 1))
        
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize parameters with Xavier/Glorot."""
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.a_left)
        nn.init.xavier_uniform_(self.a_right)
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, 
                return_attention: bool = False):
        """
        Forward pass.
        
        Args:
            x: Node features [num_nodes, in_channels]
            edge_index: Graph edges [2, num_edges]
            return_attention: If True, also return attention weights
            
        Returns:
            Updated node features [num_nodes, out_channels]
            (Optional) Attention weights [num_edges]
        """
        num_nodes = x.size(0)
        
        # Step 1: Add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        
        # Step 2: Linear transformation: Wh for all nodes
        Wh = self.W(x)  # [num_nodes, out_channels]
        
        # Step 3: Compute attention scores for each edge
        src, dst = edge_index  # Source and destination nodes
        
        # e_ij = LeakyReLU(a_left^T * Wh_i + a_right^T * Wh_j)
        # This is equivalent to a^T [Wh_i || Wh_j] but more efficient
        e_left = (Wh @ self.a_left).squeeze(-1)   # [num_nodes]
        e_right = (Wh @ self.a_right).squeeze(-1)  # [num_nodes]
        
        # Attention scores for each edge
        e = e_left[src] + e_right[dst]  # [num_edges]
        e = self.leaky_relu(e)
        
        # Step 4: Softmax over neighbors (for each destination node)
        # PyG's softmax groups by destination node
        alpha = softmax(e, dst, num_nodes=num_nodes)
        
        # Apply dropout to attention weights
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        # Step 5: Weighted aggregation
        out = torch.zeros_like(Wh)
        src_features = Wh[src] * alpha.view(-1, 1)  # Weighted source features
        out.scatter_add_(0, dst.view(-1, 1).expand_as(src_features), src_features)
        
        if return_attention:
            return out, (edge_index, alpha)
        return out
    
    def __repr__(self):
        return f'GATLayerScratch({self.in_channels}, {self.out_channels})'

In [None]:
# Test the layer
layer = GATLayerScratch(dataset.num_features, 64).to(device)
print(f"Layer: {layer}")

# Forward pass with attention
out, (edge_idx, attn) = layer(data.x, data.edge_index, return_attention=True)

print(f"\nInput shape: {data.x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {attn.shape}")
print(f"Attention weights sum per node: {attn.sum().item():.1f} (should be ~{data.num_nodes})")
print("\n‚úÖ GAT layer working!")

### 3.2 Multi-Head Attention

Just like in Transformers, using multiple attention "heads" helps the model capture different types of relationships.

**Multi-head attention:**
$$h'_i = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha^k_{ij} \mathbf{W}^k h_j\right)$$

Where $\|$ means concatenation of $K$ heads.

In [None]:
class MultiHeadGATLayer(nn.Module):
    """
    Multi-Head Graph Attention Layer.
    
    Runs multiple attention heads in parallel, then concatenates
    (or averages) the results.
    
    Args:
        in_channels: Input feature dimension
        out_channels: Output dimension PER HEAD
        heads: Number of attention heads (default: 8)
        concat: If True, concatenate heads. If False, average them.
        dropout: Dropout probability
    
    Output dimension:
        - If concat=True: heads * out_channels
        - If concat=False: out_channels
    """
    
    def __init__(self, in_channels: int, out_channels: int, 
                 heads: int = 8, concat: bool = True, dropout: float = 0.6):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        
        # Create multiple attention heads
        self.attention_heads = nn.ModuleList([
            GATLayerScratch(in_channels, out_channels, dropout=dropout)
            for _ in range(heads)
        ])
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                return_attention: bool = False):
        """
        Forward pass with multiple heads.
        """
        # Run all heads in parallel
        if return_attention:
            head_outputs = []
            head_attentions = []
            for head in self.attention_heads:
                out, attn = head(x, edge_index, return_attention=True)
                head_outputs.append(out)
                head_attentions.append(attn)
        else:
            head_outputs = [head(x, edge_index) for head in self.attention_heads]
        
        # Combine heads
        if self.concat:
            # Concatenate: [num_nodes, heads * out_channels]
            out = torch.cat(head_outputs, dim=-1)
        else:
            # Average: [num_nodes, out_channels]
            out = torch.stack(head_outputs, dim=0).mean(dim=0)
        
        if return_attention:
            return out, head_attentions
        return out
    
    def __repr__(self):
        return f'MultiHeadGATLayer({self.in_channels}, {self.out_channels}, heads={self.heads})'

In [None]:
# Test multi-head attention
mh_layer = MultiHeadGATLayer(dataset.num_features, 8, heads=8, concat=True).to(device)
print(f"Layer: {mh_layer}")

out = mh_layer(data.x, data.edge_index)
print(f"\nInput: {data.x.shape}")
print(f"Output: {out.shape} (8 heads √ó 8 dims = 64)")
print("\n‚úÖ Multi-head GAT working!")

---

## Part 4: Building the Complete GAT Model

### Activation Functions: ELU vs ReLU

The GAT paper uses **ELU** (Exponential Linear Unit) instead of ReLU:

| Activation | Formula | Advantage |
|------------|---------|-----------|
| ReLU | max(0, x) | Simple, fast |
| ELU | x if x > 0, else Œ±(e^x - 1) | Smoother, can output negatives |

**Why ELU for GAT?** ELU allows negative outputs, which can be important when attention weights vary significantly. The original GAT paper found ELU worked better than ReLU.

```python
# In PyTorch:
x = F.relu(x)   # Standard ReLU
x = F.elu(x)    # ELU (default Œ±=1.0)
```

In [None]:
class GAT(nn.Module):
    """
    Two-layer Graph Attention Network for node classification.
    
    Architecture:
        Input ‚Üí Multi-Head GAT (8 heads, concat) ‚Üí ELU ‚Üí Dropout 
              ‚Üí GAT (1 head, no concat) ‚Üí Output
    
    Args:
        num_features: Input feature dimension
        num_classes: Number of output classes
        hidden_dim: Hidden dimension per head (default: 8)
        heads: Number of attention heads (default: 8)
        dropout: Dropout probability (default: 0.6)
    """
    
    def __init__(self, num_features: int, num_classes: int,
                 hidden_dim: int = 8, heads: int = 8, dropout: float = 0.6):
        super().__init__()
        
        # Layer 1: Multi-head attention with concatenation
        self.gat1 = MultiHeadGATLayer(
            in_channels=num_features,
            out_channels=hidden_dim,
            heads=heads,
            concat=True,
            dropout=dropout
        )
        
        # Layer 2: Single-head attention (for classification)
        self.gat2 = MultiHeadGATLayer(
            in_channels=hidden_dim * heads,  # Output of layer 1
            out_channels=num_classes,
            heads=1,
            concat=False,
            dropout=dropout
        )
        
        self.dropout = dropout
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                return_attention: bool = False):
        """
        Forward pass.
        """
        # Input dropout
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Layer 1 + ELU activation
        if return_attention:
            x, attn1 = self.gat1(x, edge_index, return_attention=True)
        else:
            x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Layer 2 (no activation - raw logits)
        if return_attention:
            x, attn2 = self.gat2(x, edge_index, return_attention=True)
            return x, (attn1, attn2)
        
        x = self.gat2(x, edge_index)
        return x
    
    def get_embeddings(self, x: torch.Tensor, edge_index: torch.Tensor):
        """Get intermediate embeddings (after layer 1)."""
        x = self.gat1(x, edge_index)
        return F.elu(x)

In [None]:
# Create model
model = GAT(
    num_features=dataset.num_features,
    num_classes=dataset.num_classes,
    hidden_dim=8,
    heads=8,
    dropout=0.6
).to(device)

print("GAT Model Architecture:")
print("=" * 50)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

---

## Part 5: Training GAT

In [None]:
def train(model, data, optimizer):
    """Train for one epoch."""
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, data):
    """Evaluate on train/val/test."""
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask].eq(data.y[mask]).sum().item()
        accs.append(correct / mask.sum().item())
    return tuple(accs)

In [None]:
# Training configuration
model = GAT(
    num_features=dataset.num_features,
    num_classes=dataset.num_classes,
    hidden_dim=8,
    heads=8,
    dropout=0.6
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# Training history
history = {'loss': [], 'train_acc': [], 'val_acc': [], 'test_acc': []}

print("Training GAT on Cora...")
print("=" * 60)

best_val_acc = 0
start_time = time.time()

for epoch in range(300):
    loss = train(model, data, optimizer)
    train_acc, val_acc, test_acc = evaluate(model, data)
    
    history['loss'].append(loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['test_acc'].append(test_acc)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
        best_epoch = epoch
    
    if epoch % 50 == 0 or epoch == 299:
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
              f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}")

train_time = time.time() - start_time
print("\n" + "=" * 60)
print(f"üéâ Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
print(f"üìä Test accuracy at best val: {best_test_acc:.4f}")
print(f"‚è±Ô∏è Training time: {train_time:.1f} seconds")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['loss'], color='steelblue', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', linewidth=2)
axes[1].plot(history['val_acc'], label='Validation', linewidth=2)
axes[1].plot(history['test_acc'], label='Test', linewidth=2)
axes[1].axhline(y=0.81, color='red', linestyle='--', label='81% target')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy Curves')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Part 6: Visualizing Attention Weights

This is the exciting part - let's see which neighbors the model pays attention to!

In [None]:
# Get attention weights from trained model
model.eval()
with torch.no_grad():
    _, (layer1_attn, layer2_attn) = model(data.x, data.edge_index, return_attention=True)

# Layer 1 has 8 heads, let's look at the first head
edge_index_with_loops, alpha = layer1_attn[0]  # First head
edge_index_np = edge_index_with_loops.cpu().numpy()
alpha_np = alpha.cpu().numpy()

print(f"Edge index shape: {edge_index_np.shape}")
print(f"Attention weights shape: {alpha_np.shape}")
print(f"\nAttention statistics:")
print(f"  Min: {alpha_np.min():.4f}")
print(f"  Max: {alpha_np.max():.4f}")
print(f"  Mean: {alpha_np.mean():.4f}")
print(f"  Std: {alpha_np.std():.4f}")

In [None]:
# Visualize attention for a specific node
import networkx as nx
from torch_geometric.utils import to_networkx

def visualize_node_attention(node_id, data, edge_index, alpha, top_k=10):
    """
    Visualize attention weights for a specific node.
    
    Args:
        node_id: The node to analyze
        data: PyG data object
        edge_index: Edge index tensor [2, num_edges]
        alpha: Attention weights [num_edges]
        top_k: Number of top neighbors to highlight
    """
    # Find edges where node_id is the destination (receiving attention)
    edge_index_np = edge_index.cpu().numpy() if torch.is_tensor(edge_index) else edge_index
    alpha_np = alpha.cpu().numpy() if torch.is_tensor(alpha) else alpha
    
    mask = edge_index_np[1] == node_id
    neighbors = edge_index_np[0][mask]
    neighbor_attentions = alpha_np[mask]
    
    # Sort by attention
    sorted_idx = np.argsort(neighbor_attentions)[::-1]
    
    print(f"\nüîç Attention Analysis for Node {node_id}")
    print(f"   Label: {data.y[node_id].item()}")
    print(f"   Number of neighbors: {len(neighbors)}")
    print("\n   Top-5 Attended Neighbors:")
    print("   " + "-" * 40)
    
    labels = data.y.cpu().numpy()
    node_label = labels[node_id]
    
    for i, idx in enumerate(sorted_idx[:5]):
        neighbor = neighbors[idx]
        attn = neighbor_attentions[idx]
        neighbor_label = labels[neighbor]
        same_class = "‚úì" if neighbor_label == node_label else "‚úó"
        print(f"   {i+1}. Node {neighbor}: Œ±={attn:.4f} (class {neighbor_label}) {same_class}")
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Create subgraph with node and its neighbors
    all_nodes = [node_id] + list(neighbors)
    G = nx.DiGraph()
    G.add_nodes_from(all_nodes)
    
    for i, (neighbor, attn) in enumerate(zip(neighbors, neighbor_attentions)):
        G.add_edge(neighbor, node_id, weight=attn)
    
    pos = nx.spring_layout(G, seed=42, k=2)
    
    # Draw nodes
    node_colors = [labels[n] for n in all_nodes]
    node_sizes = [1000 if n == node_id else 500 for n in all_nodes]
    
    nx.draw_networkx_nodes(G, pos, nodelist=all_nodes, 
                          node_color=node_colors, cmap=plt.cm.Set3,
                          node_size=node_sizes, alpha=0.8)
    
    # Draw edges with width proportional to attention
    edges = G.edges()
    weights = [G[u][v]['weight'] * 5 for u, v in edges]  # Scale for visibility
    
    nx.draw_networkx_edges(G, pos, edgelist=edges, width=weights,
                          alpha=0.7, edge_color='gray',
                          arrows=True, arrowsize=15)
    
    # Labels
    nx.draw_networkx_labels(G, pos, font_size=10)
    
    ax.set_title(f"Attention to Node {node_id}\n(Edge width = attention weight)")
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    return neighbors, neighbor_attentions

# Analyze a few nodes
visualize_node_attention(0, data, edge_index_with_loops, alpha)

In [None]:
# Analyze: Do nodes pay more attention to same-class neighbors?

labels = data.y.cpu().numpy()
edge_index_np = edge_index_with_loops.cpu().numpy()

# For each edge, check if source and destination have same label
src_labels = labels[edge_index_np[0]]
dst_labels = labels[edge_index_np[1]]
same_class = src_labels == dst_labels

# Compare attention weights for same-class vs different-class edges
same_class_attn = alpha_np[same_class].mean()
diff_class_attn = alpha_np[~same_class].mean()

print("\nüìä ATTENTION PATTERN ANALYSIS")
print("=" * 50)
print(f"Mean attention to SAME-class neighbors: {same_class_attn:.4f}")
print(f"Mean attention to DIFFERENT-class neighbors: {diff_class_attn:.4f}")
print(f"\nRatio: {same_class_attn/diff_class_attn:.2f}x more attention to same class!")

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(alpha_np[same_class], bins=50, alpha=0.7, label='Same class', density=True)
axes[0].hist(alpha_np[~same_class], bins=50, alpha=0.7, label='Different class', density=True)
axes[0].set_xlabel('Attention Weight')
axes[0].set_ylabel('Density')
axes[0].set_title('Distribution of Attention Weights')
axes[0].legend()

# Box plot
axes[1].boxplot([alpha_np[same_class], alpha_np[~same_class]], 
                labels=['Same Class', 'Different Class'])
axes[1].set_ylabel('Attention Weight')
axes[1].set_title('Attention by Class Relationship')

plt.tight_layout()
plt.show()

print("\nüí° The model learned to pay MORE attention to same-class neighbors!")
print("   This helps it propagate useful information for classification.")

---

## Part 7: Comparison with GCN

In [None]:
from torch_geometric.nn import GCNConv, GATConv

# Train GCN for comparison
class PyGGCN(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=64):
        super().__init__()
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class PyGGAT(nn.Module):
    def __init__(self, num_features, num_classes, hidden_dim=8, heads=8):
        super().__init__()
        self.conv1 = GATConv(num_features, hidden_dim, heads=heads, dropout=0.6)
        self.conv2 = GATConv(hidden_dim * heads, num_classes, heads=1, 
                            concat=False, dropout=0.6)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

def train_model(Model, num_epochs=200, **kwargs):
    """Train a model and return best test accuracy."""
    model = Model(
        num_features=dataset.num_features,
        num_classes=dataset.num_classes,
        **kwargs
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    
    best_val = 0
    best_test = 0
    
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        model.eval()
        with torch.no_grad():
            pred = model(data.x, data.edge_index).argmax(dim=1)
            val_acc = (pred[data.val_mask] == data.y[data.val_mask]).float().mean().item()
            test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()
            
            if val_acc > best_val:
                best_val = val_acc
                best_test = test_acc
    
    return best_test, sum(p.numel() for p in model.parameters())

# Compare models
print("Model Comparison on Cora")
print("=" * 50)

gcn_acc, gcn_params = train_model(PyGGCN, hidden_dim=64)
print(f"GCN:  {gcn_acc:.4f} accuracy, {gcn_params:,} params")

gat_acc, gat_params = train_model(PyGGAT, hidden_dim=8, heads=8)
print(f"GAT:  {gat_acc:.4f} accuracy, {gat_params:,} params")

print("\n" + "=" * 50)
if gat_acc > gcn_acc:
    print(f"üéâ GAT outperforms GCN by {(gat_acc - gcn_acc)*100:.1f}%!")
else:
    print(f"üìä GCN outperforms GAT by {(gcn_acc - gat_acc)*100:.1f}%")

print(f"\nüí° GAT uses {gat_params/gcn_params:.1f}x {'more' if gat_params > gcn_params else 'fewer'} parameters")

---

## ‚úã Try It Yourself: Exercise 1

**Task:** Experiment with different numbers of attention heads.

Train GAT models with heads = [1, 2, 4, 8, 16] and compare:
1. Test accuracy
2. Number of parameters
3. Training time

Is more heads always better?

In [None]:
# Your code here!

head_counts = [1, 2, 4, 8, 16]
results = []

for heads in head_counts:
    # Train GAT with this number of heads
    # Record accuracy, params, time
    pass

# Plot results

<details>
<summary>üí° Hint</summary>

```python
for heads in head_counts:
    start = time.time()
    acc, params = train_model(PyGGAT, num_epochs=200, hidden_dim=8, heads=heads)
    train_time = time.time() - start
    results.append((heads, acc, params, train_time))
    print(f"Heads={heads}: {acc:.4f} acc, {params} params, {train_time:.1f}s")
```
</details>

---

## ‚úã Try It Yourself: Exercise 2

**Task:** Find the most "important" edges in the graph based on attention.

1. Compute attention weights for all edges across all heads
2. Find the top-10 edges with highest attention
3. Are these edges within the same class or between classes?
4. What does this tell us about the model's strategy?

In [None]:
# Your code here!

# Get attention from all 8 heads in layer 1
# Average attention across heads
# Find top-10 edges
# Analyze same-class vs different-class

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Softmax over wrong dimension
```python
# ‚ùå Wrong: Softmax over all edges
alpha = F.softmax(e, dim=0)  # All attention sums to 1

# ‚úÖ Right: Softmax over neighbors of each node
from torch_geometric.utils import softmax
alpha = softmax(e, dst, num_nodes=num_nodes)  # Per-node normalization
```
**Why:** Each node's attention weights should sum to 1 independently.

### Mistake 2: Not using dropout on attention
```python
# ‚ùå Missing: No dropout on attention
out = (alpha.view(-1, 1) * src_features).scatter_add(...)

# ‚úÖ Right: Apply dropout to attention weights
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = (alpha.view(-1, 1) * src_features).scatter_add(...)
```
**Why:** Attention dropout prevents overfitting to specific neighbors.

### Mistake 3: Wrong concatenation in multi-head
```python
# ‚ùå Wrong: Concatenating along wrong dimension
out = torch.cat(head_outputs, dim=0)  # Stacks nodes!

# ‚úÖ Right: Concatenate along feature dimension
out = torch.cat(head_outputs, dim=-1)  # [num_nodes, heads * dim]
```
**Why:** Multi-head concatenation should combine features, not duplicate nodes.

### Mistake 4: Using LeakyReLU with wrong negative slope
```python
# ‚ùå Unusual: Standard ReLU (all negatives become 0)
e = F.relu(e_left + e_right)

# ‚úÖ Standard: LeakyReLU with negative_slope=0.2
e = F.leaky_relu(e_left + e_right, negative_slope=0.2)
```
**Why:** LeakyReLU allows negative attention scores, which get very small (not zero) after softmax.

---

## üéâ Checkpoint

You've learned:
- ‚úÖ Why attention is better than equal weighting (learn neighbor importance)
- ‚úÖ The GAT attention formula (e_ij, softmax, weighted aggregation)
- ‚úÖ Multi-head attention (capture different relationship types)
- ‚úÖ How to visualize and interpret attention weights
- ‚úÖ GAT pays more attention to same-class neighbors!

---

## üöÄ Challenge (Optional)

**Advanced Challenge:** Implement GATv2.

The original GAT computes attention as:
```
e_ij = LeakyReLU(a_left * Wh_i + a_right * Wh_j)
```

GATv2 (Brody et al., 2021) uses:
```
e_ij = a^T * LeakyReLU(W * [h_i || h_j])
```

The LeakyReLU is applied AFTER concatenation, giving more expressive attention.

Implement GATv2 and compare to GAT on Cora!

In [None]:
# Advanced Challenge: GATv2

class GATv2Layer(nn.Module):
    """
    GATv2: Improved attention mechanism.
    
    Key difference: LeakyReLU is applied AFTER concatenation.
    """
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Your code here!
        pass
    
    def forward(self, x, edge_index):
        # Your code here!
        pass

---

## üìñ Further Reading

- [GAT Paper](https://arxiv.org/abs/1710.10903) - Original 2018 paper
- [GATv2 Paper](https://arxiv.org/abs/2105.14491) - Improved attention (2021)
- [Attention in Graphs Survey](https://arxiv.org/abs/2202.13060) - Comprehensive review
- [PyG GATConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv)

---

## üßπ Cleanup

In [None]:
# Clear GPU memory
import gc

del model, layer, mh_layer
del layer1_attn, layer2_attn, alpha, edge_index_with_loops

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1e6:.1f} MB")

print("‚úÖ Cleanup complete!")

---

## ‚è≠Ô∏è Next Steps

So far we've classified **nodes**. But what about classifying entire **graphs**?

**In Lab E.4: Graph Classification**, you'll:
- Learn about graph-level pooling operations
- Build classifiers for molecules and social networks
- Implement mean, max, and attention-based pooling
- Predict molecular properties on the MUTAG dataset!

Let's classify some molecules! üß™