# GNN Experiments on AMiner Dataset

Testing out different GNN architectures (GAT, GCN, GraphSAGE) on the AMiner author collaboration network. 
This dataset has around 10k authors across 8 research fields.

**Goal:** See which architecture works best for classifying authors by their research area based on collaboration patterns.

In [None]:
# Setting up the environment - this takes a minute
!pip install -q torch torchvision
!pip install -q torch-geometric
!pip install -q ogb
!pip install -q matplotlib seaborn pandas numpy scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv, SAGEConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
from datetime import datetime

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Loading AMiner Data

The AMiner dataset contains author collaboration networks. Each node is an author, edges represent collaborations.

In [None]:
from torch_geometric.datasets import AMiner
import torch_geometric.transforms as T

# Load the dataset
print("Loading AMiner dataset...")
dataset = AMiner(root='/tmp/AMiner', transform=T.NormalizeFeatures())
data = dataset[0]

# Basic stats
print(f"\nDataset stats:")
print(f"  Nodes (authors): {data.num_nodes}")
print(f"  Edges: {data.num_edges}")
print(f"  Features per node: {data.num_features}")
print(f"  Research fields: {dataset.num_classes}")
print(f"  Average degree: {data.num_edges / data.num_nodes:.2f}")

In [None]:
# Quick preprocessing - adding self-loops helps with message passing
from torch_geometric.utils import add_self_loops, remove_self_loops

edge_index = data.edge_index
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes)
data.edge_index = edge_index

print(f"After adding self-loops: {data.num_edges} edges")

## Creating Train/Val/Test Splits

Using 60-20-20 split. Trying to keep it balanced across classes.

In [None]:
# Create masks for train/val/test
num_authors = data.num_nodes
num_classes = dataset.num_classes

train_mask = torch.zeros(num_authors, dtype=torch.bool)
val_mask = torch.zeros(num_authors, dtype=torch.bool)
test_mask = torch.zeros(num_authors, dtype=torch.bool)

# Split per class to keep it balanced
for c in range(num_classes):
    class_indices = (data.y == c).nonzero(as_tuple=True)[0]
    n = len(class_indices)
    
    perm = torch.randperm(n)
    train_size = int(0.6 * n)
    val_size = int(0.2 * n)
    
    train_mask[class_indices[perm[:train_size]]] = True
    val_mask[class_indices[perm[train_size:train_size+val_size]]] = True
    test_mask[class_indices[perm[train_size+val_size:]]] = True

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

print(f"Train: {train_mask.sum()} | Val: {val_mask.sum()} | Test: {test_mask.sum()}")

## Model Architectures

Implementing three different GNN models to compare:
1. **GAT** - uses attention mechanisms
2. **GCN** - classic graph convolution
3. **GraphSAGE** - neighborhood sampling approach

In [None]:
class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.6):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.6):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.6):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = dropout
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

print("Models defined!")

## Training Setup

Using class weights to handle imbalanced data - some research fields have way more authors than others.

In [None]:
def train_model(model, data, model_name, epochs=200, lr=0.01, weight_decay=5e-4, patience=40):
    """
    Training loop with early stopping
    """
    model = model.to(device)
    data = data.to(device)
    
    # Calculate class weights for imbalanced data
    class_counts = torch.bincount(data.y[data.train_mask])
    class_weights = 1.0 / class_counts.float()
    class_weights = class_weights / class_weights.sum() * len(class_weights)
    class_weights = class_weights.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.NLLLoss(weight=class_weights)
    
    best_val_acc = 0
    best_model_state = None
    patience_counter = 0
    
    train_losses = []
    val_accs = []
    
    print(f"\nTraining {model_name}...")
    
    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            pred = out.argmax(dim=1)
            val_acc = (pred[data.val_mask] == data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
        
        train_losses.append(loss.item())
        val_accs.append(val_acc)
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    return model, train_losses, val_accs

def evaluate_model(model, data):
    """
    Get test set metrics
    """
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        
        test_acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
        
        # Get predictions for metrics
        y_true = data.y[data.test_mask].cpu().numpy()
        y_pred = pred[data.test_mask].cpu().numpy()
        
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    
    return test_acc, precision, recall, f1, y_true, y_pred

## Training GAT Model

Starting with GAT - attention mechanism should help with varying node degrees

In [None]:
gat = GAT(data.num_features, 256, num_classes, heads=4, dropout=0.6)
gat_model, gat_losses, gat_val_accs = train_model(gat, data, 'GAT', epochs=200, lr=0.005)
gat_test_acc, gat_precision, gat_recall, gat_f1, gat_y_true, gat_y_pred = evaluate_model(gat_model, data)

print(f"\nGAT Results:")
print(f"  Test Accuracy: {gat_test_acc:.4f}")
print(f"  Precision: {gat_precision:.4f}")
print(f"  Recall: {gat_recall:.4f}")
print(f"  F1 Score: {gat_f1:.4f}")

## Training GCN Model

In [None]:
gcn = GCN(data.num_features, 256, num_classes, dropout=0.6)
gcn_model, gcn_losses, gcn_val_accs = train_model(gcn, data, 'GCN', epochs=200, lr=0.01)
gcn_test_acc, gcn_precision, gcn_recall, gcn_f1, gcn_y_true, gcn_y_pred = evaluate_model(gcn_model, data)

print(f"\nGCN Results:")
print(f"  Test Accuracy: {gcn_test_acc:.4f}")
print(f"  Precision: {gcn_precision:.4f}")
print(f"  Recall: {gcn_recall:.4f}")
print(f"  F1 Score: {gcn_f1:.4f}")

## Training GraphSAGE Model

GraphSAGE uses neighborhood sampling - curious to see how it compares

In [None]:
sage = GraphSAGE(data.num_features, 256, num_classes, dropout=0.6)
sage_model, sage_losses, sage_val_accs = train_model(sage, data, 'GraphSAGE', epochs=200, lr=0.01)
sage_test_acc, sage_precision, sage_recall, sage_f1, sage_y_true, sage_y_pred = evaluate_model(sage_model, data)

print(f"\nGraphSAGE Results:")
print(f"  Test Accuracy: {sage_test_acc:.4f}")
print(f"  Precision: {sage_precision:.4f}")
print(f"  Recall: {sage_recall:.4f}")
print(f"  F1 Score: {sage_f1:.4f}")

## Comparing Results

Let's see which model performed best

In [None]:
# Summary table
results_df = pd.DataFrame({
    'Model': ['GAT', 'GCN', 'GraphSAGE'],
    'Test Accuracy': [gat_test_acc, gcn_test_acc, sage_test_acc],
    'Precision': [gat_precision, gcn_precision, sage_precision],
    'Recall': [gat_recall, gcn_recall, sage_recall],
    'F1 Score': [gat_f1, gcn_f1, sage_f1]
})

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(results_df.to_string(index=False))
print("="*60)

## Visualizations

Plotting training curves and confusion matrices

In [None]:
# Training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(gat_losses, label='GAT', alpha=0.7)
axes[0].plot(gcn_losses, label='GCN', alpha=0.7)
axes[0].plot(sage_losses, label='GraphSAGE', alpha=0.7)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss Over Time')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Validation accuracy
axes[1].plot(gat_val_accs, label='GAT', alpha=0.7)
axes[1].plot(gcn_val_accs, label='GCN', alpha=0.7)
axes[1].plot(sage_val_accs, label='GraphSAGE', alpha=0.7)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy')
axes[1].set_title('Validation Accuracy Over Time')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Confusion matrices for best model
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

models_data = [
    ('GAT', gat_y_true, gat_y_pred),
    ('GCN', gcn_y_true, gcn_y_pred),
    ('GraphSAGE', sage_y_true, sage_y_pred)
]

for idx, (name, y_true, y_pred) in enumerate(models_data):
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx])
    axes[idx].set_title(f'{name} Confusion Matrix')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('Actual')

plt.tight_layout()
plt.show()

## Saving Models

Saving the trained models for later use

In [None]:
# Save models
torch.save({
    'gat_state_dict': gat_model.state_dict(),
    'gcn_state_dict': gcn_model.state_dict(),
    'sage_state_dict': sage_model.state_dict(),
    'num_features': data.num_features,
    'num_classes': num_classes,
    'results': results_df.to_dict()
}, 'aminer_models.pt')

print("Models saved to aminer_models.pt")