# Testing GNNs on OGB arXiv Dataset

This notebook experiments with different GNN architectures on the OGB arXiv citation network.

**Dataset:** 169k CS papers, 40 subject areas, 1.17M citations

**Models:** GAT, GCN, GraphSAGE

The goal is to classify papers into research topics based on their citation patterns and features.

## Setup and Installation

Installing necessary packages - this might take a couple minutes

In [None]:
!pip install -q torch torchvision
!pip install -q torch-geometric
!pip install -q ogb
!pip install -q matplotlib seaborn pandas numpy scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv, SAGEConv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns

# GPU check
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Loading OGB arXiv Dataset

This dataset is pretty large so downloading might take a bit

In [None]:
# Fix for PyTorch 2.6+ compatibility issue
import torch.serialization
from torch_geometric.data.data import DataEdgeAttr

# Monkey patch to handle the deprecation
if not hasattr(torch.serialization, 'add_safe_globals'):
    torch.serialization.add_safe_globals = lambda x: None
torch.serialization.add_safe_globals([DataEdgeAttr])

In [None]:
from ogb.nodeproppred import PygNodePropPredDataset

print("Downloading OGB arXiv dataset...")
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='/tmp/ogb')
data = dataset[0]
split_idx = dataset.get_idx_split()

# Add train/val/test masks to data object
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[split_idx['train']] = True

data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[split_idx['valid']] = True

data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[split_idx['test']] = True

# Flatten labels
data.y = data.y.squeeze(1)

print(f"\nDataset loaded!")
print(f"  Nodes: {data.num_nodes:,}")
print(f"  Edges: {data.num_edges:,}")
print(f"  Features: {data.num_features}")
print(f"  Classes: {dataset.num_classes}")
print(f"  Train: {split_idx['train'].shape[0]:,}")
print(f"  Val: {split_idx['valid'].shape[0]:,}")
print(f"  Test: {split_idx['test'].shape[0]:,}")

## Model Definitions

Defining the three GNN architectures we'll compare

In [None]:
class GAT(nn.Module):
    """Graph Attention Network"""
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.5):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GCN(nn.Module):
    """Graph Convolutional Network"""
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GraphSAGE(nn.Module):
    """GraphSAGE with neighborhood sampling"""
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

print("Models ready!")

## Training Function

Standard training loop with early stopping based on validation accuracy

In [None]:
def train_model(model, data, model_name, epochs=200, lr=0.01, weight_decay=5e-4, patience=40):
    """
    Train a GNN model with early stopping
    """
    model = model.to(device)
    data = data.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.NLLLoss()
    
    best_val_acc = 0
    best_model_state = None
    patience_counter = 0
    
    train_losses = []
    val_accs = []
    
    print(f"\nTraining {model_name}...")
    print("-" * 50)
    
    for epoch in range(1, epochs + 1):
        # Training phase
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            val_correct = (pred[data.val_mask] == data.y[data.val_mask]).sum().item()
            val_acc = val_correct / data.val_mask.sum().item()
        
        train_losses.append(loss.item())
        val_accs.append(val_acc)
        
        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f} | Best: {best_val_acc:.4f}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch}")
            break
    
    # Restore best model
    model.load_state_dict(best_model_state)
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    
    return model, train_losses, val_accs

def evaluate_model(model, data):
    """
    Evaluate model on test set
    """
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        
        # Test accuracy
        test_correct = (pred[data.test_mask] == data.y[data.test_mask]).sum().item()
        test_acc = test_correct / data.test_mask.sum().item()
        
        # Additional metrics
        y_true = data.y[data.test_mask].cpu().numpy()
        y_pred = pred[data.test_mask].cpu().numpy()
        
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )
    
    return test_acc, precision, recall, f1, y_true, y_pred

## Training All Models

Now let's train all three models and compare their performance

### GAT (Graph Attention Network)

In [None]:
gat = GAT(data.num_features, 256, dataset.num_classes, heads=4, dropout=0.5)
gat_model, gat_losses, gat_val_accs = train_model(
    gat, data, 'GAT', 
    epochs=200, 
    lr=0.005, 
    weight_decay=5e-4
)

gat_test_acc, gat_precision, gat_recall, gat_f1, gat_y_true, gat_y_pred = evaluate_model(gat_model, data)

print(f"\n{'='*50}")
print(f"GAT Test Results:")
print(f"{'='*50}")
print(f"Accuracy:  {gat_test_acc:.4f}")
print(f"Precision: {gat_precision:.4f}")
print(f"Recall:    {gat_recall:.4f}")
print(f"F1 Score:  {gat_f1:.4f}")

### GCN (Graph Convolutional Network)

In [None]:
gcn = GCN(data.num_features, 256, dataset.num_classes, dropout=0.5)
gcn_model, gcn_losses, gcn_val_accs = train_model(
    gcn, data, 'GCN', 
    epochs=200, 
    lr=0.01, 
    weight_decay=5e-4
)

gcn_test_acc, gcn_precision, gcn_recall, gcn_f1, gcn_y_true, gcn_y_pred = evaluate_model(gcn_model, data)

print(f"\n{'='*50}")
print(f"GCN Test Results:")
print(f"{'='*50}")
print(f"Accuracy:  {gcn_test_acc:.4f}")
print(f"Precision: {gcn_precision:.4f}")
print(f"Recall:    {gcn_recall:.4f}")
print(f"F1 Score:  {gcn_f1:.4f}")

### GraphSAGE

In [None]:
sage = GraphSAGE(data.num_features, 256, dataset.num_classes, dropout=0.5)
sage_model, sage_losses, sage_val_accs = train_model(
    sage, data, 'GraphSAGE', 
    epochs=200, 
    lr=0.01, 
    weight_decay=5e-4
)

sage_test_acc, sage_precision, sage_recall, sage_f1, sage_y_true, sage_y_pred = evaluate_model(sage_model, data)

print(f"\n{'='*50}")
print(f"GraphSAGE Test Results:")
print(f"{'='*50}")
print(f"Accuracy:  {sage_test_acc:.4f}")
print(f"Precision: {sage_precision:.4f}")
print(f"Recall:    {sage_recall:.4f}")
print(f"F1 Score:  {sage_f1:.4f}")

## Results Comparison

Let's put all the results together and see which model won

In [None]:
# Create comparison dataframe
results_df = pd.DataFrame({
    'Model': ['GAT', 'GCN', 'GraphSAGE'],
    'Test Accuracy': [gat_test_acc, gcn_test_acc, sage_test_acc],
    'Precision': [gat_precision, gcn_precision, sage_precision],
    'Recall': [gat_recall, gcn_recall, sage_recall],
    'F1 Score': [gat_f1, gcn_f1, sage_f1]
})

# Sort by test accuracy
results_df = results_df.sort_values('Test Accuracy', ascending=False).reset_index(drop=True)

print("\n" + "="*70)
print("FINAL RESULTS - OGB ARXIV DATASET")
print("="*70)
print(results_df.to_string(index=False))
print("="*70)

## Visualizations

In [None]:
# Plot training dynamics
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Training loss
axes[0].plot(gat_losses, label='GAT', linewidth=2, alpha=0.7)
axes[0].plot(gcn_losses, label='GCN', linewidth=2, alpha=0.7)
axes[0].plot(sage_losses, label='GraphSAGE', linewidth=2, alpha=0.7)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Training Loss', fontsize=12)
axes[0].set_title('Training Loss Curves', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(alpha=0.3)

# Validation accuracy
axes[1].plot(gat_val_accs, label='GAT', linewidth=2, alpha=0.7)
axes[1].plot(gcn_val_accs, label='GCN', linewidth=2, alpha=0.7)
axes[1].plot(sage_val_accs, label='GraphSAGE', linewidth=2, alpha=0.7)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Validation Accuracy', fontsize=12)
axes[1].set_title('Validation Accuracy Over Time', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Bar chart comparing models
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(results_df))
width = 0.2

ax.bar(x - 1.5*width, results_df['Test Accuracy'], width, label='Accuracy', alpha=0.8)
ax.bar(x - 0.5*width, results_df['Precision'], width, label='Precision', alpha=0.8)
ax.bar(x + 0.5*width, results_df['Recall'], width, label='Recall', alpha=0.8)
ax.bar(x + 1.5*width, results_df['F1 Score'], width, label='F1', alpha=0.8)

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(results_df['Model'])
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## Saving Results

Let's save the trained models and results

In [None]:
# Save everything
torch.save({
    'gat_state': gat_model.state_dict(),
    'gcn_state': gcn_model.state_dict(),
    'sage_state': sage_model.state_dict(),
    'results': results_df.to_dict(),
    'dataset_info': {
        'num_features': data.num_features,
        'num_classes': dataset.num_classes,
        'num_nodes': data.num_nodes,
        'num_edges': data.num_edges
    }
}, 'ogb_arxiv_models.pt')

print("\nAll models and results saved to 'ogb_arxiv_models.pt'")
print("\nDone! 🎉")