In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score,
    precision_score, recall_score, f1_score,
    roc_auc_score, brier_score_loss, classification_report
)
from sklearn.calibration import calibration_curve

import torch
from pytorch_tabnet.tab_model import TabNetClassifier

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Load data
df = pd.read_parquet('features_full.parquet')
feature_cols = [c for c in df.columns if c not in ['choice','OD','Obs_ID']]
X = df[feature_cols].values
y = df['choice'].values
groups = df['OD'].values # Change to OBS_ID if you want to group by Obs_ID


In [None]:
# CV setup
n_splits = 5
gkf      = GroupKFold(n_splits=n_splits)

metrics = {
    'accuracy': [], 'balanced_acc': [],
    'precision': [], 'recall': [], 'f1': [],
    'roc_auc': [], 'brier': [],
    'baseline_acc': [], 'baseline_bal_acc': [],
    'group_acc': [], 'baseline_group_acc': []
}

# Prepare calibration plot
plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k--', label='Perfect calibration')

# ------------------------------------------------------------------------------
# 3) CV loop
# ------------------------------------------------------------------------------
for fold, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups), start=1):
    print(f"\n=== Fold {fold}/{n_splits} ===")
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]
    
    # Subset of the original DataFrame for later group‐level ops
    val_df = df.iloc[val_idx].copy().reset_index(drop=True)

    # Impute
    imp = SimpleImputer(strategy='median')
    X_tr_imp  = imp.fit_transform(X_tr)
    X_val_imp = imp.transform(X_val)

    # Train TabNet
    clf = TabNetClassifier(verbose=0, device_name='auto')
    clf.fit(
        X_tr_imp, y_tr,
        eval_set=[(X_val_imp, y_val)],
        eval_name=['val'],
        eval_metric=['auc', 'accuracy'],
        max_epochs=100, patience=10,
        batch_size=1024, virtual_batch_size=128,
        compute_importance=True
    )

    # Row-level predictions
    y_proba = clf.predict_proba(X_val_imp)[:, 1]
    y_pred  = (y_proba >= 0.5).astype(int)

    # Attach to val_df
    val_df['proba']       = y_proba
    val_df['pred_label']  = y_pred
    # baseline per‐row pred: choose rank_TT == 1
    val_df['pred_base']   = (val_df['rank_TT'] == 1).astype(int)

    # **Per-row metrics**
    acc      = accuracy_score(y_val, y_pred)
    bal_acc  = balanced_accuracy_score(y_val, y_pred)
    prec     = precision_score(y_val, y_pred, zero_division=0)
    rec      = recall_score(y_val, y_pred, zero_division=0)
    f1       = f1_score(y_val, y_pred, zero_division=0)
    roc_auc  = roc_auc_score(y_val, y_proba)
    brier    = brier_score_loss(y_val, y_proba)

    base_acc = accuracy_score(y_val, val_df['pred_base'])
    base_bal = balanced_accuracy_score(y_val, val_df['pred_base'])

    # **Group‐level accuracy
    #  - model: check if for each Obs_ID, the row with max proba has choice==1
    group_hits = (
        val_df
        .groupby('Obs_ID')
        .apply(lambda g: g.loc[g['proba'].idxmax(), 'choice'] == 1)
    )
    group_acc = group_hits.mean()

    #  - baseline: for each Obs_ID, the row with rank_TT==1 should match choice==1 #This is fastest time baseline and deprecated. In report it's matched against
    # fewest transfers then fastest time
    base_hits = (
        val_df
        .groupby('Obs_ID')
        .apply(lambda g: g.loc[g['pred_base']==1, 'choice'].iat[0] == 1)
    )
    baseline_group_acc = base_hits.mean()

    # Store metrics
    metrics['accuracy'].append(acc)
    metrics['balanced_acc'].append(bal_acc)
    metrics['precision'].append(prec)
    metrics['recall'].append(rec)
    metrics['f1'].append(f1)
    metrics['roc_auc'].append(roc_auc)
    metrics['brier'].append(brier)
    metrics['baseline_acc'].append(base_acc)
    metrics['baseline_bal_acc'].append(base_bal)
    metrics['group_acc'].append(group_acc)
    metrics['baseline_group_acc'].append(baseline_group_acc)

    # Print results
    print(f"Model (per-row)    → acc: {acc:.3f}, bal_acc: {bal_acc:.3f}, "
          f"AUC: {roc_auc:.3f}, Brier: {brier:.3f}")
    print(f"Model (group)      → accuracy: {group_acc:.3f}")
    print(f"Baseline (per-row) → acc: {base_acc:.3f}, bal_acc: {base_bal:.3f}")
    print(f"Baseline (group)   → accuracy: {baseline_group_acc:.3f}")

    print("\nClassification Report (per-row):")
    print(classification_report(y_val, y_pred, zero_division=0))

    # Calibration curve
    frac_pos, mean_pred = calibration_curve(y_val, y_proba, n_bins=10)
    plt.plot(mean_pred, frac_pos, 'o-', label=f'Fold {fold}')

# CV summary & plots
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Reliability diagram (GroupKFold CV)')
plt.legend()
plt.tight_layout()
plt.show()

# Summary table
cv_df = pd.DataFrame(metrics)
print("\nCross-validation summary (mean ± std):")
print(cv_df.agg(['mean','std']).T)

# Bar chart: group accuracy per fold
plt.figure(figsize=(6,4))
folds = np.arange(1, n_splits+1)
plt.bar(folds - 0.15, metrics['group_acc'], width=0.3, label='Model')
plt.bar(folds + 0.15, metrics['baseline_group_acc'], width=0.3, label='Baseline')
plt.xlabel('Fold')
plt.ylabel('Group-level Accuracy')
plt.title('Choice-set Accuracy per Fold')
plt.xticks(folds)
plt.legend()
plt.tight_layout()
plt.show()
