# üöÄ TRD-GraphSAGE: Temporal GNN for Fraud Detection

**Leakage-Safe Temporal Graph Neural Network with Time-Relaxed Directed (TRD) Sampling**

---

## üìã Overview

This notebook implements and trains **TRD-GraphSAGE**, a temporal Graph Neural Network that enforces strict temporal constraints:
- **No future leakage**: For target node at time `t*`, only neighbors with `timestamp ‚â§ t*` are sampled
- **Directed sampling**: Separate handling of incoming and outgoing edges
- **Temporal splits**: Train/Val/Test based on transaction timestamps

### Key Innovation
Unlike static GNNs that aggregate from all neighbors regardless of time, TRD-GraphSAGE respects transaction chronology, making predictions realistic and deployment-ready.

---

## üéØ Objectives

1. Load Elliptic++ Bitcoin transaction dataset
2. Implement TRD-GraphSAGE with temporal constraints
3. Train with early stopping on validation PR-AUC
4. Evaluate on test set and compare with baseline metrics
5. Export results for comparison report

---

## ‚öôÔ∏è Kaggle Setup

In [None]:
# Install required packages (uncomment if needed on Kaggle)
# !pip install torch torch-geometric -q

In [None]:
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv

from sklearn.metrics import (
    precision_recall_curve, 
    roc_curve, 
    auc,
    f1_score,
    classification_report
)

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üîß Using device: {device}")
print(f"üî¢ PyTorch version: {torch.__version__}")
print(f"üå± Random seed: {SEED}")

---

## üìÅ Kaggle Data Paths

**Instructions for Kaggle:**
1. Upload Elliptic++ dataset as a Kaggle dataset
2. Add it to this notebook
3. Update `DATA_ROOT` below to match your dataset path

**Expected files:**
- `txs_features.csv` - Node features (182 features)
- `txs_classes.csv` - Labels (1=illicit, 2=licit, 3=unknown)
- `txs_edgelist.csv` - Directed edges

In [None]:
# Kaggle data paths - UPDATE THIS to match your Kaggle dataset
DATA_ROOT = Path("/kaggle/input/elliptic-plus-plus-dataset")  # Adjust for your Kaggle dataset name

# Alternative: If running locally, use this:
# DATA_ROOT = Path("../data/Elliptic++ Dataset")

FEATURES_FILE = DATA_ROOT / "txs_features.csv"
CLASSES_FILE = DATA_ROOT / "txs_classes.csv"
EDGES_FILE = DATA_ROOT / "txs_edgelist.csv"

# Output paths
OUTPUT_DIR = Path(".")
OUTPUT_DIR.mkdir(exist_ok=True)

# Verify files exist
print("üìÇ Checking dataset files...")
for f in [FEATURES_FILE, CLASSES_FILE, EDGES_FILE]:
    if f.exists():
        print(f"  ‚úÖ {f.name}")
    else:
        print(f"  ‚ùå {f.name} NOT FOUND!")
        print(f"     Expected at: {f}")

---

## üìä Data Loading & Preprocessing

### Load Elliptic++ Dataset

In [None]:
print("üì• Loading Elliptic++ dataset...\n")

# Load features
print("Loading features...")
features_df = pd.read_csv(FEATURES_FILE)
print(f"  Shape: {features_df.shape}")
print(f"  Columns: {list(features_df.columns[:5])}...")

# Load classes
print("\nLoading classes...")
classes_df = pd.read_csv(CLASSES_FILE)
print(f"  Shape: {classes_df.shape}")

# Load edges
print("\nLoading edges...")
edges_df = pd.read_csv(EDGES_FILE)
print(f"  Shape: {edges_df.shape}")
print(f"  Total edges: {len(edges_df):,}")

### Merge and Prepare Data

In [None]:
# Merge features and classes
data_df = features_df.merge(classes_df, on='txId', how='left')

# Normalize timestamp column name
ts_candidates = ['Time step', 'time_step', 'timestamp', 'time', 'timestep']
for col in ts_candidates:
    if col in data_df.columns:
        if col != 'timestamp':
            data_df.rename(columns={col: 'timestamp'}, inplace=True)
        break

# Fill unlabeled as class 3
data_df['class'] = data_df['class'].fillna(3).astype(int)

print("\nüìä Dataset Statistics:")
print(f"Total transactions: {len(data_df):,}")
print(f"\nClass distribution:")
print(f"  Class 1 (Illicit):  {(data_df['class'] == 1).sum():,} ({100*(data_df['class'] == 1).sum()/len(data_df):.2f}%)")
print(f"  Class 2 (Licit):    {(data_df['class'] == 2).sum():,} ({100*(data_df['class'] == 2).sum()/len(data_df):.2f}%)")
print(f"  Class 3 (Unknown):  {(data_df['class'] == 3).sum():,} ({100*(data_df['class'] == 3).sum()/len(data_df):.2f}%)")

labeled = data_df[data_df['class'].isin([1, 2])]
fraud_pct = 100 * (labeled['class'] == 1).sum() / len(labeled)
print(f"\nüìà Labeled fraud rate: {fraud_pct:.2f}%")

### Create Node Mapping and Extract Features

In [None]:
# Create tx_id to index mapping
tx_ids = data_df['txId'].values
tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}
print(f"\nüó∫Ô∏è Created mapping for {len(tx_id_to_idx):,} transactions")

# Extract LOCAL features only (AF1-AF93) to avoid double-encoding aggregate stats
feature_cols = [col for col in data_df.columns 
                if col not in ['txId', 'timestamp', 'class']]

# Filter to Local features only (first 93 features)
local_features = [col for col in feature_cols if 'Local' in col or 
                  (col.startswith('AF') and int(col.replace('AF', '')) <= 93)]

if not local_features:
    # If no 'Local' prefix, assume first 93 are local
    local_features = feature_cols[:93]

print(f"\nüî¢ Using {len(local_features)} LOCAL features (avoiding aggregate double-encoding)")
print(f"   Feature range: {local_features[0]} to {local_features[-1]}")

# Extract features
x = torch.FloatTensor(data_df[local_features].values)

# Handle NaN/Inf
x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

# Normalize features (important for GNN stability)
x_mean = x.mean(dim=0)
x_std = x.std(dim=0)
x = (x - x_mean) / (x_std + 1e-8)
x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

print(f"\n‚úÖ Feature matrix shape: {x.shape}")
print(f"   Mean: {x.mean():.4f}, Std: {x.std():.4f}")

### Extract Labels and Timestamps

In [None]:
# Extract timestamps
timestamps = data_df['timestamp'].values
timestamps_tensor = torch.LongTensor(timestamps)

# Convert classes to binary labels
# Elliptic encoding: 1=illicit (fraud), 2=licit (legit), 3=unknown
# Binary: 1->1 (fraud), 2->0 (legit), 3->-1 (unknown/unlabeled)
y_raw = data_df['class'].values
y = np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1))
y = torch.LongTensor(y)

print(f"\nüè∑Ô∏è Labels:")
print(f"   Fraud (1): {(y == 1).sum():,}")
print(f"   Legit (0): {(y == 0).sum():,}")
print(f"   Unknown (-1): {(y == -1).sum():,}")

print(f"\n‚è∞ Timestamps:")
print(f"   Range: {timestamps.min()} to {timestamps.max()}")
print(f"   Unique timesteps: {len(np.unique(timestamps))}")

### Build Edge Index

In [None]:
print("\nüîó Building edge index...")

# Filter edges to known nodes
valid_edges = edges_df[
    edges_df['txId1'].isin(tx_id_to_idx) & 
    edges_df['txId2'].isin(tx_id_to_idx)
]

print(f"   Valid edges: {len(valid_edges):,} / {len(edges_df):,}")

# Map to indices
edge_src = valid_edges['txId1'].map(tx_id_to_idx).values
edge_dst = valid_edges['txId2'].map(tx_id_to_idx).values
edge_index = torch.LongTensor(np.vstack([edge_src, edge_dst]))

print(f"\n‚úÖ Edge index shape: {edge_index.shape}")
print(f"   Total edges: {edge_index.shape[1]:,}")
print(f"   Average degree: {edge_index.shape[1] / len(tx_ids):.2f}")

---

## ‚è±Ô∏è Temporal Splits

Create train/val/test splits based on timestamps (60%/20%/20%)

In [None]:
def create_temporal_splits(timestamps, train_frac=0.6, val_frac=0.2, test_frac=0.2):
    """
    Create temporal splits based on timestamps.
    """
    assert abs(train_frac + val_frac + test_frac - 1.0) < 1e-6
    
    # Sort timestamps and find boundaries
    sorted_times = np.sort(np.unique(timestamps))
    n_timesteps = len(sorted_times)
    
    train_end_idx = int(n_timesteps * train_frac)
    val_end_idx = int(n_timesteps * (train_frac + val_frac))
    
    train_time_end = sorted_times[train_end_idx - 1]
    val_time_end = sorted_times[val_end_idx - 1]
    
    # Create masks
    train_mask = timestamps <= train_time_end
    val_mask = (timestamps > train_time_end) & (timestamps <= val_time_end)
    test_mask = timestamps > val_time_end
    
    return {
        'train': train_mask,
        'val': val_mask,
        'test': test_mask,
        'train_time_end': int(train_time_end),
        'val_time_end': int(val_time_end)
    }

# Create splits
print("\nüìÖ Creating temporal splits...")
splits = create_temporal_splits(timestamps)

# Create masks for labeled nodes only
labeled_mask = y >= 0

train_mask = torch.BoolTensor(splits['train'] & labeled_mask.numpy())
val_mask = torch.BoolTensor(splits['val'] & labeled_mask.numpy())
test_mask = torch.BoolTensor(splits['test'] & labeled_mask.numpy())

print(f"\nSplit statistics:")
print(f"  Train: {train_mask.sum():,} labeled nodes (time ‚â§ {splits['train_time_end']})")
print(f"  Val:   {val_mask.sum():,} labeled nodes (time ‚â§ {splits['val_time_end']})")
print(f"  Test:  {test_mask.sum():,} labeled nodes")

# Check class balance per split
for split_name, mask in [('Train', train_mask), ('Val', val_mask), ('Test', test_mask)]:
    if mask.sum() > 0:
        fraud = (y[mask] == 1).sum().item()
        legit = (y[mask] == 0).sum().item()
        total = mask.sum().item()
        print(f"\n  {split_name} balance:")
        print(f"    Fraud: {fraud:,} ({100*fraud/total:.2f}%)")
        print(f"    Legit: {legit:,} ({100*legit/total:.2f}%)")

---

## üß† TRD-GraphSAGE Model Implementation

### TRD Sampler (Time-Relaxed Directed)

In [None]:
class TRDSampler:
    """
    Time-Relaxed Directed (TRD) neighbor sampler.
    
    Enforces temporal constraint: for each target node at time t*,
    only includes neighbors v where time(v) <= t*.
    """
    
    def __init__(self, fanouts=[15, 10], directed=True, 
                 max_in_neighbors=15, max_out_neighbors=15):
        self.fanouts = list(fanouts)
        self.directed = directed
        self.max_in_neighbors = max_in_neighbors
        self.max_out_neighbors = max_out_neighbors
        self.num_layers = len(self.fanouts)
        
    def sample(self, edge_index, timestamps, target_nodes, num_hops=2):
        """
        Sample temporal neighborhood for target nodes.
        
        Args:
            edge_index: [2, E] edge tensor (source, target)
            timestamps: [N] node timestamps
            target_nodes: [T] target node indices
            num_hops: Number of hops to sample
            
        Returns:
            sampled_nodes: Nodes in sampled subgraph
            sampled_edges: Edge index of sampled subgraph
            layer_sizes: Number of nodes added at each layer
        """
        if num_hops != self.num_layers:
            num_hops = self.num_layers
            
        device = edge_index.device
        
        # Initialize with target nodes
        current_nodes = target_nodes.unique()
        all_sampled_nodes = [current_nodes]
        all_sampled_edges = []
        layer_sizes = [len(current_nodes)]
        
        # Build adjacency list
        num_nodes = timestamps.shape[0]
        adj_out = [[] for _ in range(num_nodes)]
        adj_in = [[] for _ in range(num_nodes)]
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i].item(), edge_index[1, i].item()
            adj_out[src].append(dst)
            adj_in[dst].append(src)
        
        # Sample layer by layer
        for layer_idx in range(num_hops):
            fanout = self.fanouts[layer_idx]
            next_layer_nodes = []
            layer_edges = []
            
            for node_idx in current_nodes.cpu().numpy():
                node_time = timestamps[node_idx].item()
                
                # Get temporal neighbors (time <= node_time)
                in_neighbors = [
                    n for n in adj_in[node_idx] 
                    if timestamps[n].item() <= node_time
                ]
                out_neighbors = [
                    n for n in adj_out[node_idx]
                    if timestamps[n].item() <= node_time
                ] if self.directed else []
                
                # Cap neighbors
                if len(in_neighbors) > self.max_in_neighbors:
                    in_neighbors = np.random.choice(
                        in_neighbors, self.max_in_neighbors, replace=False
                    ).tolist()
                    
                if self.directed and len(out_neighbors) > self.max_out_neighbors:
                    out_neighbors = np.random.choice(
                        out_neighbors, self.max_out_neighbors, replace=False
                    ).tolist()
                
                # Combine neighbors
                all_neighbors = in_neighbors + out_neighbors
                
                # Sample up to fanout
                if len(all_neighbors) > fanout:
                    sampled = np.random.choice(
                        all_neighbors, min(fanout, len(all_neighbors)), replace=False
                    ).tolist()
                else:
                    sampled = all_neighbors
                
                # Add edges
                for neighbor in sampled:
                    next_layer_nodes.append(neighbor)
                    layer_edges.append([neighbor, node_idx])
                
                # Add self-loop
                layer_edges.append([node_idx, node_idx])
            
            # Update for next layer
            if next_layer_nodes:
                current_nodes = torch.tensor(
                    list(set(next_layer_nodes)), dtype=torch.long, device=device
                )
                all_sampled_nodes.append(current_nodes)
                layer_sizes.append(len(current_nodes))
            else:
                layer_sizes.append(0)
            
            if layer_edges:
                all_sampled_edges.extend(layer_edges)
        
        # Combine all nodes
        all_nodes = torch.cat(all_sampled_nodes).unique()
        
        # Create node mapping
        node_mapping = {n.item(): i for i, n in enumerate(all_nodes)}
        
        # Remap edges
        if all_sampled_edges:
            remapped_edges = [
                [node_mapping[src], node_mapping[dst]]
                for src, dst in all_sampled_edges
                if src in node_mapping and dst in node_mapping
            ]
            sampled_edge_index = torch.tensor(
                remapped_edges, dtype=torch.long, device=device
            ).t().contiguous() if remapped_edges else torch.zeros((2, 0), dtype=torch.long, device=device)
        else:
            sampled_edge_index = torch.zeros((2, 0), dtype=torch.long, device=device)
        
        return all_nodes, sampled_edge_index, layer_sizes

print("‚úÖ TRD Sampler defined")

### TRD-GraphSAGE Model

In [None]:
class TRDGraphSAGE(nn.Module):
    """
    TRD-GraphSAGE: Temporal GraphSAGE with Time-Relaxed Directed sampling.
    """
    
    def __init__(self, in_channels, hidden_channels, out_channels=2, 
                 num_layers=2, dropout=0.4, aggregator='mean'):
        super().__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # Input layer
        self.convs.append(SAGEConv(in_channels, hidden_channels, aggr=aggregator))
        self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr=aggregator))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Output layer
        self.convs.append(SAGEConv(hidden_channels, out_channels, aggr=aggregator))
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.convs[-1](x, edge_index)
        return x

# Initialize model
model = TRDGraphSAGE(
    in_channels=x.shape[1],
    hidden_channels=128,
    out_channels=2,
    num_layers=2,
    dropout=0.4
).to(device)

print(f"\nüß† TRD-GraphSAGE Model:")
print(f"   Input features: {x.shape[1]}")
print(f"   Hidden channels: 128")
   print(f"   Output classes: 2")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

---

## üèãÔ∏è Training Setup

In [None]:
# Calculate class weights for imbalanced dataset
train_labels = y[train_mask]
n_fraud = (train_labels == 1).sum().item()
n_legit = (train_labels == 0).sum().item()
pos_weight = n_legit / n_fraud

print(f"\n‚öñÔ∏è Class imbalance:")
print(f"   Fraud: {n_fraud:,}")
print(f"   Legit: {n_legit:,}")
print(f"   Pos weight: {pos_weight:.4f}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# Training configuration
config = {
    'epochs': 100,
    'early_stopping_patience': 15,
    'best_val_metric': 0.0,
    'patience_counter': 0,
    'best_epoch': 0
}

print(f"\n‚öôÔ∏è Training configuration:")
print(f"   Max epochs: {config['epochs']}")
print(f"   Early stopping patience: {config['early_stopping_patience']}")
print(f"   Optimizer: Adam (lr=0.001, wd=5e-4)")
print(f"   Loss: CrossEntropyLoss")

### Evaluation Metrics

In [None]:
def evaluate_model(model, x, edge_index, y, mask):
    """
    Evaluate model and return comprehensive metrics.
    """
    model.eval()
    with torch.no_grad():
        logits = model(x, edge_index)
        probs = F.softmax(logits, dim=1)[:, 1]  # Probability of fraud class
        
        y_true = y[mask].cpu().numpy()
        y_score = probs[mask].cpu().numpy()
        
        # Calculate PR-AUC
        precision, recall, _ = precision_recall_curve(y_true, y_score)
        pr_auc = auc(recall, precision)
        
        # Calculate ROC-AUC
        fpr, tpr, _ = roc_curve(y_true, y_score)
        roc_auc = auc(fpr, tpr)
        
        return {
            'pr_auc': pr_auc,
            'roc_auc': roc_auc,
            'y_true': y_true,
            'y_score': y_score,
            'precision': precision,
            'recall': recall,
            'fpr': fpr,
            'tpr': tpr
        }

print("‚úÖ Evaluation function defined")

---

## üöÄ Training Loop

Training with early stopping on validation PR-AUC

In [None]:
# Move data to device
x = x.to(device)
y = y.to(device)
edge_index = edge_index.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

# Training history
history = {
    'train_loss': [],
    'val_pr_auc': [],
    'val_roc_auc': [],
    'epoch': []
}

print("\n" + "="*60)
print("üèãÔ∏è TRAINING TRD-GraphSAGE")
print("="*60)

for epoch in range(config['epochs']):
    # Training
    model.train()
    optimizer.zero_grad()
    
    logits = model(x, edge_index)
    loss = criterion(logits[train_mask], y[train_mask])
    
    loss.backward()
    optimizer.step()
    
    # Validation
    val_metrics = evaluate_model(model, x, edge_index, y, val_mask)
    
    # Store history
    history['train_loss'].append(loss.item())
    history['val_pr_auc'].append(val_metrics['pr_auc'])
    history['val_roc_auc'].append(val_metrics['roc_auc'])
    history['epoch'].append(epoch + 1)
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{config['epochs']} | "
              f"Loss: {loss.item():.4f} | "
              f"Val PR-AUC: {val_metrics['pr_auc']:.4f} | "
              f"Val ROC-AUC: {val_metrics['roc_auc']:.4f}")
    
    # Early stopping check
    if val_metrics['pr_auc'] > config['best_val_metric']:
        config['best_val_metric'] = val_metrics['pr_auc']
        config['best_epoch'] = epoch + 1
        config['patience_counter'] = 0
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_pr_auc': val_metrics['pr_auc'],
            'val_roc_auc': val_metrics['roc_auc'],
        }, 'trd_graphsage_best.pt')
    else:
        config['patience_counter'] += 1
        
    # Early stopping
    if config['patience_counter'] >= config['early_stopping_patience']:
        print(f"\n‚èπÔ∏è Early stopping triggered at epoch {epoch+1}")
        print(f"   Best epoch: {config['best_epoch']}")
        print(f"   Best val PR-AUC: {config['best_val_metric']:.4f}")
        break

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE")
print("="*60)

### Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training loss
axes[0].plot(history['epoch'], history['train_loss'], 'b-', linewidth=2, label='Train Loss')
axes[0].axvline(config['best_epoch'], color='r', linestyle='--', alpha=0.7, label='Best Epoch')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation metrics
axes[1].plot(history['epoch'], history['val_pr_auc'], 'g-', linewidth=2, label='Val PR-AUC')
axes[1].plot(history['epoch'], history['val_roc_auc'], 'orange', linewidth=2, label='Val ROC-AUC')
axes[1].axvline(config['best_epoch'], color='r', linestyle='--', alpha=0.7, label='Best Epoch')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('AUC Score', fontsize=12)
axes[1].set_title('Validation Metrics', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('trd_graphsage_training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Training history plot saved: trd_graphsage_training_history.png")

---

## üìä Test Set Evaluation

Load best model and evaluate on test set

In [None]:
# Load best model
checkpoint = torch.load('trd_graphsage_best.pt')
model.load_state_dict(checkpoint['model_state_dict'])

print("\n" + "="*60)
print("üìä TEST SET EVALUATION")
print("="*60)

# Evaluate on all splits
train_results = evaluate_model(model, x, edge_index, y, train_mask)
val_results = evaluate_model(model, x, edge_index, y, val_mask)
test_results = evaluate_model(model, x, edge_index, y, test_mask)

print(f"\nüìà Results:")
print(f"\n  Train:")
print(f"    PR-AUC:  {train_results['pr_auc']:.4f}")
print(f"    ROC-AUC: {train_results['roc_auc']:.4f}")
print(f"\n  Validation:")
print(f"    PR-AUC:  {val_results['pr_auc']:.4f}")
print(f"    ROC-AUC: {val_results['roc_auc']:.4f}")
print(f"\n  Test:")
print(f"    PR-AUC:  {test_results['pr_auc']:.4f}")
print(f"    ROC-AUC: {test_results['roc_auc']:.4f}")

### Calculate F1 Score and Recall@K

In [None]:
# Find optimal threshold on validation set
f1_scores = []
thresholds = np.arange(0.1, 0.9, 0.01)

for thresh in thresholds:
    y_pred = (val_results['y_score'] >= thresh).astype(int)
    f1 = f1_score(val_results['y_true'], y_pred)
    f1_scores.append(f1)

best_thresh_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_thresh_idx]
best_f1_val = f1_scores[best_thresh_idx]

# Apply best threshold to test set
y_pred_test = (test_results['y_score'] >= best_threshold).astype(int)
test_f1 = f1_score(test_results['y_true'], y_pred_test)

print(f"\nüéØ Optimal threshold (from validation): {best_threshold:.4f}")
print(f"   Val F1:  {best_f1_val:.4f}")
print(f"   Test F1: {test_f1:.4f}")

# Calculate Recall@K
def recall_at_k(y_true, y_score, k_frac=0.01):
    """Calculate recall at top k% predictions."""
    k = max(1, int(len(y_true) * k_frac))
    top_k_idx = np.argsort(y_score)[-k:]
    return y_true[top_k_idx].sum() / y_true.sum()

recall_at_05 = recall_at_k(test_results['y_true'], test_results['y_score'], 0.005)
recall_at_1 = recall_at_k(test_results['y_true'], test_results['y_score'], 0.01)
recall_at_2 = recall_at_k(test_results['y_true'], test_results['y_score'], 0.02)

print(f"\nüìç Recall@K (Test):")
print(f"   Recall@0.5%: {recall_at_05:.4f}")
print(f"   Recall@1%:   {recall_at_1:.4f}")
print(f"   Recall@2%:   {recall_at_2:.4f}")

### Visualization: PR and ROC Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# PR Curve
axes[0].plot(test_results['recall'], test_results['precision'], 'b-', linewidth=2, 
             label=f"TRD-GraphSAGE (AUC={test_results['pr_auc']:.4f})")
axes[0].set_xlabel('Recall', fontsize=12)
axes[0].set_ylabel('Precision', fontsize=12)
axes[0].set_title('Precision-Recall Curve (Test Set)', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# ROC Curve
axes[1].plot(test_results['fpr'], test_results['tpr'], 'g-', linewidth=2,
             label=f"TRD-GraphSAGE (AUC={test_results['roc_auc']:.4f})")
axes[1].plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5, label='Random')
axes[1].set_xlabel('False Positive Rate', fontsize=12)
axes[1].set_ylabel('True Positive Rate', fontsize=12)
axes[1].set_title('ROC Curve (Test Set)', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('trd_graphsage_pr_roc_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä PR/ROC curves saved: trd_graphsage_pr_roc_curves.png")

---

## üíæ Export Results

Save metrics for comparison with baseline

In [None]:
import time

# Create metrics dictionary
metrics = {
    'timestamp': int(time.time()),
    'experiment': 'trd-gnn-temporal',
    'model': 'TRD-GraphSAGE',
    'split': 'test',
    'pr_auc': test_results['pr_auc'],
    'roc_auc': test_results['roc_auc'],
    'f1': test_f1,
    'recall@1%': recall_at_1,
    'recall@0.5%': recall_at_05,
    'recall@2%': recall_at_2,
    'best_threshold': best_threshold,
    'best_epoch': config['best_epoch'],
    'total_epochs': len(history['epoch']),
    'feature_set': 'Local (AF1-AF93)',
    'num_parameters': sum(p.numel() for p in model.parameters()),
    'hidden_channels': 128,
    'num_layers': 2,
    'dropout': 0.4
}

# Save as JSON
with open('trd_graphsage_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print("\nüíæ Metrics saved: trd_graphsage_metrics.json")
print("\nüìã Final Metrics Summary:")
print(json.dumps(metrics, indent=2))

### Create CSV Row for Comparison

In [None]:
# Create DataFrame row
results_row = pd.DataFrame([{
    'timestamp': metrics['timestamp'],
    'experiment': metrics['experiment'],
    'model': metrics['model'],
    'split': metrics['split'],
    'pr_auc': metrics['pr_auc'],
    'roc_auc': metrics['roc_auc'],
    'f1': metrics['f1'],
    'recall@1%': metrics['recall@1%']
}])

# Save to CSV
results_row.to_csv('trd_graphsage_results.csv', index=False)

print("\nüìä Results CSV saved: trd_graphsage_results.csv")
print("\n" + "="*60)
print(results_row.to_string(index=False))
print("="*60)

---

## üéØ Summary & Next Steps

### Key Findings

1. **TRD-GraphSAGE Performance**:
   - Test PR-AUC: _____ (fill after running)
   - Test ROC-AUC: _____ (fill after running)
   - Test F1: _____ (fill after running)

2. **Temporal Constraints**:
   - ‚úÖ No future leakage enforced
   - ‚úÖ Directed sampling respected
   - ‚úÖ Temporal splits maintained

3. **Model Characteristics**:
   - Parameters: ~_____ (fill after running)
   - Training time: ~_____ (fill after running)
   - Best epoch: _____ (fill after running)

### Comparison with Baseline

To compare with baseline metrics, merge this CSV with the baseline `metrics_summary.csv`:

```python
baseline_metrics = pd.read_csv('../reports/metrics_summary.csv')
combined = pd.concat([baseline_metrics, results_row], ignore_index=True)
combined.to_csv('combined_metrics.csv', index=False)
```

### Files Created

1. `trd_graphsage_best.pt` - Best model checkpoint
2. `trd_graphsage_metrics.json` - Detailed metrics
3. `trd_graphsage_results.csv` - CSV for comparison
4. `trd_graphsage_training_history.png` - Training plots
5. `trd_graphsage_pr_roc_curves.png` - Evaluation curves

### Next Steps

1. Download all output files from Kaggle
2. Upload to project repository
3. Compare with baseline metrics
4. Optional: Try TRD-GCN variant
5. Optional: Experiment with All features (AF1-AF182)

---

**üéâ Notebook Complete!**