# Tutorial: Node Classification and Link Prediction in Citation Networks

## Applying Graph Machine Learning to Academic Paper Analysis

**Authors**: Denis Troegubov  
**Date**: December 2025  
**GitHub**: [Link](https://github.com/BogGoro/predicting-paper-topics-and-connections-tutorial)

---

## Introduction

Welcome to this hands-on tutorial on **Graph Machine Learning**!  
In this notebook, we will implement and compare two popular **Graph Neural Network (GNN)** architectures **GraphSAGE** and **GATv2** on the **Cora** citation network dataset.

We will solve two key graph learning tasks:

1. **Node Classification**: Predicting the research topic of academic papers  
2. **Link Prediction**: Recommending potential citations between papers  

The notebook is structured as follows:
- Data exploration and visualization
- Implementation of modular GNN systems for each task
- Training and evaluation pipelines
- **Comparative analysis** across 100 independent runs to assess model stability and performance

We use **PyTorch Geometric (PyG)** for graph learning and include **statistical tests** to validate differences between models.

### Prerequisites
- Basic knowledge of Python and PyTorch
- Familiarity with neural networks
- No prior experience with graphs required

Let's get started

1. Environment Setup

We begin by installing and importing necessary libraries, setting random seeds for reproducibility, and checking GPU availability.

In [None]:
# Install PyTorch Geometric if you are running in colab
!pip install torch-geometric

In [None]:
# Import libraries
import os
import torch
import torch.nn.functional as F
import torch.nn as nn

# PyTorch Geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv, GCNConv
from torch_geometric.utils import negative_sampling, to_networkx
from torch_geometric.data import Data

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix
from sklearn.manifold import TSNE

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create directory for visualizations
try:
    os.mkdir("images")
except Exception:
    print("Directory already exists")

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Understanding the Data: Cora Dataset

### What is Cora?
The Cora dataset is a classic citation network consisting of machine learning papers. It's widely used as a benchmark in graph ML research.

### Dataset Statistics:
- **Nodes**: 2,708 academic papers
- **Edges**: 10,556 citation links (directed)
- **Features**: 1,433-dimensional binary word vectors (bag-of-words)
- **Classes**: 7 research topics

In [None]:
# Load the Cora dataset
dataset = Planetoid(root="/tmp/Cora", name="Cora")
data = dataset[0]

print("Dataset Information:")
print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Has isolated nodes: {data.has_isolated_nodes()}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")

In [None]:
# Visualize the class distribution
class_names = [
    "Case-Based",
    "Genetic Algorithms",
    "Neural Networks",
    "Probabilistic Methods",
    "Reinforcement Learning",
    "Rule Learning",
    "Theory",
]

class_counts = torch.bincount(data.y).numpy()

plt.figure(figsize=(10, 6))
bars = plt.bar(class_names, class_counts, color=sns.color_palette("husl", 7))
plt.title(
    "Distribution of Paper Topics in Cora Dataset", fontsize=14, fontweight="bold"
)
plt.xlabel("Research Topic", fontsize=12)
plt.ylabel("Number of Papers", fontsize=12)
plt.xticks(rotation=45, ha="right")

# Add count labels on bars
for bar, count in zip(bars, class_counts):
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 5,
        f"{count}",
        ha="center",
        va="bottom",
    )

plt.tight_layout()
plt.savefig("images/class_distribution.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Visualize a subgraph of the citation network
def visualize_citation_subgraph(data, num_nodes=100):
    """Visualize a small subgraph of the citation network"""
    # Take first num_nodes nodes
    subgraph_nodes = torch.arange(num_nodes)

    # Create mask for edges between these nodes
    mask = (data.edge_index[0] < num_nodes) & (data.edge_index[1] < num_nodes)
    subgraph_edges = data.edge_index[:, mask]

    # Create subgraph
    subgraph = Data(
        x=data.x[:num_nodes], edge_index=subgraph_edges, y=data.y[:num_nodes]
    )

    # Convert to NetworkX for visualization
    G = to_networkx(subgraph, to_undirected=True)

    # Create visualization
    plt.figure(figsize=(12, 8))

    # Node colors by class
    node_colors = [data.y[i].item() for i in range(num_nodes)]

    pos = nx.spring_layout(G, seed=42)
    nx.draw_networkx_nodes(
        G, pos, node_size=50, node_color=node_colors, cmap=plt.cm.Set2, alpha=0.8
    )
    nx.draw_networkx_edges(G, pos, alpha=0.2, width=0.5)

    plt.title(
        f"Citation Network Subgraph (First {num_nodes} Papers)",
        fontsize=14,
        fontweight="bold",
    )
    plt.axis("off")

    # Create legend for classes
    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=plt.cm.Set2(i / 7),
            markersize=10,
            label=class_names[i],
        )
        for i in range(7)
    ]
    plt.legend(
        handles=legend_elements,
        title="Research Topics",
        bbox_to_anchor=(1.05, 1),
        loc="upper left",
    )

    plt.tight_layout()
    plt.savefig("images/citation_subgraph.png", dpi=150, bbox_inches="tight")
    plt.show()

    return G


# Visualize subgraph
G = visualize_citation_subgraph(data, num_nodes=1000)

## 3. Graph Neural Networks: Core Concepts
GNNs extend deep learning to graph-structured data through **message passing**:
1. Gather neighbor features
2. Aggregate them (sum, mean, max)
3. Update node representations

We implement:
- **GraphSAGE**: Inductive, scalable, supports neighborhood sampling
- **GATv2**: Uses attention to weigh neighbor importance

## 4. Model Architecture
We design **separate model systems** for node classification and link prediction to prevent task interference.

### GraphSAGE System:
- `NodeEncoder` + `GraphSAGEClassifier` (with focal loss support)
- `LinkEncoder` + `LinkPredictor` (MLP-based)

### GATv2 System:
- `GATv2Encoder` + `GATv2Classifier` (multi-head attention)
- `GATv2LinkEncoder` + `GATv2LinkPredictor` (attention-enhanced)

Each system is wrapped in a unified class (`GraphSAGEModels`, `GATv2Models`) for easier training and evaluation.

In [None]:
class NodeEncoder(nn.Module):
    """Encoder specifically for node classification"""

    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [None]:
class LinkEncoder(nn.Module):
    """Encoder specifically optimized for link prediction"""

    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        # First SAGE layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)

        # Second SAGE layer
        x = self.conv2(x, edge_index)

        return x

    def get_embeddings(self, x, edge_index):
        """Get intermediate node embeddings"""
        with torch.no_grad():
            embeddings = self.conv1(x, edge_index)
            embeddings = F.relu(embeddings)
        return embeddings

In [None]:
class GraphSAGEClassifier(nn.Module):
    """GraphSAGE with focal loss for handling class imbalance"""

    def __init__(
        self, hidden_channels, num_classes, dropout=0.6, gamma=2.0, alpha=None
    ):
        super().__init__()
        self.sage1 = SAGEConv(hidden_channels, hidden_channels)
        self.sage2 = SAGEConv(hidden_channels, hidden_channels // 2)
        self.sage3 = SAGEConv(hidden_channels // 2, num_classes)

        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.bn2 = nn.BatchNorm1d(hidden_channels // 2)
        self.dropout = nn.Dropout(dropout)

        # Focal loss parameters
        self.gamma = gamma
        self.alpha = alpha  # Can be list of per-class weights

    def forward(self, x, edge_index):
        x = F.relu(self.bn1(self.sage1(x, edge_index)))
        x = self.dropout(x)

        x = F.relu(self.bn2(self.sage2(x, edge_index)))
        x = self.dropout(x)

        x = self.sage3(x, edge_index)
        return x  # Return logits

    def focal_loss(self, logits, labels):
        """Focal loss for imbalanced datasets"""
        ce_loss = F.cross_entropy(logits, labels, reduction="none")
        pt = torch.exp(-ce_loss)

        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)
            alpha_weight = alpha[labels]
            focal_loss = alpha_weight * focal_loss

        return focal_loss.mean()

In [None]:
# Define the Link Predictor
class LinkPredictor(nn.Module):
    """Link Predictor optimized for GraphSAGE embeddings"""

    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        # GraphSAGE benefits from deeper MLP since embeddings are simpler
        self.lin1 = nn.Linear(in_channels * 2, hidden_channels * 2)
        self.lin2 = nn.Linear(hidden_channels * 2, hidden_channels)
        self.lin3 = nn.Linear(hidden_channels, 1)

        self.bn1 = nn.BatchNorm1d(hidden_channels * 2)
        self.bn2 = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(0.4)

    def forward(self, z_src, z_dst):
        # Concatenate source and destination embeddings
        x = torch.cat([z_src, z_dst], dim=1)

        # Deeper MLP suitable for GraphSAGE
        x = self.lin1(x)
        x = F.relu(self.bn1(x))
        x = self.dropout(x)

        x = self.lin2(x)
        x = F.relu(self.bn2(x))
        x = self.dropout(x)

        x = self.lin3(x)
        return x.squeeze()

In [None]:
class GATv2Encoder(nn.Module):
    """GATv2 Encoder for node classification"""

    def __init__(
        self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.2
    ):
        super().__init__()
        from torch_geometric.nn import GATv2Conv

        self.conv1 = GATv2Conv(
            in_channels, hidden_channels, heads=heads, dropout=dropout, concat=True
        )
        self.conv2 = GATv2Conv(
            hidden_channels * heads,
            out_channels,
            heads=1,
            dropout=dropout,
            concat=False,
        )
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None  # Store attention weights for analysis

    def forward(self, x, edge_index, return_attention=False):
        # Store attention weights if requested
        x, attn1 = self.conv1(x, edge_index, return_attention_weights=True)
        x = F.elu(x)
        x = self.dropout(x)

        x, attn2 = self.conv2(x, edge_index, return_attention_weights=True)

        if return_attention:
            self.attention_weights = {"layer1": attn1, "layer2": attn2}

        return x

In [None]:
class GATv2LinkEncoder(nn.Module):
    """GATv2 Encoder specifically optimized for link prediction"""

    def __init__(
        self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.2
    ):
        super().__init__()
        from torch_geometric.nn import GATv2Conv

        # First layer with multiple heads for rich feature extraction
        self.conv1 = GATv2Conv(
            in_channels, hidden_channels, heads=heads, dropout=dropout, concat=True
        )
        # Second layer with attention to focus on relevant neighbors for link prediction
        self.conv2 = GATv2Conv(
            hidden_channels * heads,
            hidden_channels,
            heads=heads,
            dropout=dropout,
            concat=True,
        )
        # Final layer to produce link-focused embeddings
        self.conv3 = GATv2Conv(
            hidden_channels * heads,
            out_channels,
            heads=1,
            dropout=dropout,
            concat=False,
        )

        self.bn1 = nn.BatchNorm1d(hidden_channels * heads)
        self.bn2 = nn.BatchNorm1d(hidden_channels * heads)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        # Layer 1: Initial feature transformation
        x = self.conv1(x, edge_index)
        x = F.elu(x + 1e-8)
        x = self.dropout(x)

        # Layer 2: Neighborhood aggregation with attention
        x = self.conv2(x, edge_index)
        x = F.elu(x + 1e-8)
        x = self.dropout(x)

        # Layer 3: Final embedding for link prediction
        x = self.conv3(x, edge_index)

        return x

In [None]:
class GATv2LinkPredictor(nn.Module):
    """Link Predictor optimized for GATv2 embeddings with attention mechanism"""

    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        # GATv2 embeddings are already rich, so we use a shallower network
        # with attention-like operations
        self.lin_transform = nn.Linear(in_channels * 2, hidden_channels)
        self.lin_attention = nn.Linear(hidden_channels, 1)
        self.lin_output = nn.Linear(hidden_channels, 1)

        self.bn = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(0.2)

    def forward(self, z_src, z_dst):
        # Concatenate embeddings
        x = torch.cat([z_src, z_dst], dim=1)

        # Transform to hidden space
        x = self.lin_transform(x)
        x = F.leaky_relu(self.bn(x), negative_slope=0.2)
        x = self.dropout(x)

        # Attention-like scoring
        attention_scores = torch.sigmoid(self.lin_attention(x))
        x = x * attention_scores

        # Final prediction
        x = self.lin_output(x)
        return x.squeeze()

In [None]:
class GATv2Classifier(nn.Module):
    """GATv2 classifier with attention mechanisms"""

    def __init__(self, hidden_channels, num_classes, heads=8, dropout=0.2):
        super().__init__()
        from torch_geometric.nn import GATv2Conv

        self.gat1 = GATv2Conv(
            hidden_channels, hidden_channels, heads=heads, dropout=dropout, concat=True
        )
        self.gat2 = GATv2Conv(
            hidden_channels * heads,
            hidden_channels,
            heads=heads,
            dropout=dropout,
            concat=True,
        )
        self.gat3 = GATv2Conv(
            hidden_channels * heads, num_classes, heads=1, dropout=dropout, concat=False
        )

        self.bn1 = nn.BatchNorm1d(hidden_channels * heads)
        self.bn2 = nn.BatchNorm1d(hidden_channels * heads)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        x = F.elu(self.bn1(self.gat1(x, edge_index)))
        x = self.dropout(x)

        x = F.elu(self.bn2(self.gat2(x, edge_index)))
        x = self.dropout(x)

        x = self.gat3(x, edge_index)
        return x

In [None]:
class GraphSAGEModels(nn.Module):
    """Complete GraphSAGE model system"""

    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()

        # Node classification components
        self.node_encoder = NodeEncoder(in_channels, hidden_channels, hidden_channels)
        self.node_classifier = GraphSAGEClassifier(hidden_channels, num_classes)

        # Link prediction components (GraphSAGE optimized)
        self.link_encoder = LinkEncoder(in_channels, hidden_channels, hidden_channels)
        self.link_predictor = LinkPredictor(
            hidden_channels, hidden_channels // 2
        )

    def forward_node(self, x, edge_index):
        """Forward pass for node classification"""
        z = self.node_encoder(x, edge_index)
        node_logits = self.node_classifier(z, edge_index)
        return z, node_logits

    def forward_link(self, x, edge_index, edge_label_index=None):
        """Forward pass for link prediction"""
        z = self.link_encoder(x, edge_index)
        if edge_label_index is None:
            edge_label_index = edge_index

        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        link_pred = self.link_predictor(src, dst)

        return z, link_pred

    def get_node_embeddings(self, x, edge_index):
        """Get node embeddings for analysis"""
        with torch.no_grad():
            z = self.node_encoder(x, edge_index)
        return z

    def get_link_embeddings(self, x, edge_index):
        """Get link-focused embeddings"""
        with torch.no_grad():
            z = self.link_encoder(x, edge_index)
        return z

In [None]:
class GATv2Models(nn.Module):
    """Complete GATv2 model system"""

    def __init__(self, in_channels, hidden_channels, num_classes, heads=8):
        super().__init__()

        # Node classification components
        self.node_encoder = GATv2Encoder(
            in_channels, hidden_channels, hidden_channels, heads
        )
        self.node_classifier = GATv2Classifier(hidden_channels, num_classes, heads)

        # Link prediction components (GATv2 optimized)
        self.link_encoder = GATv2LinkEncoder(
            in_channels, hidden_channels, hidden_channels, heads // 2
        )
        self.link_predictor = GATv2LinkPredictor(hidden_channels, hidden_channels // 2)
        # Alternative: self.link_predictor = MultiHeadGATv2LinkPredictor(hidden_channels, hidden_channels // 2)

    def forward_node(self, x, edge_index, return_attention=False):
        """Forward pass for node classification"""
        z = self.node_encoder(x, edge_index, return_attention=return_attention)
        node_logits = self.node_classifier(z, edge_index)
        return z, node_logits

    def forward_link(self, x, edge_index, edge_label_index=None):
        """Forward pass for link prediction"""
        z = self.link_encoder(x, edge_index)
        if edge_label_index is None:
            edge_label_index = edge_index

        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        link_pred = self.link_predictor(src, dst)

        return z, link_pred

    def get_attention_weights(self, x, edge_index, node_idx=None):
        """Get attention weights for specific nodes"""
        self.node_encoder.eval()
        with torch.no_grad():
            _, _ = self.forward_node(x, edge_index, return_attention=True)

            if node_idx is not None:
                # Extract attention for specific node
                attention_data = {}
                for layer_name, (
                    edge_index_layer,
                    attn_weights,
                ) in self.node_encoder.attention_weights.items():
                    # Find edges where node_idx is the target
                    mask = edge_index_layer[1] == node_idx
                    attention_data[layer_name] = {
                        "neighbors": edge_index_layer[0][mask].cpu().numpy(),
                        "weights": attn_weights[mask].cpu().numpy(),
                    }
                return attention_data
            else:
                return self.node_encoder.attention_weights

    def get_node_embeddings(self, x, edge_index):
        """Get node embeddings for analysis"""
        with torch.no_grad():
            z, _ = self.forward_node(x, edge_index)
        return z

    def get_link_embeddings(self, x, edge_index):
        """Get link-focused embeddings"""
        with torch.no_grad():
            z = self.link_encoder(x, edge_index)
        return z

## 5. Data Preparation for Link Prediction

For link prediction, we need to:
1. Split edges into train/val/test sets
2. Create negative examples (non-existent edges)
3. Balance positive and negative samples

In [None]:
def prepare_link_prediction_data(data, val_ratio=0.1, test_ratio=0.1):
    """Prepare data for link prediction task"""
    edge_index = data.edge_index

    # Make edges undirected and unique
    row, col = edge_index
    mask = row < col
    row, col = row[mask], col[mask]

    n_edges = row.size(0)

    # Random permutation
    perm = torch.randperm(n_edges)
    row, col = row[perm], col[perm]

    # Split sizes
    n_val = int(n_edges * val_ratio)
    n_test = int(n_edges * test_ratio)

    # Create splits
    val_edges_pos = torch.stack([row[:n_val], col[:n_val]], dim=0)
    test_edges_pos = torch.stack(
        [row[n_val : n_val + n_test], col[n_val : n_val + n_test]], dim=0
    )
    train_edges_pos = torch.stack([row[n_val + n_test :], col[n_val + n_test :]], dim=0)

    # Training edges (all except test)
    train_edge_index = torch.cat([train_edges_pos, val_edges_pos], dim=1)

    print("Link Prediction Data Split:")
    print(f"Training edges: {train_edge_index.size(1)}")
    print(f"Validation positive edges: {val_edges_pos.size(1)}")
    print(f"Test positive edges: {test_edges_pos.size(1)}")

    return {
        "train_edge_index": train_edge_index,
        "train_edges_pos": train_edges_pos,
        "val_edges_pos": val_edges_pos,
        "test_edges_pos": test_edges_pos,
    }


# Prepare link prediction data
link_data = prepare_link_prediction_data(data)

## 6. Training Functions
We define separate training loops for:
- `train_graphsage_models`
- `train_gatv2_models`

Each function:
- Trains the node classification module first
- Then trains the link prediction module
- Uses separate optimizers and loss functions per task
- Saves the best model based on validation performance

In [None]:
def train_graphsage_models(
    sage_models, data, link_data, epochs_node=100, epochs_link=100
):
    """Complete training pipeline for GraphSAGE models"""
    print("\nTraining GraphSAGE Model System\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sage_models = sage_models.to(device)
    data = data.to(device)

    # Convert link_data to device
    link_data_device = {}
    for key, value in link_data.items():
        link_data_device[key] = value.to(device)

    # Phase 1: Train node classification
    print("\n1. Training Node Classification...")

    node_params = list(sage_models.node_encoder.parameters()) + list(
        sage_models.node_classifier.parameters()
    )
    node_optimizer = torch.optim.Adam(node_params, lr=0.01, weight_decay=5e-4)
    criterion_node = nn.CrossEntropyLoss()

    node_losses, node_accs = [], []
    best_node_acc = 0

    for epoch in range(epochs_node):
        sage_models.node_encoder.train()
        sage_models.node_classifier.train()
        node_optimizer.zero_grad()

        # Forward through GraphSAGE node models
        _, node_logits = sage_models.forward_node(data.x, data.edge_index)

        # Compute loss
        loss = criterion_node(node_logits[data.train_mask], data.y[data.train_mask])
        loss.backward()
        node_optimizer.step()

        # Validation
        sage_models.node_encoder.eval()
        sage_models.node_classifier.eval()
        with torch.no_grad():
            _, node_logits = sage_models.forward_node(data.x, data.edge_index)
            val_pred = node_logits[data.val_mask].argmax(dim=1)
            val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()

            if val_acc > best_node_acc:
                best_node_acc = val_acc
                best_node_state = {
                    "encoder": sage_models.node_encoder.state_dict(),
                    "classifier": sage_models.node_classifier.state_dict(),
                }

        node_losses.append(loss.item())
        node_accs.append(val_acc)

        if (epoch + 1) % 20 == 0:
            print(f"Node Epoch {epoch+1:03d}: Loss={loss:.4f}, Val Acc={val_acc:.4f}")

    # Load best node model
    sage_models.node_encoder.load_state_dict(best_node_state["encoder"])
    sage_models.node_classifier.load_state_dict(best_node_state["classifier"])

    # Phase 2: Train link prediction
    print("\n2. Training Link Prediction...")

    link_params = list(sage_models.link_encoder.parameters()) + list(
        sage_models.link_predictor.parameters()
    )
    link_optimizer = torch.optim.Adam(link_params, lr=0.01, weight_decay=1e-4)
    criterion_link = nn.BCEWithLogitsLoss()

    # Prepare training data for link prediction
    pos_edges = link_data_device["train_edges_pos"]
    neg_edges = negative_sampling(
        edge_index=link_data_device["train_edge_index"],
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edges.size(1),
    ).to(device)

    train_edges = torch.cat([pos_edges, neg_edges], dim=1)
    train_labels = torch.cat(
        [
            torch.ones(pos_edges.size(1), device=device),
            torch.zeros(neg_edges.size(1), device=device),
        ]
    )

    link_losses = []

    for epoch in range(epochs_link):
        sage_models.link_encoder.train()
        sage_models.link_predictor.train()
        link_optimizer.zero_grad()

        # Forward through GraphSAGE link models
        _, link_scores = sage_models.forward_link(
            data.x, link_data_device["train_edge_index"], train_edges
        )

        # Compute loss
        loss = criterion_link(link_scores, train_labels)
        loss.backward()
        link_optimizer.step()

        link_losses.append(loss.item())

        if (epoch + 1) % 20 == 0:
            print(f"Link Epoch {epoch+1:03d}: Loss={loss.item():.4f}")

    print(f"\nGraphSAGE Training Completed:")
    print(f"  Best Node Accuracy: {best_node_acc:.4f}")
    print(f"  Final Link Loss: {link_losses[-1]:.4f}")

    return sage_models, node_losses, node_accs, link_losses

In [None]:
def train_gatv2_models(gatv2_models, data, link_data, epochs_node=100, epochs_link=100):
    """Complete training pipeline for GATv2 models"""
    print("\nTraining GATv2 Model System\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gatv2_models = gatv2_models.to(device)
    data = data.to(device)

    # Convert link_data to device
    link_data_device = {}
    for key, value in link_data.items():
        link_data_device[key] = value.to(device)

    # Phase 1: Train node classification with attention
    print("\n1. Training Node Classification with Attention...")

    node_params = list(gatv2_models.node_encoder.parameters()) + list(
        gatv2_models.node_classifier.parameters()
    )
    node_optimizer = torch.optim.Adam(node_params, lr=0.005, weight_decay=5e-4)
    criterion_node = nn.CrossEntropyLoss()

    node_losses, node_accs = [], []
    best_node_acc = 0

    for epoch in range(epochs_node):
        gatv2_models.node_encoder.train()
        gatv2_models.node_classifier.train()
        node_optimizer.zero_grad()

        # Forward through GATv2 node models
        _, node_logits = gatv2_models.forward_node(data.x, data.edge_index)

        # Compute loss
        loss = criterion_node(node_logits[data.train_mask], data.y[data.train_mask])
        loss.backward()
        node_optimizer.step()

        # Validation
        gatv2_models.node_encoder.eval()
        gatv2_models.node_classifier.eval()
        with torch.no_grad():
            _, node_logits = gatv2_models.forward_node(data.x, data.edge_index)
            val_pred = node_logits[data.val_mask].argmax(dim=1)
            val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()

            if val_acc > best_node_acc:
                best_node_acc = val_acc
                best_node_state = {
                    "encoder": gatv2_models.node_encoder.state_dict(),
                    "classifier": gatv2_models.node_classifier.state_dict(),
                }

        node_losses.append(loss.item())
        node_accs.append(val_acc)

        if (epoch + 1) % 20 == 0:
            print(f"Node Epoch {epoch+1:03d}: Loss={loss:.4f}, Val Acc={val_acc:.4f}")

    # Load best node model
    gatv2_models.node_encoder.load_state_dict(best_node_state["encoder"])
    gatv2_models.node_classifier.load_state_dict(best_node_state["classifier"])

    # Phase 2: Train link prediction with attention-based predictor
    print("\n2. Training Link Prediction with Attention Mechanisms...")

    link_params = list(gatv2_models.link_encoder.parameters()) + list(
        gatv2_models.link_predictor.parameters()
    )
    link_optimizer = torch.optim.Adam(link_params, lr=0.005, weight_decay=1e-4)
    criterion_link = nn.BCEWithLogitsLoss()

    # Prepare training data for link prediction
    pos_edges = link_data_device["train_edges_pos"]
    neg_edges = negative_sampling(
        edge_index=link_data_device["train_edge_index"],
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edges.size(1),
    ).to(device)

    train_edges = torch.cat([pos_edges, neg_edges], dim=1)
    train_labels = torch.cat(
        [
            torch.ones(pos_edges.size(1), device=device),
            torch.zeros(neg_edges.size(1), device=device),
        ]
    )

    link_losses = []

    for epoch in range(epochs_link):
        gatv2_models.link_encoder.train()
        gatv2_models.link_predictor.train()
        link_optimizer.zero_grad()

        # Forward through GATv2 link models
        _, link_scores = gatv2_models.forward_link(
            data.x, link_data_device["train_edge_index"], train_edges
        )

        # Compute loss
        loss = criterion_link(link_scores, train_labels)
        loss.backward()
        link_optimizer.step()

        link_losses.append(loss.item())

        if (epoch + 1) % 20 == 0:
            print(f"Link Epoch {epoch+1:03d}: Loss={loss.item():.4f}")

    print(f"\nGATv2 Training Completed:")
    print(f"  Best Node Accuracy: {best_node_acc:.4f}")
    print(f"  Final Link Loss: {link_losses[-1]:.4f}")

    return gatv2_models, node_losses, node_accs, link_losses

## 7. Evaluation Functions
We evaluate both tasks:
- **Node classification**: Accuracy on test set
- **Link prediction**: AUC-ROC, Average Precision (AP), Precision@k

Evaluation functions:
- `evaluate_graphsage_models`
- `evaluate_gatv2_models`

In [None]:
def evaluate_graphsage_models(sage_models, data, link_data):
    """Evaluate GraphSAGE models on both tasks"""
    sage_models.eval()
    device = data.x.device  # Get the device from data

    # Convert link_data to the same device
    link_data_device = {}
    for key, value in link_data.items():
        link_data_device[key] = value.to(device)

    with torch.no_grad():
        # Node Classification Evaluation
        _, node_logits = sage_models.forward_node(data.x, data.edge_index)
        test_pred = node_logits[data.test_mask].argmax(dim=1)
        node_acc = (test_pred == data.y[data.test_mask]).float().mean().item()

        # Link Prediction Evaluation
        z = sage_models.get_link_embeddings(data.x, link_data_device["train_edge_index"])

        # Positive test edges
        pos_edges = link_data_device["test_edges_pos"]
        src_pos = z[pos_edges[0]]
        dst_pos = z[pos_edges[1]]
        pos_scores = torch.sigmoid(sage_models.link_predictor(src_pos, dst_pos))

        # Negative test edges
        neg_edges = negative_sampling(
            edge_index=torch.cat(
                [link_data_device["train_edge_index"], link_data_device["val_edges_pos"]], dim=1
            ),
            num_nodes=data.num_nodes,
            num_neg_samples=pos_edges.size(1),
        ).to(device)

        src_neg = z[neg_edges[0]]
        dst_neg = z[neg_edges[1]]
        neg_scores = torch.sigmoid(sage_models.link_predictor(src_neg, dst_neg))

        # Combine predictions
        all_scores = torch.cat([pos_scores, neg_scores]).cpu().numpy()
        all_labels = torch.cat([
            torch.ones(pos_scores.size(0), device=device),
            torch.zeros(neg_scores.size(0), device=device)
        ]).cpu().numpy()

        # Compute metrics
        auc = roc_auc_score(all_labels, all_scores)
        ap = average_precision_score(all_labels, all_scores)

        # Precision@k
        k = min(100, len(all_scores) // 2)
        top_k_idx = np.argsort(all_scores)[-k:]
        precision_at_k = all_labels[top_k_idx].mean()

    return {
        'node_accuracy': node_acc,
        'link_auc': auc,
        'link_ap': ap,
        'link_precision_at_k': precision_at_k
    }

In [None]:
def evaluate_gatv2_models(gatv2_models, data, link_data):
    """Evaluate GATv2 models on both tasks"""
    gatv2_models.eval()
    device = data.x.device  # Get the device from data

    # Convert link_data to the same device
    link_data_device = {}
    for key, value in link_data.items():
        link_data_device[key] = value.to(device)

    with torch.no_grad():
        # Node Classification Evaluation
        _, node_logits = gatv2_models.forward_node(data.x, data.edge_index)
        test_pred = node_logits[data.test_mask].argmax(dim=1)
        node_acc = (test_pred == data.y[data.test_mask]).float().mean().item()

        # Link Prediction Evaluation
        z = gatv2_models.get_link_embeddings(data.x, link_data_device["train_edge_index"])

        # Positive test edges
        pos_edges = link_data_device["test_edges_pos"]
        src_pos = z[pos_edges[0]]
        dst_pos = z[pos_edges[1]]
        pos_scores = torch.sigmoid(gatv2_models.link_predictor(src_pos, dst_pos))

        # Negative test edges
        neg_edges = negative_sampling(
            edge_index=torch.cat(
                [link_data_device["train_edge_index"], link_data_device["val_edges_pos"]], dim=1
            ),
            num_nodes=data.num_nodes,
            num_neg_samples=pos_edges.size(1),
        ).to(device)

        src_neg = z[neg_edges[0]]
        dst_neg = z[neg_edges[1]]
        neg_scores = torch.sigmoid(gatv2_models.link_predictor(src_neg, dst_neg))

        # Combine predictions
        all_scores = torch.cat([pos_scores, neg_scores]).cpu().numpy()
        all_labels = torch.cat([
            torch.ones(pos_scores.size(0), device=device),
            torch.zeros(neg_scores.size(0), device=device)
        ]).cpu().numpy()

        # Compute metrics
        auc = roc_auc_score(all_labels, all_scores)
        ap = average_precision_score(all_labels, all_scores)

        # Precision@k
        k = min(100, len(all_scores) // 2)
        top_k_idx = np.argsort(all_scores)[-k:]
        precision_at_k = all_labels[top_k_idx].mean()

    return {
        'node_accuracy': node_acc,
        'link_auc': auc,
        'link_ap': ap,
        'link_precision_at_k': precision_at_k
    }

## 8. Model Initialization
We initialize both GraphSAGE and GATv2 model systems, print their parameter counts, and compare architectural complexity.

In [None]:
# Initialize both model systems
print("Initializing GraphSAGE and GATv2 model systems...")

# GraphSAGE models
sage_models = GraphSAGEModels(
    in_channels=data.num_features, hidden_channels=128, num_classes=dataset.num_classes
)

# GATv2 models
gatv2_models = GATv2Models(
    in_channels=data.num_features,
    hidden_channels=128,
    num_classes=dataset.num_classes,
    heads=8,
)

print(f"\nModel Parameter Counts:")
print(
    f"GraphSAGE Total Parameters: {sum(p.numel() for p in sage_models.parameters()):,}"
)
print(
    f"  Node Encoder: {sum(p.numel() for p in sage_models.node_encoder.parameters()):,}"
)
print(
    f"  Node Classifier: {sum(p.numel() for p in sage_models.node_classifier.parameters()):,}"
)
print(
    f"  Link Encoder: {sum(p.numel() for p in sage_models.link_encoder.parameters()):,}"
)
print(
    f"  Link Predictor: {sum(p.numel() for p in sage_models.link_predictor.parameters()):,}"
)

print(
    f"\nGATv2 Total Parameters: {sum(p.numel() for p in gatv2_models.parameters()):,}"
)
print(
    f"  Node Encoder: {sum(p.numel() for p in gatv2_models.node_encoder.parameters()):,}"
)
print(
    f"  Node Classifier: {sum(p.numel() for p in gatv2_models.node_classifier.parameters()):,}"
)
print(
    f"  Link Encoder: {sum(p.numel() for p in gatv2_models.link_encoder.parameters()):,}"
)
print(
    f"  Link Predictor: {sum(p.numel() for p in gatv2_models.link_predictor.parameters()):,}"
)

## 9. Comparative Analysis: Multi-Run Experiment
We run **100 independent training sessions** for each model to:
- Measure performance stability
- Compare average accuracy, AUC, training time
- Perform statistical significance tests (paired t-tests)
- Visualize distributions, learning curves, and performance trade-offs

Key outputs:
- Summary table of mean ± std metrics
- Box plots of accuracy and AUC distributions
- Learning curves with standard deviation bands
- Pareto frontier analysis (Accuracy vs. AUC)

In [None]:
def plot_multiple_runs_results(all_results, training_times, num_runs):
    """Visualize results from multiple runs"""

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # 1. Node Accuracy Distribution
    ax = axes[0, 0]
    positions = [1, 2]
    sage_accs = all_results["GraphSAGE"]["node_accuracies"]
    gatv2_accs = all_results["GATv2"]["node_accuracies"]

    bp = ax.boxplot(
        [sage_accs, gatv2_accs], positions=positions, widths=0.6, patch_artist=True
    )

    # Color boxes
    colors = ["tab:blue", "tab:orange"]
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax.set_xticks(positions)
    ax.set_xticklabels(["GraphSAGE", "GATv2"])
    ax.set_title(
        f"Node Accuracy Distribution\n({num_runs} runs)", fontsize=12, fontweight="bold"
    )
    ax.set_ylabel("Accuracy")
    ax.grid(True, alpha=0.3, axis="y")

    # Add mean points
    ax.scatter(
        [1], [np.mean(sage_accs)], color="darkblue", s=100, marker="D", label="Mean"
    )
    ax.scatter([2], [np.mean(gatv2_accs)], color="darkorange", s=100, marker="D")

    # 2. Link AUC Distribution
    ax = axes[0, 1]
    sage_aucs = all_results["GraphSAGE"]["link_aucs"]
    gatv2_aucs = all_results["GATv2"]["link_aucs"]

    bp = ax.boxplot(
        [sage_aucs, gatv2_aucs], positions=positions, widths=0.6, patch_artist=True
    )

    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax.set_xticks(positions)
    ax.set_xticklabels(["GraphSAGE", "GATv2"])
    ax.set_title(
        f"Link Prediction AUC Distribution\n({num_runs} runs)", fontsize=12, fontweight="bold"
    )
    ax.set_ylabel("AUC-ROC")
    ax.grid(True, alpha=0.3, axis="y")

    # Add mean points
    ax.scatter([1], [np.mean(sage_aucs)], color="darkblue", s=100, marker="D")
    ax.scatter([2], [np.mean(gatv2_aucs)], color="darkorange", s=100, marker="D")

    # 3. Training Time Comparison
    ax = axes[0, 2]
    sage_times = training_times["GraphSAGE"]
    gatv2_times = training_times["GATv2"]

    bp = ax.boxplot(
        [sage_times, gatv2_times], positions=positions, widths=0.6, patch_artist=True
    )

    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax.set_xticks(positions)
    ax.set_xticklabels(["GraphSAGE", "GATv2"])
    ax.set_title(
        f"Training Time Distribution\n({num_runs} runs)", fontsize=12, fontweight="bold"
    )
    ax.set_ylabel("Time (seconds)")
    ax.grid(True, alpha=0.3, axis="y")

    # Add mean points
    ax.scatter([1], [np.mean(sage_times)], color="darkblue", s=100, marker="D")
    ax.scatter([2], [np.mean(gatv2_times)], color="darkorange", s=100, marker="D")

    # 4. Average Training Curves (Node Loss)
    ax = axes[1, 0]

    # Get average node losses across runs
    max_len = max(
        len(losses) for losses in all_results["GraphSAGE"]["node_losses_list"]
    )

    # Pad and average GraphSAGE losses
    sage_losses_padded = []
    for losses in all_results["GraphSAGE"]["node_losses_list"]:
        if len(losses) < max_len:
            losses = np.pad(losses, (0, max_len - len(losses)), "edge")
        sage_losses_padded.append(losses)

    sage_losses_mean = np.mean(sage_losses_padded, axis=0)
    sage_losses_std = np.std(sage_losses_padded, axis=0)

    # Pad and average GATv2 losses
    gatv2_losses_padded = []
    for losses in all_results["GATv2"]["node_losses_list"]:
        if len(losses) < max_len:
            losses = np.pad(losses, (0, max_len - len(losses)), "edge")
        gatv2_losses_padded.append(losses)

    gatv2_losses_mean = np.mean(gatv2_losses_padded, axis=0)
    gatv2_losses_std = np.std(gatv2_losses_padded, axis=0)

    epochs = np.arange(len(sage_losses_mean))

    ax.plot(epochs, sage_losses_mean, label="GraphSAGE", color="tab:blue", linewidth=2)
    ax.fill_between(
        epochs,
        sage_losses_mean - sage_losses_std,
        sage_losses_mean + sage_losses_std,
        alpha=0.2,
        color="tab:blue",
    )

    ax.plot(epochs, gatv2_losses_mean, label="GATv2", color="tab:orange", linewidth=2)
    ax.fill_between(
        epochs,
        gatv2_losses_mean - gatv2_losses_std,
        gatv2_losses_mean + gatv2_losses_std,
        alpha=0.2,
        color="tab:orange",
    )

    ax.set_title(
        "Average Node Classification Loss\n(with std deviation)",
        fontsize=12,
        fontweight="bold",
    )
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 5. Average Training Curves (Link Loss)
    ax = axes[1, 1]

    # Get average link losses across runs
    max_len = max(
        len(losses) for losses in all_results["GraphSAGE"]["link_losses_list"]
    )

    # Pad and average GraphSAGE losses
    sage_losses_padded = []
    for losses in all_results["GraphSAGE"]["link_losses_list"]:
        if len(losses) < max_len:
            losses = np.pad(losses, (0, max_len - len(losses)), "edge")
        sage_losses_padded.append(losses)

    sage_losses_mean = np.mean(sage_losses_padded, axis=0)
    sage_losses_std = np.std(sage_losses_padded, axis=0)

    # Pad and average GATv2 losses
    gatv2_losses_padded = []
    for losses in all_results["GATv2"]["link_losses_list"]:
        if len(losses) < max_len:
            losses = np.pad(losses, (0, max_len - len(losses)), "edge")
        gatv2_losses_padded.append(losses)

    gatv2_losses_mean = np.mean(gatv2_losses_padded, axis=0)
    gatv2_losses_std = np.std(gatv2_losses_padded, axis=0)

    epochs = np.arange(len(sage_losses_mean))

    ax.plot(epochs, sage_losses_mean, label="GraphSAGE", color="tab:blue", linewidth=2)
    ax.fill_between(
        epochs,
        sage_losses_mean - sage_losses_std,
        sage_losses_mean + sage_losses_std,
        alpha=0.2,
        color="tab:blue",
    )

    ax.plot(epochs, gatv2_losses_mean, label="GATv2", color="tab:orange", linewidth=2)
    ax.fill_between(
        epochs,
        gatv2_losses_mean - gatv2_losses_std,
        gatv2_losses_mean + gatv2_losses_std,
        alpha=0.2,
        color="tab:orange",
    )

    ax.set_title(
        "Average Link Prediction Loss\n(with std deviation)",
        fontsize=12,
        fontweight="bold",
    )
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 6. Performance Trade-off Analysis
    ax = axes[1, 2]

    # Scatter plot: Node Accuracy vs Link AUC
    sage_accs = np.array(all_results["GraphSAGE"]["node_accuracies"])
    sage_aucs = np.array(all_results["GraphSAGE"]["link_aucs"])
    gatv2_accs = np.array(all_results["GATv2"]["node_accuracies"])
    gatv2_aucs = np.array(all_results["GATv2"]["link_aucs"])

    ax.scatter(
        sage_accs,
        sage_aucs,
        color="tab:blue",
        s=100,
        alpha=0.7,
        label="GraphSAGE",
        edgecolors="darkblue",
        linewidth=1,
    )
    ax.scatter(
        gatv2_accs,
        gatv2_aucs,
        color="tab:orange",
        s=100,
        alpha=0.7,
        label="GATv2",
        edgecolors="darkorange",
        linewidth=1,
    )

    # Add mean points
    ax.scatter(
        np.mean(sage_accs),
        np.mean(sage_aucs),
        color="darkblue",
        s=200,
        marker="*",
        label="GraphSAGE Mean",
    )
    ax.scatter(
        np.mean(gatv2_accs),
        np.mean(gatv2_aucs),
        color="darkorange",
        s=200,
        marker="*",
        label="GATv2 Mean",
    )

    ax.set_xlabel("Node Accuracy")
    ax.set_ylabel("Link Prediction AUC")
    ax.set_title(
        "Performance Trade-off Analysis\n(Each point = 1 run)",
        fontsize=12,
        fontweight="bold",
    )
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Add Pareto frontier
    all_accs = np.concatenate([sage_accs, gatv2_accs])
    all_aucs = np.concatenate([sage_aucs, gatv2_aucs])

    # Simple Pareto frontier
    pareto_mask = np.ones(len(all_accs), dtype=bool)
    for i in range(len(all_accs)):
        for j in range(len(all_accs)):
            if i != j and all_accs[j] >= all_accs[i] and all_aucs[j] >= all_aucs[i]:
                pareto_mask[i] = False
                break

    pareto_accs = all_accs[pareto_mask]
    pareto_aucs = all_aucs[pareto_mask]

    # Sort for plotting
    sort_idx = np.argsort(pareto_accs)
    ax.plot(
        pareto_accs[sort_idx],
        pareto_aucs[sort_idx],
        "k--",
        alpha=0.5,
        linewidth=1,
        label="Pareto Frontier",
    )

    plt.tight_layout()
    plt.savefig("images/multiple_runs_comparison.png", dpi=150, bbox_inches="tight")
    plt.show()

In [None]:
import time
import pandas as pd
from collections import defaultdict


def run_multiple_experiments(
    data, link_data, num_runs=100, epochs_node=80, epochs_link=120
):
    """Run multiple independent training sessions and collect statistics"""

    print(f"\nRunning {num_runs} Independent Experiments")
    print(f"GraphSAGE vs GATv2 - Averaging Results\n")

    # Storage for results
    all_results = {
        "GraphSAGE": {
            "node_accuracies": [],
            "link_aucs": [],
            "link_aps": [],
            "precisions_at_k": [],
            "node_losses_list": [],
            "link_losses_list": [],
        },
        "GATv2": {
            "node_accuracies": [],
            "link_aucs": [],
            "link_aps": [],
            "precisions_at_k": [],
            "node_losses_list": [],
            "link_losses_list": [],
        },
    }

    training_times = {"GraphSAGE": [], "GATv2": []}

    for run in range(num_runs):
        print(f"\nExperiment Run {run+1}/{num_runs}\n")

        # Set different random seed for each run
        torch.manual_seed(42 + run)
        np.random.seed(42 + run)

        print(f"\n[Run {run+1}] Training GraphSAGE...")
        start_time = time.time()

        # Reinitialize fresh GraphSAGE models
        sage_models = GraphSAGEModels(
            in_channels=data.num_features,
            hidden_channels=128,
            num_classes=dataset.num_classes,
        )

        # Train
        sage_models, sage_node_losses, sage_node_accs, sage_link_losses = (
            train_graphsage_models(
                sage_models,
                data,
                link_data,
                epochs_node=epochs_node,
                epochs_link=epochs_link,
            )
        )

        # Evaluate
        sage_metrics = evaluate_graphsage_models(sage_models, data, link_data)
        sage_time = time.time() - start_time
        training_times["GraphSAGE"].append(sage_time)

        # Store results
        all_results["GraphSAGE"]["node_accuracies"].append(
            sage_metrics["node_accuracy"]
        )
        all_results["GraphSAGE"]["link_aucs"].append(sage_metrics["link_auc"])
        all_results["GraphSAGE"]["link_aps"].append(sage_metrics["link_ap"])
        all_results["GraphSAGE"]["precisions_at_k"].append(
            sage_metrics["link_precision_at_k"]
        )
        all_results["GraphSAGE"]["node_losses_list"].append(sage_node_losses)
        all_results["GraphSAGE"]["link_losses_list"].append(sage_link_losses)

        print(
            f"GraphSAGE Run {run+1}: Node Acc={sage_metrics['node_accuracy']:.4f}, "
            f"Link AUC={sage_metrics['link_auc']:.4f}, Time={sage_time:.1f}s"
        )

        print(f"\n[Run {run+1}] Training GATv2...")
        start_time = time.time()

        # Reinitialize fresh GATv2 models
        gatv2_models = GATv2Models(
            in_channels=data.num_features,
            hidden_channels=128,
            num_classes=dataset.num_classes,
            heads=8,
        )

        # Train
        gatv2_models, gatv2_node_losses, gatv2_node_accs, gatv2_link_losses = (
            train_gatv2_models(
                gatv2_models,
                data,
                link_data,
                epochs_node=epochs_node,
                epochs_link=epochs_link,
            )
        )

        # Evaluate
        gatv2_metrics = evaluate_gatv2_models(gatv2_models, data, link_data)
        gatv2_time = time.time() - start_time
        training_times["GATv2"].append(gatv2_time)

        # Store results
        all_results["GATv2"]["node_accuracies"].append(gatv2_metrics["node_accuracy"])
        all_results["GATv2"]["link_aucs"].append(gatv2_metrics["link_auc"])
        all_results["GATv2"]["link_aps"].append(gatv2_metrics["link_ap"])
        all_results["GATv2"]["precisions_at_k"].append(
            gatv2_metrics["link_precision_at_k"]
        )
        all_results["GATv2"]["node_losses_list"].append(gatv2_node_losses)
        all_results["GATv2"]["link_losses_list"].append(gatv2_link_losses)

        print(
            f"GATv2 Run {run+1}: Node Acc={gatv2_metrics['node_accuracy']:.4f}, "
            f"Link AUC={gatv2_metrics['link_auc']:.4f}, Time={gatv2_time:.1f}s"
        )

        # Clean up to free memory
        del sage_models, gatv2_models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"\nAnalysis (Mean ± Std over {num_runs} runs)\n")

    def compute_stats(values):
        mean = np.mean(values)
        std = np.std(values)
        min_val = np.min(values)
        max_val = np.max(values)
        return mean, std, min_val, max_val

    # Create summary table
    summary_data = []

    for model_name in ["GraphSAGE", "GATv2"]:
        # Node Accuracy
        node_mean, node_std, node_min, node_max = compute_stats(
            all_results[model_name]["node_accuracies"]
        )

        # Link AUC
        auc_mean, auc_std, auc_min, auc_max = compute_stats(
            all_results[model_name]["link_aucs"]
        )

        # Link AP
        ap_mean, ap_std, ap_min, ap_max = compute_stats(
            all_results[model_name]["link_aps"]
        )

        # Precision@k
        prec_mean, prec_std, prec_min, prec_max = compute_stats(
            all_results[model_name]["precisions_at_k"]
        )

        # Training time
        time_mean, time_std, time_min, time_max = compute_stats(
            training_times[model_name]
        )

        summary_data.append(
            {
                "Model": model_name,
                "Node Accuracy": f"{node_mean:.4f} ± {node_std:.4f} [{node_min:.4f}-{node_max:.4f}]",
                "Link AUC": f"{auc_mean:.4f} ± {auc_std:.4f} [{auc_min:.4f}-{auc_max:.4f}]",
                "Link AP": f"{ap_mean:.4f} ± {ap_std:.4f} [{ap_min:.4f}-{ap_max:.4f}]",
                "Precision@100": f"{prec_mean:.4f} ± {prec_std:.4f} [{prec_min:.4f}-{prec_max:.4f}]",
                "Training Time (s)": f"{time_mean:.1f} ± {time_std:.1f} [{time_min:.1f}-{time_max:.1f}]",
            }
        )

    # Display summary table
    summary_df = pd.DataFrame(summary_data)
    print("\n" + summary_df.to_string(index=False))

    print("\nComparison\n")

    from scipy import stats

    # Compare metrics between models
    for metric_name, sage_values, gatv2_values in [
        (
            "Node Accuracy",
            all_results["GraphSAGE"]["node_accuracies"],
            all_results["GATv2"]["node_accuracies"],
        ),
        (
            "Link AUC",
            all_results["GraphSAGE"]["link_aucs"],
            all_results["GATv2"]["link_aucs"],
        ),
        (
            "Link AP",
            all_results["GraphSAGE"]["link_aps"],
            all_results["GATv2"]["link_aps"],
        ),
    ]:
        # T-test for statistical significance
        t_stat, p_value = stats.ttest_rel(sage_values, gatv2_values)

        sage_mean = np.mean(sage_values)
        gatv2_mean = np.mean(gatv2_values)
        diff = gatv2_mean - sage_mean

        print(f"\n{metric_name}:")
        print(f"  GraphSAGE: {sage_mean:.4f} ± {np.std(sage_values):.4f}")
        print(f"  GATv2:     {gatv2_mean:.4f} ± {np.std(gatv2_values):.4f}")
        print(f"  Difference: {diff:.4f} (GATv2 - GraphSAGE)")
        print(f"  p-value: {p_value:.6f}")

        if p_value < 0.05:
            if diff > 0:
                print(f"   GATv2 is significantly better (p < 0.05)")
            else:
                print(f"   GraphSAGE is significantly better (p < 0.05)")
        else:
            print(f"   No significant difference (p ≥ 0.05)")

    print(f"visualizing results across {num_runs} runs")

    plot_multiple_runs_results(all_results, training_times, num_runs)

    return all_results, summary_df

In [None]:
all_results, summary_df = run_multiple_experiments(
    data, link_data, num_runs=100, epochs_node=60, epochs_link=120
)