# GCN Fraud Detection - Kaggle GPU Training

This notebook trains a GCN model on the Elliptic++ dataset for Bitcoin fraud detection.

**Requirements:**
- GPU enabled (Settings → Accelerator → GPU T4 x2)
- Elliptic dataset linked (Add Data → elliptic-fraud-detection)

**Expected Runtime:** ~10-15 minutes

## 1. Install Dependencies

In [None]:
!pip install torch-geometric -q
print("✓ Dependencies installed")

## 2. Imports & Setup

In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

from sklearn.metrics import (
    average_precision_score,
    roc_auc_score,
    f1_score,
    precision_recall_curve,
    roc_curve
)

# Set style
sns.set_style('whitegrid')

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")

## 3. Set Seed for Reproducibility

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("✓ Seed set to 42")

## 4. Load Dataset

In [None]:
# Dataset path (adjust if needed)
DATA_PATH = '/kaggle/input/elliptic-fraud-detection/'

print("Loading Elliptic++ dataset...")

# Load CSVs
features_df = pd.read_csv(DATA_PATH + 'txs_features.csv')
classes_df = pd.read_csv(DATA_PATH + 'txs_classes.csv')
edges_df = pd.read_csv(DATA_PATH + 'txs_edgelist.csv')

print(f"✓ Features: {features_df.shape}")
print(f"✓ Classes: {classes_df.shape}")
print(f"✓ Edges: {edges_df.shape}")

## 5. Process Data

In [None]:
# Merge features and classes
data_df = features_df.merge(classes_df, on='txId', how='left')
data_df['class'] = data_df['class'].fillna(3).astype(int)

print(f"Total transactions: {len(data_df):,}")
print(f"Fraud (class=1): {(data_df['class'] == 1).sum():,}")
print(f"Legitimate (class=2): {(data_df['class'] == 2).sum():,}")
print(f"Unlabeled (class=3): {(data_df['class'] == 3).sum():,}")

# Extract features
feature_cols = [col for col in data_df.columns if col not in ['txId', 'Time step', 'class']]
x = torch.FloatTensor(data_df[feature_cols].values)

# Normalize features
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)

# Extract timestamps
timestamps = data_df['Time step'].values

# Convert labels: 1=fraud→1, 2=legit→0, 3=unlabeled→-1
y_raw = data_df['class'].values
y = np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1))
y = torch.LongTensor(y)

print(f"\n✓ Features shape: {x.shape}")
print(f"✓ Labels shape: {y.shape}")

## 6. Build Graph

In [None]:
# Create tx_id to index mapping
tx_ids = data_df['txId'].values
tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}

# Build edge index
valid_edges = edges_df[
    edges_df['txId1'].isin(tx_id_to_idx) & 
    edges_df['txId2'].isin(tx_id_to_idx)
]

edge_src = valid_edges['txId1'].map(tx_id_to_idx).values
edge_dst = valid_edges['txId2'].map(tx_id_to_idx).values
edge_index = torch.LongTensor(np.vstack([edge_src, edge_dst]))

print(f"✓ Edge index shape: {edge_index.shape}")
print(f"✓ Total edges: {edge_index.shape[1]:,}")

## 7. Create Temporal Splits

In [None]:
# Temporal split (60/20/20)
sorted_times = np.sort(np.unique(timestamps))
n_timesteps = len(sorted_times)

train_end_idx = int(n_timesteps * 0.6)
val_end_idx = int(n_timesteps * 0.8)

train_time_end = sorted_times[train_end_idx - 1]
val_time_end = sorted_times[val_end_idx - 1]

# Create masks
labeled_mask = y >= 0
train_mask = torch.BoolTensor((timestamps <= train_time_end) & labeled_mask.numpy())
val_mask = torch.BoolTensor((timestamps > train_time_end) & (timestamps <= val_time_end) & labeled_mask.numpy())
test_mask = torch.BoolTensor((timestamps > val_time_end) & labeled_mask.numpy())

print(f"Train: {train_mask.sum():,} nodes (time <= {train_time_end})")
print(f"Val:   {val_mask.sum():,} nodes (time <= {val_time_end})")
print(f"Test:  {test_mask.sum():,} nodes")

# Class balance
train_fraud = (y[train_mask] == 1).sum().item()
val_fraud = (y[val_mask] == 1).sum().item()
test_fraud = (y[test_mask] == 1).sum().item()

print(f"\nTrain: {train_fraud:,} fraud ({100*train_fraud/train_mask.sum():.2f}%)")
print(f"Val:   {val_fraud:,} fraud ({100*val_fraud/val_mask.sum():.2f}%)")
print(f"Test:  {test_fraud:,} fraud ({100*test_fraud/test_mask.sum():.2f}%)")

## 8. Define GCN Model

In [None]:
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=2, num_layers=2, dropout=0.4):
        super(GCN, self).__init__()
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        
        if num_layers > 1:
            self.convs.append(GCNConv(hidden_channels, out_channels))
        else:
            self.convs[0] = GCNConv(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GCN(in_channels=x.shape[1], hidden_channels=128, out_channels=2, num_layers=2, dropout=0.4)
model = model.to(device)

# Move data to device
x = x.to(device)
y = y.to(device)
edge_index = edge_index.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

print(f"✓ Model on device: {device}")
print(f"✓ Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 9. Train Model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()

best_val_pr_auc = 0
best_epoch = 0
patience = 15
patience_counter = 0

history = {'train_loss': [], 'val_loss': [], 'val_pr_auc': []}

print("Starting training...\n")

for epoch in range(100):
    # Train
    model.train()
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = criterion(out[train_mask], y[train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    # Validate
    model.eval()
    with torch.no_grad():
        out = model(x, edge_index)
        val_loss = criterion(out[val_mask], y[val_mask]).item()
        val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
        val_labels = y[val_mask].cpu().numpy()
        val_pr_auc = average_precision_score(val_labels, val_probs)
    
    history['train_loss'].append(loss.item())
    history['val_loss'].append(val_loss)
    history['val_pr_auc'].append(val_pr_auc)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d}: Train Loss={loss.item():.4f}, Val Loss={val_loss:.4f}, Val PR-AUC={val_pr_auc:.4f}")
    
    # Early stopping
    if val_pr_auc > best_val_pr_auc:
        best_val_pr_auc = val_pr_auc
        best_epoch = epoch
        patience_counter = 0
        best_state = model.state_dict()
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch+1}")
        print(f"Best Val PR-AUC: {best_val_pr_auc:.4f} at epoch {best_epoch+1}")
        break

# Load best model
model.load_state_dict(best_state)
print("\n✓ Training complete!")

## 10. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['val_loss'], label='Val', linewidth=2)
axes[0].axvline(best_epoch, color='r', linestyle='--', alpha=0.5)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training History')
axes[0].legend()
axes[0].grid(alpha=0.3)

# PR-AUC
axes[1].plot(history['val_pr_auc'], color='green', linewidth=2)
axes[1].axvline(best_epoch, color='r', linestyle='--', alpha=0.5)
axes[1].axhline(best_val_pr_auc, color='g', linestyle=':', alpha=0.5)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('PR-AUC')
axes[1].set_title('Validation PR-AUC')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('gcn_training_history.png', dpi=150, bbox_inches='tight')
plt.show()
print("✓ Saved: gcn_training_history.png")

## 11. Evaluate on Test Set

In [None]:
model.eval()
with torch.no_grad():
    out = model(x, edge_index)
    test_probs = F.softmax(out[test_mask], dim=1)[:, 1].cpu().numpy()
    test_labels = y[test_mask].cpu().numpy()

# Compute metrics
test_pr_auc = average_precision_score(test_labels, test_probs)
test_roc_auc = roc_auc_score(test_labels, test_probs)

# Find best threshold on validation
precision_val, recall_val, thresholds = precision_recall_curve(val_labels, val_probs)
f1_scores = 2 * (precision_val * recall_val) / (precision_val + recall_val + 1e-8)
best_threshold = thresholds[np.argmax(f1_scores)]

test_preds = (test_probs >= best_threshold).astype(int)
test_f1 = f1_score(test_labels, test_preds)

# Recall@K
def recall_at_k(y_true, y_score, k_frac=0.01):
    k = max(1, int(len(y_true) * k_frac))
    top_k_idx = np.argsort(y_score)[-k:]
    return y_true[top_k_idx].sum() / y_true.sum()

recall_05 = recall_at_k(test_labels, test_probs, 0.005)
recall_10 = recall_at_k(test_labels, test_probs, 0.01)
recall_20 = recall_at_k(test_labels, test_probs, 0.02)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"PR-AUC:      {test_pr_auc:.4f} ⭐ (primary)")
print(f"ROC-AUC:     {test_roc_auc:.4f}")
print(f"F1 Score:    {test_f1:.4f}")
print(f"Threshold:   {best_threshold:.4f}")
print(f"\nRecall@0.5%: {recall_05:.4f}")
print(f"Recall@1.0%: {recall_10:.4f}")
print(f"Recall@2.0%: {recall_20:.4f}")
print("="*60)

## 12. Plot PR and ROC Curves

In [None]:
precision_test, recall_test, _ = precision_recall_curve(test_labels, test_probs)
fpr, tpr, _ = roc_curve(test_labels, test_probs)

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# PR Curve
axes[0].plot(recall_test, precision_test, linewidth=2.5, label=f'GCN (PR-AUC={test_pr_auc:.4f})')
axes[0].axhline(test_labels.mean(), color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Baseline')
axes[0].set_xlabel('Recall', fontsize=12)
axes[0].set_ylabel('Precision', fontsize=12)
axes[0].set_title('Precision-Recall Curve (Test Set)', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(alpha=0.3)

# ROC Curve
axes[1].plot(fpr, tpr, linewidth=2.5, label=f'GCN (ROC-AUC={test_roc_auc:.4f})')
axes[1].plot([0, 1], [0, 1], color='red', linestyle='--', linewidth=1.5, alpha=0.7, label='Baseline')
axes[1].set_xlabel('False Positive Rate', fontsize=12)
axes[1].set_ylabel('True Positive Rate', fontsize=12)
axes[1].set_title('ROC Curve (Test Set)', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('gcn_pr_roc_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("✓ Saved: gcn_pr_roc_curves.png")

## 13. Save Results

In [None]:
# Save metrics
metrics = {
    'pr_auc': float(test_pr_auc),
    'roc_auc': float(test_roc_auc),
    'f1': float(test_f1),
    'threshold': float(best_threshold),
    'recall@0.5%': float(recall_05),
    'recall@1.0%': float(recall_10),
    'recall@2.0%': float(recall_20),
    'best_epoch': int(best_epoch + 1),
    'best_val_pr_auc': float(best_val_pr_auc)
}

with open('gcn_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print("✓ Saved: gcn_metrics.json")

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'metrics': metrics,
    'config': {'in_channels': x.shape[1], 'hidden_channels': 128, 'num_layers': 2}
}, 'gcn_best.pt')

print("✓ Saved: gcn_best.pt")
print("\n" + "="*60)
print("ALL RESULTS SAVED!")
print("="*60)
print("\nDownload these files:")
print("  - gcn_metrics.json")
print("  - gcn_training_history.png")
print("  - gcn_pr_roc_curves.png")
print("  - gcn_best.pt")