# Training GraphSAGE and GAT Models

This notebook demonstrates training GraphSAGE and GAT models for friend recommendation.

## Models:
- **GraphSAGE**: Inductive learning with mean/pooling aggregation
- **GAT**: Graph Attention Network with multi-head attention

## Steps:
1. Load preprocessed data
2. Initialize models
3. Train models
4. Evaluate on validation set
5. Generate top-K recommendations for example users


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd())))

import torch
import numpy as np
import matplotlib.pyplot as plt
from src.models.link_predictor import GraphSAGELinkPredictor, GATLinkPredictor
from src.training import Trainer, set_seed, EarlyStopping
from src.evaluation import compute_metrics, compute_ranking_metrics, get_top_k_recommendations

# Set seed for reproducibility
set_seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## 1. Load Data


In [None]:
# Load processed data
data = torch.load("data/processed/facebook_combined.pt")
link_data = torch.load("data/processed/facebook_link_data.pt")

print(f"Graph: {data.num_nodes} nodes, {data.edge_index.size(1) // 2} edges")
print(f"Node features: {data.x.shape}")
print(f"Train edges: {link_data['train_edges'].size(1)}")
print(f"Val edges: {link_data['val_edges'].size(1)}")
print(f"Test edges: {link_data['test_edges'].size(1)}")


## 2. Train GraphSAGE Model


In [None]:
# Initialize GraphSAGE model
input_dim = data.x.size(1)
hidden_dim = 64
embedding_dim = 64
num_layers = 2
dropout = 0.5

model_graphsage = GraphSAGELinkPredictor(
    input_dim, hidden_dim, embedding_dim, num_layers, dropout, 
    aggregator='mean', predictor_method='mlp'
).to(device)

print(f"GraphSAGE parameters: {sum(p.numel() for p in model_graphsage.parameters())}")


In [None]:
# Create trainer
trainer_graphsage = Trainer(model_graphsage, device, lr=0.01, weight_decay=5e-4)

# Prepare data
train_data = {
    'edges': link_data['train_edges'],
    'labels': link_data['train_labels']
}
val_data = {
    'edges': link_data['val_edges'],
    'labels': link_data['val_labels']
}

# Early stopping
early_stopping = EarlyStopping(patience=10)

# Train
print("Training GraphSAGE...")
history_graphsage = trainer_graphsage.train(
    data, train_data, val_data, num_epochs=100, 
    early_stopping=early_stopping, checkpoint_dir="data/checkpoints/graphsage"
)

print(f"Best validation loss: {history_graphsage['best_val_loss']:.4f}")


In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history_graphsage['train_losses'], label='Train Loss')
plt.plot(history_graphsage['val_losses'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GraphSAGE Training Curves')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


## 3. Train GAT Model


In [None]:
# Initialize GAT model
model_gat = GATLinkPredictor(
    input_dim, hidden_dim, embedding_dim, num_layers, 
    num_heads=4, dropout=dropout, predictor_method='mlp'
).to(device)

print(f"GAT parameters: {sum(p.numel() for p in model_gat.parameters())}")


In [None]:
# Create trainer
trainer_gat = Trainer(model_gat, device, lr=0.01, weight_decay=5e-4)

# Early stopping
early_stopping_gat = EarlyStopping(patience=10)

# Train
print("Training GAT...")
history_gat = trainer_gat.train(
    data, train_data, val_data, num_epochs=100, 
    early_stopping=early_stopping_gat, checkpoint_dir="data/checkpoints/gat"
)

print(f"Best validation loss: {history_gat['best_val_loss']:.4f}")


In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history_gat['train_losses'], label='Train Loss')
plt.plot(history_gat['val_losses'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAT Training Curves')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


## 4. Evaluate on Validation Set


In [None]:
# Evaluate GraphSAGE
val_loss_graphsage, val_pred_graphsage, val_labels_np = trainer_graphsage.evaluate(
    data, val_data['edges'], val_data['labels']
)

metrics_graphsage = compute_metrics(val_pred_graphsage, val_labels_np)
ranking_metrics_graphsage = compute_ranking_metrics(val_pred_graphsage, val_labels_np, k_values=[5, 10, 20])

print("GraphSAGE Validation Metrics:")
print(f"AUC: {metrics_graphsage['auc']:.4f}")
print(f"AP: {metrics_graphsage['ap']:.4f}")
print(f"Precision@10: {ranking_metrics_graphsage['precision@10']:.4f}")
print(f"Recall@10: {ranking_metrics_graphsage['recall@10']:.4f}")
print(f"NDCG@10: {ranking_metrics_graphsage['ndcg@10']:.4f}")


In [None]:
# Evaluate GAT
val_loss_gat, val_pred_gat, val_labels_np = trainer_gat.evaluate(
    data, val_data['edges'], val_data['labels']
)

metrics_gat = compute_metrics(val_pred_gat, val_labels_np)
ranking_metrics_gat = compute_ranking_metrics(val_pred_gat, val_labels_np, k_values=[5, 10, 20])

print("GAT Validation Metrics:")
print(f"AUC: {metrics_gat['auc']:.4f}")
print(f"AP: {metrics_gat['ap']:.4f}")
print(f"Precision@10: {ranking_metrics_gat['precision@10']:.4f}")
print(f"Recall@10: {ranking_metrics_gat['recall@10']:.4f}")
print(f"NDCG@10: {ranking_metrics_gat['ndcg@10']:.4f}")


## 5. Generate Top-K Recommendations

Generate friend recommendations for example users.


In [None]:
# Example: Get top-10 recommendations for user 0
user_id = 0
k = 10

# Get all candidate nodes (excluding user and existing friends)
existing_friends = set()
for i in range(data.edge_index.size(1)):
    src, dst = data.edge_index[0, i].item(), data.edge_index[1, i].item()
    if src == user_id:
        existing_friends.add(dst)
    if dst == user_id:
        existing_friends.add(src)

candidate_nodes = torch.tensor([i for i in range(data.num_nodes) 
                                if i != user_id and i not in existing_friends])

# Get recommendations using GraphSAGE
top_k_nodes_graphsage, top_k_scores_graphsage = get_top_k_recommendations(
    model_graphsage, data, user_id, candidate_nodes, k=k, device=device
)

print(f"Top-{k} recommendations for user {user_id} (GraphSAGE):")
for i, (node, score) in enumerate(zip(top_k_nodes_graphsage, top_k_scores_graphsage)):
    print(f"{i+1}. User {node.item()}: {score.item():.4f}")


In [None]:
# Get recommendations using GAT
top_k_nodes_gat, top_k_scores_gat = get_top_k_recommendations(
    model_gat, data, user_id, candidate_nodes, k=k, device=device
)

print(f"Top-{k} recommendations for user {user_id} (GAT):")
for i, (node, score) in enumerate(zip(top_k_nodes_gat, top_k_scores_gat)):
    print(f"{i+1}. User {node.item()}: {score.item():.4f}")


## Summary

- **GraphSAGE**: Trained and evaluated
- **GAT**: Trained and evaluated
- Models saved to `data/checkpoints/`
- Top-K recommendations generated for example users

Next steps:
1. Evaluate on test set (see `evaluation_and_ablation.ipynb`)
2. Compare with baselines
3. Train SEAL model (see `training_seal.ipynb`)
