# GCN Baseline for Fraud Detection

This notebook trains a Graph Convolutional Network (GCN) on the Elliptic++ dataset for Bitcoin transaction fraud detection.

**Goal:** Establish a reproducible GCN baseline with temporal splits and honest evaluation.

---

## Notebook TODO
- [x] Load real Elliptic++ from `data/elliptic/`
- [x] Set seeds + deterministic flags
- [x] Train GCN model end-to-end
- [x] Save: `reports/metrics.json`, `reports/plots/*.png`, append `reports/metrics_summary.csv`
- [x] Verify metrics + artifacts paths printed in last cell
- [x] Clear TODOs/placeholders before commit

## 1. Setup & Imports

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add src to path
sys.path.append('..')

from src.data import EllipticDataset
from src.models.gcn import GCN, GCNTrainer
from src.utils.seed import set_all_seeds
from src.utils.metrics import (
    compute_metrics,
    find_best_f1_threshold,
    compute_recall_at_k
)
from src.utils.logger import save_metrics_json, append_metrics_to_csv

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

print("‚úÖ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Set Reproducibility

In [None]:
SEED = 42
set_all_seeds(SEED)
print(f"‚úÖ All seeds set to {SEED}")
print("‚úÖ Deterministic operations enabled")

## 3. Load Dataset

In [None]:
# Load Elliptic++ dataset
dataset = EllipticDataset(root='../data/elliptic')
data = dataset.load(verbose=True)

print("\nüìä Dataset Summary:")
print(f"   Nodes: {data.x.shape[0]:,}")
print(f"   Edges: {data.edge_index.shape[1]:,}")
print(f"   Features: {data.x.shape[1]}")
print(f"   Train nodes: {data.train_mask.sum():,}")
print(f"   Val nodes: {data.val_mask.sum():,}")
print(f"   Test nodes: {data.test_mask.sum():,}")

## 4. Initialize Model

In [None]:
# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Model configuration
config = {
    'in_channels': data.x.shape[1],
    'hidden_channels': 128,
    'out_channels': 2,
    'num_layers': 2,
    'dropout': 0.4
}

# Initialize model
model = GCN(**config)
print(f"\n‚úÖ GCN Model initialized")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Architecture: {config}")

## 5. Train Model

In [None]:
# Initialize trainer
trainer = GCNTrainer(
    model=model,
    data=data,
    device=device,
    lr=0.001,
    weight_decay=0.0005
)

print("üöÄ Starting training...\n")

# Train with early stopping
history = trainer.fit(
    epochs=100,
    patience=15,
    eval_metric='pr_auc',
    verbose=True
)

print(f"\n‚úÖ Training complete!")
print(f"   Best validation PR-AUC: {trainer.best_val_metric:.4f}")
print(f"   Best epoch: {trainer.best_epoch + 1}")

## 6. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].axvline(trainer.best_epoch, color='r', linestyle='--', alpha=0.5, label='Best Epoch')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Metric plot
axes[1].plot(history['val_metric'], label='Val PR-AUC', linewidth=2, color='green')
axes[1].axvline(trainer.best_epoch, color='r', linestyle='--', alpha=0.5, label='Best Epoch')
axes[1].axhline(trainer.best_val_metric, color='g', linestyle=':', alpha=0.5, label=f'Best: {trainer.best_val_metric:.4f}')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PR-AUC')
axes[1].set_title('Validation PR-AUC Over Time')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/plots/gcn_training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Training history plot saved")

## 7. Evaluate on Test Set

In [None]:
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
    average_precision_score,
    roc_auc_score
)

# Get test predictions
test_loss, test_preds, test_probs = trainer.evaluate(data.test_mask)
test_labels = data.y[data.test_mask].cpu().numpy()
test_probs_fraud = test_probs[:, 1].cpu().numpy()

print("üìä Test Set Evaluation:")
print(f"   Test Loss: {test_loss:.4f}")

# Find best threshold on validation set
val_loss, val_preds, val_probs = trainer.evaluate(data.val_mask)
val_labels = data.y[data.val_mask].cpu().numpy()
val_probs_fraud = val_probs[:, 1].cpu().numpy()

best_threshold, best_f1_val = find_best_f1_threshold(val_labels, val_probs_fraud)
print(f"\n   Best threshold (from val): {best_threshold:.4f}")
print(f"   Val F1 at best threshold: {best_f1_val:.4f}")

## 8. Compute Metrics

In [None]:
# Compute comprehensive metrics on test set
test_metrics = compute_metrics(test_labels, test_probs_fraud, threshold=best_threshold)
recall_at_k = compute_recall_at_k(test_labels, test_probs_fraud, k_fracs=[0.005, 0.01, 0.02])

# Combine metrics
test_metrics.update(recall_at_k)

print("\nüìà Test Set Metrics:")
print(f"   PR-AUC:      {test_metrics['pr_auc']:.4f}")
print(f"   ROC-AUC:     {test_metrics['roc_auc']:.4f}")
print(f"   F1 Score:    {test_metrics['f1']:.4f}")
print(f"   Threshold:   {test_metrics['threshold']:.4f}")
print(f"\n   Recall@0.5%: {test_metrics['recall@0.5%']:.4f}")
print(f"   Recall@1.0%: {test_metrics['recall@1.0%']:.4f}")
print(f"   Recall@2.0%: {test_metrics['recall@2.0%']:.4f}")

## 9. Plot PR and ROC Curves

In [None]:
# Compute curves
precision, recall, pr_thresholds = precision_recall_curve(test_labels, test_probs_fraud)
fpr, tpr, roc_thresholds = roc_curve(test_labels, test_probs_fraud)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# PR Curve
axes[0].plot(recall, precision, linewidth=2.5, label=f'GCN (PR-AUC={test_metrics["pr_auc"]:.4f})')
axes[0].axhline(test_labels.mean(), color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Baseline (random)')
axes[0].set_xlabel('Recall', fontsize=12)
axes[0].set_ylabel('Precision', fontsize=12)
axes[0].set_title('Precision-Recall Curve (Test Set)', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# ROC Curve
axes[1].plot(fpr, tpr, linewidth=2.5, label=f'GCN (ROC-AUC={test_metrics["roc_auc"]:.4f})')
axes[1].plot([0, 1], [0, 1], color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Baseline (random)')
axes[1].set_xlabel('False Positive Rate', fontsize=12)
axes[1].set_ylabel('True Positive Rate', fontsize=12)
axes[1].set_title('ROC Curve (Test Set)', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/plots/gcn_pr_roc_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ PR and ROC curves saved")

## 10. Save Artifacts

In [None]:
# Save model checkpoint
checkpoint_path = Path('../checkpoints/gcn_best.pt')
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'metrics': test_metrics,
    'best_epoch': trainer.best_epoch,
    'seed': SEED
}, checkpoint_path)
print(f"‚úÖ Model checkpoint saved: {checkpoint_path}")

# Save metrics JSON
metrics_json_path = Path('../reports/gcn_metrics.json')
save_metrics_json(test_metrics, metrics_json_path)
print(f"‚úÖ Metrics JSON saved: {metrics_json_path}")

# Append to summary CSV
summary_csv_path = Path('../reports/metrics_summary.csv')
append_metrics_to_csv(
    metrics=test_metrics,
    filepath=summary_csv_path,
    experiment_name='elliptic-gnn-baselines',
    model_name='GCN',
    split='test'
)
print(f"‚úÖ Results appended to: {summary_csv_path}")

## 11. Final Summary

In [None]:
print("="*60)
print("GCN BASELINE - FINAL SUMMARY")
print("="*60)
print(f"\nüìä Dataset:")
print(f"   Total nodes: {data.x.shape[0]:,}")
print(f"   Total edges: {data.edge_index.shape[1]:,}")
print(f"   Features: {data.x.shape[1]}")
print(f"\nüéØ Model: GCN")
print(f"   Hidden channels: {config['hidden_channels']}")
print(f"   Num layers: {config['num_layers']}")
print(f"   Dropout: {config['dropout']}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nüìà Test Results:")
print(f"   PR-AUC:      {test_metrics['pr_auc']:.4f} ‚≠ê (primary metric)")
print(f"   ROC-AUC:     {test_metrics['roc_auc']:.4f}")
print(f"   F1 Score:    {test_metrics['f1']:.4f}")
print(f"   Recall@0.5%: {test_metrics['recall@0.5%']:.4f}")
print(f"   Recall@1.0%: {test_metrics['recall@1.0%']:.4f}")
print(f"   Recall@2.0%: {test_metrics['recall@2.0%']:.4f}")
print(f"\nüìÅ Artifacts Saved:")
print(f"   ‚úÖ {checkpoint_path}")
print(f"   ‚úÖ {metrics_json_path}")
print(f"   ‚úÖ {summary_csv_path}")
print(f"   ‚úÖ ../reports/plots/gcn_training_history.png")
print(f"   ‚úÖ ../reports/plots/gcn_pr_roc_curves.png")
print("\n" + "="*60)
print("‚úÖ GCN BASELINE COMPLETE!")
print("="*60)