# GraphSAGE & GAT - Kaggle GPU Training

Training GraphSAGE and GAT models on Elliptic++ dataset.

**Requirements:**
- GPU enabled (Settings → Accelerator → GPU T4 x2)
- Elliptic dataset linked (Add Data → elliptic-fraud-detection)

**Expected Runtime:** ~25-30 minutes (both models)

## 1. Install Dependencies

In [None]:
!pip install torch-geometric -q
print("✓ Dependencies installed")

## 2. Imports & Setup

In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GATConv

from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    f1_score,
    precision_recall_curve,
    roc_curve
)

sns.set_style('whitegrid')

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## 3. Set Seed

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("✓ Seed set to 42")

## 4. Load Dataset (Same as GCN)

In [None]:
DATA_PATH = '/kaggle/input/elliptic-fraud-detection/'

print("Loading Elliptic++ dataset...")

features_df = pd.read_csv(DATA_PATH + 'txs_features.csv')
classes_df = pd.read_csv(DATA_PATH + 'txs_classes.csv')
edges_df = pd.read_csv(DATA_PATH + 'txs_edgelist.csv')

print(f"✓ Features: {features_df.shape}")
print(f"✓ Classes: {classes_df.shape}")
print(f"✓ Edges: {edges_df.shape}")

In [None]:
# Merge and process
data_df = features_df.merge(classes_df, on='txId', how='left')
data_df['class'] = data_df['class'].fillna(3).astype(int)

print(f"Total: {len(data_df):,}")
print(f"Fraud: {(data_df['class'] == 1).sum():,}")
print(f"Legit: {(data_df['class'] == 2).sum():,}")
print(f"Unlabeled: {(data_df['class'] == 3).sum():,}")

# Features
feature_cols = [col for col in data_df.columns if col not in ['txId', 'Time step', 'class']]
feat_df = data_df[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)
x_np = feat_df.astype(np.float32).values

# Normalize
mean = x_np.mean(axis=0)
std = x_np.std(axis=0)
std[std < 1e-6] = 1.0
x_np = (x_np - mean) / std

x = torch.from_numpy(x_np).float()
x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

# Labels
timestamps = data_df['Time step'].values
y_raw = data_df['class'].values
y = np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1))
y = torch.LongTensor(y)

print(f"\n✓ Features: {x.shape}")
print(f"✓ Labels: {y.shape}")

In [None]:
# Build graph
tx_ids = data_df['txId'].values
tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}

valid_edges = edges_df[
    edges_df['txId1'].isin(tx_id_to_idx) & 
    edges_df['txId2'].isin(tx_id_to_idx)
]

edge_src = valid_edges['txId1'].map(tx_id_to_idx).values
edge_dst = valid_edges['txId2'].map(tx_id_to_idx).values
edge_index = torch.LongTensor(np.vstack([edge_src, edge_dst]))

# Add self-loops
num_nodes = len(data_df)
self_loop_src = np.arange(num_nodes)
self_loop_dst = np.arange(num_nodes)
edge_index = torch.cat([
    edge_index,
    torch.LongTensor(np.vstack([self_loop_src, self_loop_dst]))
], dim=1)

print(f"✓ Edges (with self-loops): {edge_index.shape[1]:,}")

In [None]:
# Temporal splits
sorted_times = np.sort(np.unique(timestamps))
n_timesteps = len(sorted_times)

train_end_idx = int(n_timesteps * 0.6)
val_end_idx = int(n_timesteps * 0.8)

train_time_end = sorted_times[train_end_idx - 1]
val_time_end = sorted_times[val_end_idx - 1]

labeled_mask = y >= 0
train_mask = torch.BoolTensor((timestamps <= train_time_end) & labeled_mask.numpy())
val_mask = torch.BoolTensor((timestamps > train_time_end) & (timestamps <= val_time_end) & labeled_mask.numpy())
test_mask = torch.BoolTensor((timestamps > val_time_end) & labeled_mask.numpy())

print(f"Train: {train_mask.sum():,}")
print(f"Val:   {val_mask.sum():,}")
print(f"Test:  {test_mask.sum():,}")

## 5. Define Models

In [None]:
# GraphSAGE Model
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=2, num_layers=2, dropout=0.4):
        super(GraphSAGE, self).__init__()
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        if num_layers > 1:
            self.convs.append(SAGEConv(hidden_channels, out_channels))
        else:
            self.convs[0] = SAGEConv(in_channels, out_channels)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
    
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

# GAT Model
class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=2, num_layers=2, heads=4, dropout=0.4):
        super(GAT, self).__init__()
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout, concat=True))
        
        input_dim = hidden_channels * heads
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(input_dim, hidden_channels, heads=heads, dropout=dropout, concat=True))
        
        if num_layers > 1:
            self.convs.append(GATConv(input_dim, out_channels, heads=heads, dropout=dropout, concat=False))
        else:
            self.convs[0] = GATConv(in_channels, out_channels, heads=heads, dropout=dropout, concat=False)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
    
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

print("✓ Models defined")

## 6. Train GraphSAGE

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Move data to device
x = x.to(device)
y = y.to(device)
edge_index = edge_index.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

print(f"✓ Data on {device}")

In [None]:
# Initialize GraphSAGE
model_sage = GraphSAGE(in_channels=x.shape[1], hidden_channels=128, out_channels=2, num_layers=2, dropout=0.4)
model_sage = model_sage.to(device)

print(f"✓ GraphSAGE params: {sum(p.numel() for p in model_sage.parameters()):,}")

optimizer_sage = torch.optim.Adam(model_sage.parameters(), lr=0.001, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()

print("\nTraining GraphSAGE...")

best_val_pr_auc = 0
best_epoch = 0
patience = 15
patience_counter = 0
history_sage = {'train_loss': [], 'val_loss': [], 'val_pr_auc': []}

for epoch in range(100):
    # Train
    model_sage.train()
    optimizer_sage.zero_grad()
    out = model_sage(x, edge_index)
    
    if torch.isnan(out).any():
        print(f"Epoch {epoch+1}: NaN, skipping")
        continue
    
    loss = criterion(out[train_mask], y[train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model_sage.parameters(), 1.0)
    optimizer_sage.step()
    
    # Validate
    model_sage.eval()
    with torch.no_grad():
        out = model_sage(x, edge_index)
        if torch.isnan(out).any():
            continue
        
        val_loss = criterion(out[val_mask], y[val_mask]).item()
        val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
        val_labels = y[val_mask].cpu().numpy()
        
        if np.isnan(val_probs).any():
            continue
        
        val_pr_auc = average_precision_score(val_labels, val_probs)
    
    history_sage['train_loss'].append(loss.item())
    history_sage['val_loss'].append(val_loss)
    history_sage['val_pr_auc'].append(val_pr_auc)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d}: Train={loss.item():.4f}, Val={val_loss:.4f}, PR-AUC={val_pr_auc:.4f}")
    
    if val_pr_auc > best_val_pr_auc:
        best_val_pr_auc = val_pr_auc
        best_epoch = epoch
        patience_counter = 0
        best_state_sage = model_sage.state_dict().copy()
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

if best_val_pr_auc > 0:
    model_sage.load_state_dict(best_state_sage)
    print(f"\n✓ GraphSAGE complete! Best Val PR-AUC: {best_val_pr_auc:.4f} at epoch {best_epoch+1}")
else:
    print("\nWARNING: No valid training")

## 7. Train GAT

In [None]:
# Initialize GAT
model_gat = GAT(in_channels=x.shape[1], hidden_channels=64, out_channels=2, num_layers=2, heads=4, dropout=0.4)
model_gat = model_gat.to(device)

print(f"✓ GAT params: {sum(p.numel() for p in model_gat.parameters()):,}")

optimizer_gat = torch.optim.Adam(model_gat.parameters(), lr=0.005, weight_decay=0.0005)

print("\nTraining GAT...")

best_val_pr_auc_gat = 0
best_epoch_gat = 0
patience_counter_gat = 0
history_gat = {'train_loss': [], 'val_loss': [], 'val_pr_auc': []}

for epoch in range(100):
    # Train
    model_gat.train()
    optimizer_gat.zero_grad()
    out = model_gat(x, edge_index)
    
    if torch.isnan(out).any():
        print(f"Epoch {epoch+1}: NaN, skipping")
        continue
    
    loss = criterion(out[train_mask], y[train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model_gat.parameters(), 1.0)
    optimizer_gat.step()
    
    # Validate
    model_gat.eval()
    with torch.no_grad():
        out = model_gat(x, edge_index)
        if torch.isnan(out).any():
            continue
        
        val_loss = criterion(out[val_mask], y[val_mask]).item()
        val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
        val_labels = y[val_mask].cpu().numpy()
        
        if np.isnan(val_probs).any():
            continue
        
        val_pr_auc = average_precision_score(val_labels, val_probs)
    
    history_gat['train_loss'].append(loss.item())
    history_gat['val_loss'].append(val_loss)
    history_gat['val_pr_auc'].append(val_pr_auc)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d}: Train={loss.item():.4f}, Val={val_loss:.4f}, PR-AUC={val_pr_auc:.4f}")
    
    if val_pr_auc > best_val_pr_auc_gat:
        best_val_pr_auc_gat = val_pr_auc
        best_epoch_gat = epoch
        patience_counter_gat = 0
        best_state_gat = model_gat.state_dict().copy()
    else:
        patience_counter_gat += 1
    
    if patience_counter_gat >= patience:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

if best_val_pr_auc_gat > 0:
    model_gat.load_state_dict(best_state_gat)
    print(f"\n✓ GAT complete! Best Val PR-AUC: {best_val_pr_auc_gat:.4f} at epoch {best_epoch_gat+1}")
else:
    print("\nWARNING: No valid training")

## 8. Evaluate Both Models

In [None]:
def evaluate_model(model, name):
    model.eval()
    with torch.no_grad():
        out = model(x, edge_index)
        test_probs = F.softmax(out[test_mask], dim=1)[:, 1].cpu().numpy()
        test_labels = y[test_mask].cpu().numpy()
        
        val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
        val_labels = y[val_mask].cpu().numpy()
    
    # Metrics
    test_pr_auc = average_precision_score(test_labels, test_probs)
    test_roc_auc = roc_auc_score(test_labels, test_probs)
    
    # Threshold
    precision, recall, thresholds = precision_recall_curve(val_labels, val_probs)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_threshold = thresholds[np.argmax(f1_scores)] if len(thresholds) > 0 else 0.5
    
    test_preds = (test_probs >= best_threshold).astype(int)
    test_f1 = f1_score(test_labels, test_preds)
    
    # Recall@K
    def recall_at_k(y_true, y_score, k_frac):
        k = max(1, int(len(y_true) * k_frac))
        top_k_idx = np.argsort(y_score)[-k:]
        return y_true[top_k_idx].sum() / y_true.sum()
    
    print(f"\n{'='*60}")
    print(f"{name} TEST RESULTS")
    print(f"{'='*60}")
    print(f"PR-AUC:      {test_pr_auc:.4f}")
    print(f"ROC-AUC:     {test_roc_auc:.4f}")
    print(f"F1 Score:    {test_f1:.4f}")
    print(f"Threshold:   {best_threshold:.4f}")
    print(f"Recall@0.5%: {recall_at_k(test_labels, test_probs, 0.005):.4f}")
    print(f"Recall@1.0%: {recall_at_k(test_labels, test_probs, 0.01):.4f}")
    print(f"Recall@2.0%: {recall_at_k(test_labels, test_probs, 0.02):.4f}")
    
    return {
        'pr_auc': float(test_pr_auc),
        'roc_auc': float(test_roc_auc),
        'f1': float(test_f1),
        'threshold': float(best_threshold),
        'recall@0.5%': float(recall_at_k(test_labels, test_probs, 0.005)),
        'recall@1.0%': float(recall_at_k(test_labels, test_probs, 0.01)),
        'recall@2.0%': float(recall_at_k(test_labels, test_probs, 0.02))
    }

metrics_sage = evaluate_model(model_sage, "GraphSAGE")
metrics_gat = evaluate_model(model_gat, "GAT")

## 9. Save Results

In [None]:
# Save metrics
with open('graphsage_metrics.json', 'w') as f:
    json.dump(metrics_sage, f, indent=2)

with open('gat_metrics.json', 'w') as f:
    json.dump(metrics_gat, f, indent=2)

# Save models
torch.save(model_sage.state_dict(), 'graphsage_best.pt')
torch.save(model_gat.state_dict(), 'gat_best.pt')

print("\n✓ All results saved!")
print("\nDownload:")
print("  - graphsage_metrics.json")
print("  - gat_metrics.json")
print("  - graphsage_best.pt")
print("  - gat_best.pt")