# M7 — GraphSAGE Feature Ablation (Kaggle GPU)

In [None]:
import os
import sys
import json
import time
import shutil
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import average_precision_score, precision_recall_curve, roc_auc_score, roc_curve, f1_score
from torch_geometric.nn import SAGEConv

KAGGLE_CODE = Path('/kaggle/input/elliptic-gnn-code')
WORKDIR = Path('/kaggle/working/elliptic-gnn-baselines')
if KAGGLE_CODE.exists():
    if not WORKDIR.exists():
        shutil.copytree(KAGGLE_CODE, WORKDIR, dirs_exist_ok=True)
    os.chdir(WORKDIR)

PROJECT_ROOT = Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")


In [None]:
SEED = 42

def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")


In [None]:
DATA_PATH = Path('/kaggle/input/elliptic-fraud-data/Elliptic++ Dataset')
if not DATA_PATH.exists():
    DATA_PATH = PROJECT_ROOT / 'data' / 'Elliptic++ Dataset'
if not DATA_PATH.exists():
    raise FileNotFoundError('Elliptic++ dataset not found.')

print('Loading Elliptic++ dataset...')
features_df = pd.read_csv(DATA_PATH / 'txs_features.csv')
classes_df = pd.read_csv(DATA_PATH / 'txs_classes.csv')
edges_df = pd.read_csv(DATA_PATH / 'txs_edgelist.csv')
print(f"✓ Features: {features_df.shape}")
print(f"✓ Classes: {classes_df.shape}")
print(f"✓ Edges: {edges_df.shape}")


In [None]:
LOCAL_FEATURES = [f"Local_feature_{i}" for i in range(1, 94)]
AGGREGATE_FEATURES = [f"Aggregate_feature_{i}" for i in range(1, 73)]
STRUCTURAL_FEATURES = [
    'in_txs_degree','out_txs_degree','total_BTC','fees','size',
    'num_input_addresses','num_output_addresses',
    'in_BTC_min','in_BTC_max','in_BTC_mean','in_BTC_median','in_BTC_total',
    'out_BTC_min','out_BTC_max','out_BTC_mean','out_BTC_median','out_BTC_total'
]

FEATURE_CONFIGS = [
    ('full', None),
    ('local_only', LOCAL_FEATURES),
    ('aggregate_only', AGGREGATE_FEATURES),
    ('local_plus_structural', LOCAL_FEATURES + STRUCTURAL_FEATURES)
]
print('Configured feature subsets:')
for name, cols in FEATURE_CONFIGS:
    print(f"  - {name}: {'all' if cols is None else len(cols)} columns")


In [None]:
def prepare_data(feature_subset):
    data_df = features_df.merge(classes_df, on='txId', how='left')
    data_df['class'] = data_df['class'].fillna(3).astype(int)

    base_feature_cols = [c for c in data_df.columns if c not in ['txId', 'Time step', 'class']]
    if feature_subset is None:
        feature_cols = base_feature_cols
    else:
        feature_cols = [c for c in feature_subset if c in base_feature_cols]
        if not feature_cols:
            raise ValueError('Selected feature subset has no matching columns in dataset.')

    feat_df = data_df[feature_cols].replace([np.inf, -np.inf], np.nan).fillna(0.0)
    x_np = feat_df.astype(np.float32).values
    mean = x_np.mean(axis=0)
    std = x_np.std(axis=0)
    std[std < 1e-6] = 1.0
    x_np = (x_np - mean) / std
    x = torch.from_numpy(x_np).float()
    x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

    timestamps = data_df['Time step'].values
    y_raw = data_df['class'].values
    y = np.where(y_raw == 1, 1, np.where(y_raw == 2, 0, -1))
    y = torch.LongTensor(y)

    tx_ids = data_df['txId'].values
    tx_id_to_idx = {tx_id: idx for idx, tx_id in enumerate(tx_ids)}

    valid_edges = edges_df[
        edges_df['txId1'].isin(tx_id_to_idx) &
        edges_df['txId2'].isin(tx_id_to_idx)
    ]
    edge_src = valid_edges['txId1'].map(tx_id_to_idx).values
    edge_dst = valid_edges['txId2'].map(tx_id_to_idx).values
    edge_index = torch.LongTensor(np.vstack([edge_src, edge_dst]))

    num_nodes = len(data_df)
    self_loop = torch.arange(num_nodes)
    edge_index = torch.cat([
        edge_index,
        torch.stack([self_loop, self_loop], dim=0)
    ], dim=1)

    sorted_times = np.sort(np.unique(timestamps))
    n_timesteps = len(sorted_times)
    train_end_idx = int(n_timesteps * 0.6)
    val_end_idx = int(n_timesteps * 0.8)
    train_time_end = sorted_times[train_end_idx - 1]
    val_time_end = sorted_times[val_end_idx - 1]

    labeled_mask = y >= 0
    train_mask = torch.BoolTensor((timestamps <= train_time_end) & labeled_mask.numpy())
    val_mask = torch.BoolTensor((timestamps > train_time_end) & (timestamps <= val_time_end) & labeled_mask.numpy())
    test_mask = torch.BoolTensor((timestamps > val_time_end) & labeled_mask.numpy())

    print(f"    Train {train_mask.sum():,} | Val {val_mask.sum():,} | Test {test_mask.sum():,}")
    return x, y, edge_index, train_mask, val_mask, test_mask


In [None]:
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels=128, out_channels=2, num_layers=2, dropout=0.4):
        super().__init__()
        self.dropout = dropout
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        if num_layers > 1:
            self.convs.append(SAGEConv(hidden_channels, out_channels))
        else:
            self.convs[0] = SAGEConv(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)


def train_graphsage(x, y, edge_index, train_mask, val_mask):
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    edge_index = edge_index.to(DEVICE)
    train_mask = train_mask.to(DEVICE)
    val_mask = val_mask.to(DEVICE)

    model = GraphSAGE(in_channels=x.shape[1]).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    best_val = 0.0
    best_state = None
    patience = 15
    patience_counter = 0

    for epoch in range(100):
        model.train()
        optimizer.zero_grad()
        out = model(x, edge_index)
        loss = criterion(out[train_mask], y[train_mask])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        model.eval()
        with torch.no_grad():
            out = model(x, edge_index)
            val_loss = criterion(out[val_mask], y[val_mask]).item()
            val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
            val_labels = y[val_mask].cpu().numpy()
            val_pr = average_precision_score(val_labels, val_probs)

        if (epoch + 1) % 10 == 0:
            print(f"    Epoch {epoch+1:03d}: loss={loss.item():.4f} | val_PR={val_pr:.4f}")

        if val_pr > best_val:
            best_val = val_pr
            best_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"    Early stopping at epoch {epoch+1}")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    return model


def evaluate_model(model, x, y, edge_index, val_mask, test_mask):
    model.eval()
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    edge_index = edge_index.to(DEVICE)
    val_mask = val_mask.to(DEVICE)
    test_mask = test_mask.to(DEVICE)

    with torch.no_grad():
        out = model(x, edge_index)
        val_probs = F.softmax(out[val_mask], dim=1)[:, 1].cpu().numpy()
        val_labels = y[val_mask].cpu().numpy()
        test_probs = F.softmax(out[test_mask], dim=1)[:, 1].cpu().numpy()
        test_labels = y[test_mask].cpu().numpy()

    precision, recall, thresholds = precision_recall_curve(val_labels, val_probs)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_threshold = thresholds[np.argmax(f1_scores)] if len(thresholds) else 0.5

    pr_auc = average_precision_score(test_labels, test_probs)
    roc_auc = roc_auc_score(test_labels, test_probs)
    preds = (test_probs >= best_threshold).astype(int)
    f1 = f1_score(test_labels, preds)

    def recall_at_k(y_true, scores, frac):
        k = max(1, int(len(y_true) * frac))
        idx = np.argsort(scores)[-k:]
        return y_true[idx].sum() / y_true.sum()

    return {
        'pr_auc': float(pr_auc),
        'roc_auc': float(roc_auc),
        'f1': float(f1),
        'threshold': float(best_threshold),
        'recall@0.5%': float(recall_at_k(test_labels, test_probs, 0.005)),
        'recall@1.0%': float(recall_at_k(test_labels, test_probs, 0.01)),
        'recall@2.0%': float(recall_at_k(test_labels, test_probs, 0.02))
    }


In [None]:
REPORTS_DIR = PROJECT_ROOT / 'reports'
PLOTS_DIR = REPORTS_DIR / 'plots'
CHECKPOINT_DIR = PROJECT_ROOT / 'checkpoints'
for path in (REPORTS_DIR, PLOTS_DIR, CHECKPOINT_DIR):
    path.mkdir(parents=True, exist_ok=True)

results = []

for config_name, feature_cols in FEATURE_CONFIGS:
    print('\n' + '='*80)
    print(f"Running GraphSAGE for config: {config_name}")
    print('='*80)

    x, y, edge_index, train_mask, val_mask, test_mask = prepare_data(feature_cols)
    model = train_graphsage(x, y, edge_index, train_mask, val_mask)
    metrics = evaluate_model(model, x, y, edge_index, val_mask, test_mask)

    prefix = f"graphsage_{config_name}"
    metrics_path = REPORTS_DIR / f"{prefix}_metrics.json"
    checkpoint_path = CHECKPOINT_DIR / f"{prefix}_best.pt"

    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    torch.save(model.state_dict(), checkpoint_path)

    print(f"    Saved {metrics_path.name}, {checkpoint_path.name}")
    print(f"    PR-AUC={metrics['pr_auc']:.4f} | ROC-AUC={metrics['roc_auc']:.4f} | F1={metrics['f1']:.4f}")

    results.append({
        'config': config_name,
        'feature_count': x.shape[1],
        **metrics,
    })

results_df = pd.DataFrame(results)
results_df


In [None]:
summary_path = REPORTS_DIR / 'm7_graphsage_ablation_summary.csv'
results_df.to_csv(summary_path, index=False)
print(f"Summary CSV saved to {summary_path}")
