In [None]:
import os
# Radeon 890M (RDNA 3.5 / gfx1150) cannot be used with Pytorch
# os.environ["HSA_OVERRIDE_GFX_VERSION"] = "11.0.2"

import uproot
import awkward as ak
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    classification_report,
    precision_recall_fscore_support,
    balanced_accuracy_score,
    matthews_corrcoef,
    average_precision_score,
    f1_score,
    brier_score_loss,
)

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

VERBOSE = False
FULL_TRAINING = False

main_branch = "Events"
tk_branches = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
    "muon_pixel_tracks_matched",
    "muon_pixel_tracks_duplicate",
    "muon_pixel_tracks_tpPdgId",
    "muon_pixel_tracks_tpPt",
    "muon_pixel_tracks_tpEta",
    "muon_pixel_tracks_tpPhi",
]

l1tkMuon_branches = [
    "L1TkMu_pt",
    "L1TkMu_eta",
    "L1TkMu_phi",
]

# Input files
files = []
data_dir = "data/train"
i = 0
for path in os.listdir(data_dir):
    if os.path.isfile(os.path.join(data_dir, path)) and "gun" in path:
        if i == 1:
            files.append(os.path.join(data_dir, path))
        i += 1

print(f"Selected {len(files)} input files:")
print(files)

In [None]:
from IPython.core.magic import register_cell_magic


@register_cell_magic
def skip_if(line, cell):
    condition = eval(line)
    if not condition:
        print("Skipping cell...")
    else:
        exec(cell, globals())

In [None]:
# Reproducibility
"""
import random
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
"""
use_gpu = False  # RDNA 3.5 (890M) is currently unstable
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Torch version HIP:", torch.version.hip)
if torch.cuda.is_available():
    print("CUDA device name:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
print("Using device:", device)

In [None]:
# GPU Diagnostics for AMD Radeon 890M (RDNA 3.5)
if device.type == "cuda":
    print("=" * 60)
    print("AMD GPU DIAGNOSTICS")
    print("=" * 60)

    try:
        # 1. Basic Properties
        props = torch.cuda.get_device_properties(0)
        print(f"Device: {torch.cuda.get_device_name(0)}")
        print(f"VRAM (Total): {props.total_memory / 1024**3:.2f} GB")
        print(f"Compute Capability: {props.major}.{props.minor}")

        # 2. Allocation Test
        print("\n[Test 1] Memory Allocation...")
        x = torch.ones(1024, 1024, device=device)
        print("Success")

        # 3. Computation Test
        print("\n[Test 2] Matrix Multiplication...")
        y = torch.matmul(x, x)
        print("Success")

        # 4. Stream Test
        print("\n[Test 3] CUDA Streams...")
        s = torch.cuda.Stream()
        with torch.cuda.stream(s):
            z = torch.matmul(y, y)
        torch.cuda.synchronize()
        print("Success")

        print("\nGPU seems functional")

    except Exception as e:
        print(f"\nGPU FAILED: {e}")
        print(
            "Current Override: HSA_OVERRIDE_GFX_VERSION =",
            os.environ.get("HSA_OVERRIDE_GFX_VERSION", "Not Set"),
        )
        print("Try changing the override in the first cell:")
        print("os.environ['HSA_OVERRIDE_GFX_VERSION'] = '11.0.2'")
else:
    print(
        "Running on CPU. Integrated graphics not supported on PyTorch version: 2.10.0.dev20251210+rocm7.1"
    )

## Build Feature Matrix and Labels

In [None]:
features = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
]

LABEL_FIELD = "muon_pixel_tracks_matched"
useL1TkMuFeatures = True


def wrap_phi(phi):
    """Wrap phi to [-pi, pi]"""
    return ((phi + np.pi) % (2 * np.pi)) - np.pi


def build_dataset(arr, file_labels_in):
    print("Building dataset...")
    feature_names = list(features)
    mask = arr[feature_names[0]] >= 0

    # Expand file labels from event level to track level
    n_tracks_per_event = ak.num(arr[feature_names[0]])
    file_labels_jagged = ak.unflatten(
        np.repeat(file_labels_in, n_tracks_per_event), n_tracks_per_event
    )
    file_labels_masked = ak.to_numpy(ak.flatten(file_labels_jagged[mask]))

    # Store original kinematics before log transformation
    trk_pt_original = arr["muon_pixel_tracks_pt"]
    trk_ptErr_original = arr["muon_pixel_tracks_ptErr"]
    trk_eta_original = arr["muon_pixel_tracks_eta"]
    trk_phi_original = arr["muon_pixel_tracks_phi"]
    trk_chi2 = arr["muon_pixel_tracks_chi2"]
    trk_nFoundHits = arr["muon_pixel_tracks_nFoundHits"]
    trk_nLostHits = arr["muon_pixel_tracks_nLostHits"]
    trk_dxy = arr["muon_pixel_tracks_dxy"]
    trk_dz = arr["muon_pixel_tracks_dz"]
    trk_dxyErr = arr["muon_pixel_tracks_dxyErr"]
    trk_dzErr = arr["muon_pixel_tracks_dzErr"]
    trk_qoverp_original = arr["muon_pixel_tracks_qoverp"]
    trk_qoverpErr_original = arr["muon_pixel_tracks_qoverpErr"]

    cols = []
    for f in features:
        minimum = ak.min(ak.flatten(arr[f][mask]))
        maximum = ak.max(ak.flatten(arr[f][mask]))
        if f in ["muon_pixel_tracks_p", "muon_pixel_tracks_pt"] or "Err" in f:
            if VERBOSE:
                print(f"Feature {f} min {minimum:.2f} max {maximum:.2f} -> log10")
            arr[f] = np.log10(arr[f] + 1e-6)
        flat = ak.to_numpy(ak.flatten(arr[f][mask])).astype(np.float32)
        cols.append(flat)

    # DERIVED FEATURES - Help model generalize across physics processes
    print("\nAdding derived features...")

    # 1. Relative momentum uncertainty (already have this)
    sigmaPtOverPt = np.log10(trk_ptErr_original / trk_pt_original + 1e-6)
    cols.append(ak.to_numpy(ak.flatten(sigmaPtOverPt[mask])).astype(np.float32))
    feature_names.append("muon_pixel_tracks_sigmaPtOverPt")

    # 2. Hit efficiency: found / (found + lost)
    hit_efficiency = trk_nFoundHits / (trk_nFoundHits + trk_nLostHits)
    cols.append(ak.to_numpy(ak.flatten(hit_efficiency[mask])).astype(np.float32))
    feature_names.append("muon_pixel_tracks_hitEfficiency")

    # 3. Chi2 per hit (quality per measurement)
    chi2_per_hit = trk_chi2 / trk_nFoundHits
    cols.append(
        ak.to_numpy(ak.flatten(np.log10(chi2_per_hit + 1e-6)[mask])).astype(np.float32)
    )
    feature_names.append("muon_pixel_tracks_chi2PerHit")

    # 4. Impact parameter significance (3D)
    impact_param_3d = trk_dxy**2 + trk_dz**2
    cols.append(
        ak.to_numpy(ak.flatten(np.log10(impact_param_3d + 1e-6)[mask])).astype(
            np.float32
        )
    )
    feature_names.append("muon_pixel_tracks_impact3D")

    # 5. Impact parameter significance (normalized by uncertainty)
    dxy_significance = trk_dxy / trk_dxyErr
    dz_significance = trk_dz / trk_dzErr
    impact_significance_2d = np.sqrt(dxy_significance**2 + dz_significance**2)
    cols.append(
        ak.to_numpy(ak.flatten(np.log10(impact_significance_2d + 1e-6)[mask])).astype(
            np.float32
        )
    )
    feature_names.append("muon_pixel_tracks_impactSignificance")

    # 6. Relative uncertainties product (captures overall measurement quality)
    rel_uncertainty_product = (trk_ptErr_original / trk_pt_original) * (
        trk_qoverpErr_original / trk_qoverp_original
    )
    cols.append(
        ak.to_numpy(ak.flatten(np.log10(rel_uncertainty_product + 1e-6)[mask])).astype(
            np.float32
        )
    )
    feature_names.append("muon_pixel_tracks_relUncertaintyProduct")

    if useL1TkMuFeatures:
        print("\nComputing L1TkMuon matching features...")

        l1_pt = arr["L1TkMu_pt"]
        l1_eta = arr["L1TkMu_eta"]
        l1_phi = arr["L1TkMu_phi"]

        trk_zip = ak.zip(
            {"pt": trk_pt_original, "eta": trk_eta_original, "phi": trk_phi_original}
        )
        l1_zip = ak.zip({"pt": l1_pt, "eta": l1_eta, "phi": l1_phi})

        pairs = ak.cartesian({"t": trk_zip, "l": l1_zip}, axis=1, nested=True)

        deta = pairs.t.eta - pairs.l.eta
        dphi = wrap_phi(pairs.t.phi - pairs.l.phi)
        dR2 = deta**2 + dphi**2

        min_idx = ak.argmin(dR2, axis=2)
        dR2_min = ak.firsts(
            dR2[ak.local_index(dR2, axis=2) == min_idx[..., None]], axis=2
        )
        l1_pt_matched = ak.firsts(
            pairs.l.pt[ak.local_index(pairs.l.pt, axis=2) == min_idx[..., None]], axis=2
        )

        dPt_norm = np.abs(trk_pt_original - l1_pt_matched) / l1_pt_matched

        # L1 pT ratio
        pt_ratio = trk_pt_original / l1_pt_matched

        # Combined matching score (dR and pT)
        matching_score = dR2_min * (1 + dPt_norm)

        dR2_min = ak.fill_none(dR2_min, 999.0)
        dPt_norm = ak.fill_none(dPt_norm, 999.0)
        pt_ratio = ak.fill_none(pt_ratio, 999.0)
        matching_score = ak.fill_none(matching_score, 999.0)

        dR2_min_flat = ak.to_numpy(ak.flatten(dR2_min[mask]))
        dPt_norm_flat = ak.to_numpy(ak.flatten(dPt_norm[mask]))
        pt_ratio_flat = ak.to_numpy(ak.flatten(pt_ratio[mask]))
        matching_score_flat = ak.to_numpy(ak.flatten(matching_score[mask]))

        valid_matches = dPt_norm_flat < 999.0

        if VERBOSE:
            print(
                f"  ŒîR2_min: min={dR2_min_flat.min():.4f}, max={dR2_min_flat.max():.4f}, mean={dR2_min_flat.mean():.4f}"
            )
            print(
                f"  ŒîpT_norm: min={dPt_norm_flat.min():.4f}, max={dPt_norm_flat.max():.4f}, mean={dPt_norm_flat.mean():.4f}"
            )

            print(
                f"  Valid L1TkMu matches: {valid_matches.sum()}/{len(valid_matches)} ({valid_matches.mean() * 100:.2f}%)"
            )
            if valid_matches.any():
                print(
                    f"  ŒîpT_norm (valid only): mean={dPt_norm_flat[valid_matches].mean():.4f}, "
                    f"median={np.median(dPt_norm_flat[valid_matches]):.4f}, "
                    f"95th percentile={np.percentile(dPt_norm_flat[valid_matches], 95):.4f}"
                )

        cols.append(np.log10(dR2_min_flat + 1e-6).astype(np.float32))
        cols.append(np.log10(dPt_norm_flat + 1e-6).astype(np.float32))
        cols.append(np.log10(pt_ratio_flat + 1e-6).astype(np.float32))
        cols.append(np.log10(matching_score_flat + 1e-6).astype(np.float32))

        feature_names.append("L1TkMu_dR2min")
        feature_names.append("L1TkMu_dPtNorm")
        feature_names.append("L1TkMu_ptRatio")
        feature_names.append("L1TkMu_matchingScore")

    X = np.vstack(cols).T.astype(np.float32)
    y = ak.to_numpy(ak.flatten(arr[LABEL_FIELD][mask])).astype(np.int8)

    finite = np.isfinite(X).all(axis=1)
    if not finite.all():
        print(f"Removing {(~finite).sum()} non-finite samples")
        X = X[finite]
        y = y[finite]
        file_labels_masked = file_labels_masked[finite]

    return X, y, file_labels_masked, feature_names

In [None]:
# Incremental loading and processing to save memory
import gc

X_list = []
y_list = []
file_labels_list = []
feature_names = []
total_events = 0

print(f"Processing {len(files)} files incrementally...")

for i, f in enumerate(files):
    print(f"[{i + 1}/{len(files)}] Loading {f}...")
    with uproot.open(f) as file:
        arrays_f = file[main_branch].arrays(tk_branches + l1tkMuon_branches)
        n_events = len(arrays_f)
        total_events += n_events

        # Create temporary file labels
        file_labels_temp = np.full(n_events, i)

        X_chunk, y_chunk, labels_chunk, feats = build_dataset(
            arrays_f, file_labels_temp
        )

        X_list.append(X_chunk)
        y_list.append(y_chunk)
        file_labels_list.append(labels_chunk)

        if i == 0:
            feature_names = feats

        # Free memory immediately
        del arrays_f
        del X_chunk, y_chunk, labels_chunk
        gc.collect()
        print(f"Loaded {n_events} events from {f}")

X = np.concatenate(X_list)
y = np.concatenate(y_list)
file_labels_flat = np.concatenate(file_labels_list)

print(f"Loaded {total_events} events from {len(files)} files")
print("Feature matrix shape:", X.shape)
print("Labels shape:", y.shape, "Positives:", y.sum(), f"({100 * y.mean():.2f}%)")
print("Feature order:", feature_names)

# Show distribution across files
print("\nSample distribution by file:")
for file_idx, fname in enumerate(files):
    n_samples = (file_labels_flat == file_idx).sum()
    n_pos = ((file_labels_flat == file_idx) & (y == 1)).sum()
    print(
        f"  File {file_idx} ({fname.split('/')[-1]}): {n_samples} samples ({n_pos} positive, {100 * n_pos / max(n_samples, 1):.2f}%)"
    )

## Train/Test Split and Scaling

In [None]:
train_frac = 0.7  # fraction for train+test
val_frac = 0.15  # fraction (of full dataset) reserved for validation (inside train+val)

# Create composite stratification label: combine class label + file source
# This ensures both class balance AND file representation in each split
stratify_label = (
    y * len(files) + file_labels_flat
)  # Unique label per (class, file) combination

# First split train_val vs test
X_train_val, X_test, y_train_val, y_test, file_train_val, file_test = train_test_split(
    X,
    y,
    file_labels_flat,
    train_size=train_frac,
    stratify=stratify_label if (y.sum() > 0 and y.sum() < len(y)) else None,
)

# Derive actual validation fraction relative to train_val portion
if val_frac > 0:
    rel_val = val_frac / train_frac  # portion of train_val to carve out as validation
    stratify_train_val = y_train_val * len(files) + file_train_val
    X_train, X_val, y_train, y_val, file_train, file_val = train_test_split(
        X_train_val,
        y_train_val,
        file_train_val,
        test_size=rel_val,
        stratify=stratify_train_val
        if (y_train_val.sum() > 0 and y_train_val.sum() < len(y_train_val))
        else None,
    )
else:
    X_train, y_train, file_train = X_train_val, y_train_val, file_train_val
    X_val, y_val, file_val = X_test, y_test, file_test

print(
    "Train size:", X_train.shape, "Val size:", X_val.shape, "Test size:", X_test.shape
)

In [None]:
%%skip_if VERBOSE
# Examine features pre-scaling
print("Pre-scaling feature stats:")
for i in range(X_train.shape[1]):
    print(
        f"{feature_names[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}"
    )

In [None]:
# Scale features to zero mean and unit variance
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

if VERBOSE:
    print("\nPost-scaling feature stats:")
    for i in range(X_train.shape[1]):
        print(
            f"{feature_names[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}"
        )

In [None]:
# Verify file distribution in each split
print("\nFile Distribution Verification")
for split_name, file_split, y_split in [
    ("Train", file_train, y_train),
    ("Val", file_val, y_val),
    ("Test", file_test, y_test),
]:
    print(f"\n{split_name} set:")
    for file_idx, fname in enumerate(files):
        n_samples = (file_split == file_idx).sum()
        n_pos = ((file_split == file_idx) & (y_split == 1)).sum()
        pct_of_split = 100 * n_samples / len(file_split)
        print(
            f"  {fname.split('/')[-1]}: {n_samples} samples ({pct_of_split:.1f}% of {split_name.lower()}, {n_pos} pos)"
        )

In [None]:
# SIMPLIFIED CLASS BALANCING - Choose one strategy
use_undersampling = False  # Try turning OFF undersampling
use_weighted_loss = True  # Use class weights instead
use_focal = False  # Don't use focal loss with pos_weight


class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32)).unsqueeze(1)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# Simple random undersampling without file complexity
if use_undersampling:
    pos_indices = np.where(y_train == 1)[0]
    neg_indices = np.where(y_train == 0)[0]

    # Keep all positives, undersample negatives to 3:1 ratio
    n_neg_keep = min(len(neg_indices), len(pos_indices) * 3)
    neg_indices_sampled = np.random.choice(neg_indices, size=n_neg_keep, replace=False)

    balanced_indices = np.concatenate([pos_indices, neg_indices_sampled])
    np.random.shuffle(balanced_indices)

    X_train_balanced = X_train[balanced_indices]
    y_train_balanced = y_train[balanced_indices]

    print(f"Undersampling: {len(y_train)} -> {len(y_train_balanced)} samples")
    print(f"  Pos: {y_train_balanced.sum()} ({100 * y_train_balanced.mean():.1f}%)")

    train_ds = NumpyDataset(X_train_balanced, y_train_balanced)
    pos_count = y_train_balanced.sum()
    neg_count = len(y_train_balanced) - pos_count
else:
    train_ds = NumpyDataset(X_train, y_train)
    pos_count = y_train.sum()
    neg_count = len(y_train) - pos_count

# Compute pos_weight for BCE loss
if use_weighted_loss and pos_count > 0:
    pos_weight_value = neg_count / pos_count
    # Don't multiply - let the ratio speak for itself
    pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=device)
    print(f"Using pos_weight={pos_weight.item():.2f}")
else:
    pos_weight = None
    print("Using unweighted loss")

## Torch Dataset and DataLoader

In [None]:
batch_size = 2048

val_ds = NumpyDataset(X_val, y_val)
test_ds = NumpyDataset(X_test, y_test)

# if not use_undersampling:
#    # inverse frequency sampling to upweight minority
#    class_sample_counts = np.array([(y_train == 0).sum(), (y_train == 1).sum()])
#    weights = 1.0 / np.clip(class_sample_counts, 1, None)
#    sample_weights = weights[y_train]
#    sampler = torch.utils.data.WeightedRandomSampler(
#        sample_weights, num_samples=len(sample_weights), replacement=True
#    )
#    train_loader = DataLoader(
#        train_ds, batch_size=batch_size, sampler=sampler, drop_last=False
#    )
# else:
#    train_loader = DataLoader(
#        train_ds, batch_size=batch_size, shuffle=True, drop_last=False#
#  )

train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True
)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

xb, yb = next(iter(train_loader))
print("First batch shapes:", xb.shape, yb.shape)
print(
    "Class balance train: pos=",
    (y_train_balanced if use_undersampling else y_train).sum(),
    "neg=",
    len(y_train_balanced if use_undersampling else y_train)
    - (y_train_balanced if use_undersampling else y_train).sum(),
)

## Features importance

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
import pandas as pd

print("Training Random Forest for feature importance analysis...")
# Use a subset for faster training
sample_size = min(1000000, len(X_train))
sample_idx = np.random.choice(len(X_train), size=sample_size, replace=False)

rf = RandomForestClassifier(
    n_estimators=100, max_depth=10, min_samples_leaf=50, n_jobs=-1, random_state=42
)
rf.fit(X_train[sample_idx], y_train[sample_idx])

# Get feature importances
importances = rf.feature_importances_
indices = np.argsort(importances)[::-1]

print("\n" + "=" * 70)
print("FEATURE IMPORTANCE RANKING (Random Forest)")
print("=" * 70)

importance_df = pd.DataFrame(
    {
        "Feature": [feature_names[i] for i in indices],
        "Importance": importances[indices],
        "Cumulative": np.cumsum(importances[indices]),
    }
)

for idx, row in importance_df.iterrows():
    feat_name = row["Feature"].split("_")[-1]
    print(
        f"{idx + 1:2d}. {feat_name:20s}: {row['Importance']:.4f} (cumulative: {row['Cumulative']:.4f})"
    )

# Find features contributing to 98% of importance
threshold_98 = 0.98
important_features_98 = importance_df[importance_df["Cumulative"] <= threshold_98][
    "Feature"
].tolist()
# Add one more feature to cross the threshold
if len(important_features_98) < len(feature_names):
    important_features_98.append(
        importance_df.iloc[len(important_features_98)]["Feature"]
    )

print(f"\n{len(important_features_98)} features explain 98% of importance")
print("Top features:", [f.split("_")[-1] for f in important_features_98[:10]])

# Permutation importance (more reliable but slower)
print(
    "\nComputing permutation importance on validation set (this may take a minute)..."
)
perm_importance = permutation_importance(
    rf, X_val[:10000], y_val[:10000], n_repeats=5, random_state=42, n_jobs=-1
)

perm_indices = np.argsort(perm_importance.importances_mean)[::-1]
print("\nTop 10 by Permutation Importance:")
for i in range(min(10, len(perm_indices))):
    idx = perm_indices[i]
    feat_name = feature_names[idx].split("_")[-1]
    print(
        f"  {feat_name:20s}: {perm_importance.importances_mean[idx]:.4f} ¬± {perm_importance.importances_std[idx]:.4f}"
    )

In [None]:
# Visualize feature importance
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Bar chart of top features
ax = axes[0]
top_n = 15
top_features = [feature_names[i].split("_")[-1] for i in indices[:top_n]]
top_importance = importances[indices[:top_n]]
ax.barh(range(top_n), top_importance[::-1])
ax.set_yticks(range(top_n))
ax.set_yticklabels(top_features[::-1])
ax.set_xlabel("Importance")
ax.set_title(f"Top {top_n} Features by Importance")
ax.grid(axis="x", alpha=0.3)

# Plot 2: Cumulative importance
ax = axes[1]
ax.plot(
    range(1, len(importances) + 1),
    np.cumsum(importances[indices]),
    marker="o",
    markersize=4,
)
ax.axhline(y=0.98, color="r", linestyle="--", label="98% threshold")
ax.axvline(x=len(important_features_98), color="r", linestyle="--", alpha=0.5)
ax.set_xlabel("Number of Features")
ax.set_ylabel("Cumulative Importance")
ax.set_title("Cumulative Feature Importance")
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig("model_output/feature_importance.png", dpi=300)
plt.show()

print(
    f"Feature analysis complete. Consider using top {len(important_features_98)} features."
)

In [None]:
# Use top N features by importance
use_feature_pruning = True
n_features_to_keep = (
    min(20, len(important_features_98) + 4)
    if len(important_features_98) < 20
    else len(important_features_98)
)
if use_feature_pruning:
    # Select top features
    selected_feature_indices = indices[:n_features_to_keep]
    selected_feature_names = [feature_names[i] for i in selected_feature_indices]

    print(f"Pruning features: {len(feature_names)} -> {len(selected_feature_names)}")
    print("Selected features:", [f.split("_")[-1] for f in selected_feature_names])

    # Create pruned datasets
    X_train_pruned = X_train[:, selected_feature_indices]
    X_val_pruned = X_val[:, selected_feature_indices]
    X_test_pruned = X_test[:, selected_feature_indices]

    # Update for model training
    X_train_model = X_train_pruned
    X_val_model = X_val_pruned
    X_test_model = X_test_pruned
    feature_names_model = selected_feature_names

    print("Pruned feature matrix shapes:")
    print(f"  Train: {X_train_model.shape}")
    print(f"  Val:   {X_val_model.shape}")
    print(f"  Test:  {X_test_model.shape}")
else:
    X_train_model = X_train
    X_val_model = X_val
    X_test_model = X_test
    feature_names_model = feature_names
    print("Using all features")

# Recreate datasets with pruned features
train_ds = NumpyDataset(X_train_model, y_train)
val_ds = NumpyDataset(X_val_model, y_val)
test_ds = NumpyDataset(X_test_model, y_test)

# Recreate dataloaders
# if not use_undersampling:
#    class_sample_counts = np.array([(y_train == 0).sum(), (y_train == 1).sum()])
#    weights = 1.0 / np.clip(class_sample_counts, 1, None)
#    sample_weights = weights[y_train]
#    sampler = torch.utils.data.WeightedRandomSampler(
#        sample_weights, num_samples=len(sample_weights), replacement=True
#    )
#    train_loader = DataLoader(
#        train_ds, batch_size=batch_size, sampler=sampler, drop_last=False
#    )
# else:
#    train_loader = DataLoader(
#        train_ds, batch_size=batch_size, shuffle=True, drop_last=False
#    )

train_loader = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True
)

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

## Model architecture

In [None]:
class ImprovedMLP(nn.Module):
    """
    Improved MLP with:
    - Residual connections for better gradient flow
    - Layer-wise dropout increasing with depth
    - Batch normalization for stable training
    """

    def __init__(self, in_features, layers, dropout=0.2, use_residual=True):
        super().__init__()
        self.use_residual = use_residual

        # Input projection
        self.input_layer = nn.Sequential(
            nn.Linear(in_features, layers[0]), nn.BatchNorm1d(layers[0]), nn.ReLU()
        )

        # Hidden layers with optional residual connections
        self.hidden_layers = nn.ModuleList()
        for i in range(len(layers) - 1):
            in_dim = layers[i]
            out_dim = layers[i + 1]

            # Increase dropout in deeper layers
            layer_dropout = dropout * (1 + i * 0.15)
            layer_dropout = min(layer_dropout, 0.5)

            block = nn.ModuleList(
                [
                    nn.Linear(in_dim, out_dim),
                    nn.BatchNorm1d(out_dim),
                    nn.ReLU(),
                    nn.Dropout(layer_dropout),
                ]
            )
            self.hidden_layers.append(block)

            # Residual projection if dimensions don't match
            if use_residual and in_dim != out_dim:
                self.hidden_layers.append(nn.ModuleList([nn.Linear(in_dim, out_dim)]))

        # Output layer
        self.output = nn.Linear(layers[-1], 1)

        # Store architecture for residuals
        self.layer_dims = layers

    def forward(self, x):
        x = self.input_layer(x)

        prev_x = x
        block_idx = 0
        for i in range(len(self.layer_dims) - 1):
            block = self.hidden_layers[block_idx]

            # Standard forward pass through block
            identity = prev_x
            for layer in block:
                prev_x = layer(prev_x)

            # Add residual connection if dimensions match
            if self.use_residual:
                if self.layer_dims[i] == self.layer_dims[i + 1]:
                    prev_x = prev_x + identity
                else:
                    # Use projection for dimension mismatch
                    block_idx += 1
                    if block_idx < len(self.hidden_layers):
                        projection = self.hidden_layers[block_idx][0]
                        prev_x = prev_x + projection(identity)

            block_idx += 1

        return self.output(prev_x)


hidden_layers = [32, 16]

dropout = 0.3
lr = 1e-3  # Standard Adam LR
weight_decay = 1e-5

# Create improved model
model = ImprovedMLP(
    in_features=X_train_model.shape[1],
    layers=hidden_layers,
    dropout=dropout,
    use_residual=True,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n{'=' * 70}")
print("Model architecture")
print(f"{'=' * 70}")
print(model)
print("\nParameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024**2:.2f} MB (float32)")
print(f"{'=' * 70}\n")

## Training

In [None]:
# Multi-metric configuration
primary_metric = "f1"  # options: 'auc','f1','ap','balanced_accuracy','mcc'
optimize_threshold = True
threshold_opt_metric = "f1"  # options: 'f1', 'recall', 'precision'
target_recall = 0.99  # Target recall to guarantee (only used if threshold_opt_metric='recall')
min_improvement = 1e-4  # required relative improvement for early stopping reset
use_focal = True

@torch.no_grad()
def evaluate(model, loader, device, decision_threshold=0.5):
    model.eval()
    logits_list = []
    y_list = []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        logits_list.append(logits.cpu())
        y_list.append(yb.cpu())
    logits = torch.cat(logits_list).squeeze(1)
    y_true = torch.cat(y_list).squeeze(1)
    probs = torch.sigmoid(logits).numpy()
    y_np = y_true.numpy().astype(int)
    preds = (probs >= decision_threshold).astype(int)
    cm = confusion_matrix(y_np, preds)
    try:
        auc = roc_auc_score(y_np, probs)
    except ValueError:
        auc = float("nan")
    try:
        ap = average_precision_score(y_np, probs)
    except ValueError:
        ap = float("nan")
    try:
        bal_acc = balanced_accuracy_score(y_np, preds)
    except Exception:
        bal_acc = float("nan")
    try:
        f1 = f1_score(y_np, preds, zero_division=0)
    except Exception:
        f1 = float("nan")
    try:
        mcc = matthews_corrcoef(y_np, preds)
    except Exception:
        mcc = float("nan")
    return dict(
        probs=probs,
        y=y_np,
        preds=preds,
        cm=cm,
        auc=auc,
        ap=ap,
        bal_acc=bal_acc,
        f1=f1,
        mcc=mcc,
    )

def find_threshold_for_target_recall(y_true, probs, target_recall=0.95):
    """
    Find the threshold that achieves at least the target recall.
    Returns the threshold and actual achieved recall.
    """
    # Sort by probability (descending) to efficiently compute recall at different thresholds
    sorted_indices = np.argsort(-probs)
    sorted_y = y_true[sorted_indices]
    
    n_positive = y_true.sum()
    if n_positive == 0:
        return 0.5, 0.0
    
    # Find minimum threshold that gives desired recall
    target_tp = int(np.ceil(target_recall * n_positive))
    
    # Count true positives as we lower threshold
    cumsum_tp = np.cumsum(sorted_y)
    
    # Find first position where we have enough true positives
    threshold_idx = np.searchsorted(cumsum_tp, target_tp)
    
    if threshold_idx >= len(probs):
        # Need all predictions to be positive
        threshold = probs.min() - 0.01
        actual_recall = 1.0
    else:
        threshold = probs[sorted_indices[threshold_idx]]
        actual_recall = cumsum_tp[threshold_idx] / n_positive
    
    return threshold, actual_recall

# Optional Focal Loss wrapper
class FocalBCEWithLogits(nn.Module):
    def __init__(self, pos_weight=None, gamma=5.0, label_smoothing=0.01):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="none")
        self.gamma = gamma
        self.label_smoothing = label_smoothing

    def forward(self, logits, targets):
        # Apply label smoothing: 0 -> epsilon, 1 -> 1-epsilon
        if self.label_smoothing > 0:
            targets = targets * (1 - self.label_smoothing) + 0.5 * self.label_smoothing

        bce_loss = self.bce(logits, targets)
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            pt = probs * targets + (1 - probs) * (1 - targets)
        focal_factor = (1 - pt) ** self.gamma
        return (focal_factor * bce_loss).mean()

In [None]:
# Setup training with improvements
criterion = (
    nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    if use_weighted_loss
    else nn.BCEWithLogitsLoss()
)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Learning rate scheduler - reduce LR when metric plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",  # maximize metric
    factor=0.5,
    patience=15,
    # verbose=True,
    min_lr=1e-6,
)

# Training history
history = {
    "train_loss": [],
    "val_f1": [],
    "val_auc": [],
    "val_recall": [],
    "val_threshold": [],
    "learning_rate": [],
}

best_f1 = 0.0
best_state = None
best_threshold = 0.5
no_improve = 0

# Epochs and early stopping
epochs = 1000
patience = 50

print(f"Training for {epochs} epochs with patience={patience}")
print(f"Threshold optimization metric: {threshold_opt_metric}")
if threshold_opt_metric == "recall":
    print(f"Target recall: {target_recall:.1%}")

for epoch in range(1, epochs + 1):
    # Training
    model.train()
    running_loss = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Evaluation
    val_metrics = evaluate(model, val_loader, device, decision_threshold=0.5)
    probs_val = val_metrics["probs"]
    y_val_np = val_metrics["y"]

    # Threshold optimization based on chosen metric
    thr_opt = 0.5
    if optimize_threshold and probs_val.size > 0:
        if threshold_opt_metric == "f1":
            prec_curve, rec_curve, thr_pr = precision_recall_curve(y_val_np, probs_val)
            f1_curve = (
                2 * prec_curve * rec_curve / np.clip(prec_curve + rec_curve, 1e-9, None)
            )
            max_f1_idx = np.nanargmax(f1_curve)
            if max_f1_idx < len(thr_pr):
                thr_opt = thr_pr[max_f1_idx]
        elif threshold_opt_metric == "recall":
            thr_opt, actual_recall = find_threshold_for_target_recall(
                y_val_np, probs_val, target_recall=target_recall
            )
            if VERBOSE and epoch % 5 == 0:
                print(f"  Target recall: {target_recall:.3f}, Achieved: {actual_recall:.3f} @ thr={thr_opt:.3f}")
        elif threshold_opt_metric == "precision":
            prec_curve, rec_curve, thr_pr = precision_recall_curve(y_val_np, probs_val)
            max_prec_idx = np.nanargmax(prec_curve)
            if max_prec_idx < len(thr_pr):
                thr_opt = thr_pr[max_prec_idx]

    # Re-evaluate at optimal threshold
    val_metrics_thr = evaluate(model, val_loader, device, decision_threshold=thr_opt)

    current_f1 = val_metrics_thr["f1"]
    current_auc = val_metrics_thr["auc"]
    current_recall = val_metrics_thr["cm"][1, 1] / max(val_metrics_thr["cm"][1, 1] + val_metrics_thr["cm"][1, 0], 1) if val_metrics_thr["cm"].shape[0] > 1 else 0
    current_lr = optimizer.param_groups[0]["lr"]

    # Update scheduler
    scheduler.step(current_f1)

    # Log history
    history["train_loss"].append(epoch_loss)
    history["val_f1"].append(current_f1)
    history["val_auc"].append(current_auc)
    history["val_recall"].append(current_recall)
    history["val_threshold"].append(thr_opt)
    history["learning_rate"].append(current_lr)

    # Print progress
    if epoch % 5 == 0 or current_f1 > best_f1:
        print(
            f"Epoch {epoch:03d} | Loss={epoch_loss:.4f} | "
            f"Val: AUC={current_auc:.4f} F1={current_f1:.4f} Recall={current_recall:.4f} @ thr={thr_opt:.3f} | "
            f"LR={current_lr:.2e}"
        )

    # Early stopping
    if current_f1 > best_f1 + 1e-4:
        best_f1 = current_f1
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        best_threshold = thr_opt
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch}")
            break

# Load best model
if best_state is not None:
    model.load_state_dict(best_state)

print(f"Training complete. Best Val F1={best_f1:.4f} at threshold={best_threshold:.4f}")
if threshold_opt_metric == "recall":
    final_val_metrics = evaluate(model, val_loader, device, decision_threshold=best_threshold)
    final_recall = final_val_metrics["cm"][1, 1] / max(final_val_metrics["cm"][1, 1] + final_val_metrics["cm"][1, 0], 1)
    print(f"Final validation recall at best threshold: {final_recall:.4f} (target: {target_recall:.4f})")

## Evaluate on Test Set

In [None]:
# Evaluate using best threshold from validation
print(f"Evaluating on test set with threshold={best_threshold:.4f}")
test_metrics = evaluate(model, test_loader, device, decision_threshold=best_threshold)
cm = test_metrics["cm"]
probs = test_metrics["probs"]
y_true = test_metrics["y"]
preds = test_metrics["preds"]
auc_final = test_metrics["auc"]
print("Confusion matrix:\n", cm)
report = classification_report(y_true, preds, digits=3, zero_division=0)
print(report)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true, preds, average="binary", zero_division=0
)
print(f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f} AUC={auc_final:.3f}")

In [None]:
acc = (preds == y_true).mean()
precision = (
    ((preds & (y_true == 1)).sum() / max((preds == 1).sum(), 1))
    if (preds == 1).any()
    else 0.0
)
recall = (preds & (y_true == 1)).sum() / max((y_true == 1).sum(), 1)
specificity = ((preds == 0) & (y_true == 0)).sum() / max((y_true == 0).sum(), 1)
balanced_accuracy = 0.5 * (recall + specificity)
try:
    mcc = matthews_corrcoef(y_true, preds)
except Exception:
    mcc = float("nan")
try:
    ap = average_precision_score(y_true, probs)
except Exception:
    ap = float("nan")
try:
    brier = brier_score_loss(y_true, probs)
except Exception:
    brier = float("nan")

# Precision-Recall threshold analysis
prec_curve, rec_curve, thr_pr = precision_recall_curve(y_true, probs)
f1_curve = 2 * prec_curve * rec_curve / np.clip(prec_curve + rec_curve, 1e-9, None)
max_f1_idx = np.nanargmax(f1_curve)
opt_pr_threshold = (
    thr_pr[max_f1_idx - 1] if max_f1_idx > 0 and max_f1_idx - 1 < len(thr_pr) else 0.5
)
best_f1 = f1_curve[max_f1_idx]

# Youden J optimal ROC threshold
fpr_curve, tpr_curve, thr_roc = roc_curve(y_true, probs)
youden = tpr_curve - fpr_curve
j_idx = np.argmax(youden)
youden_thr = thr_roc[j_idx]

print("--- Extended Metrics ---")
print(f"Accuracy            : {acc:.4f}")
print(f"Precision (0.5 cut) : {precision:.4f}")
print(f"Recall (TPR)        : {recall:.4f}")
print(f"Specificity (TNR)   : {specificity:.4f}")
print(f"Balanced Accuracy   : {balanced_accuracy:.4f}")
print(f"MCC                 : {mcc:.4f}")
print(f"Average Precision   : {ap:.4f}")
print(f"Brier Score         : {brier:.4f}")
print(f"Best F1             : {best_f1:.4f} at PR threshold ~ {opt_pr_threshold:.4f}")
print(f"Youden J threshold  : {youden_thr:.4f}")

# Plot Precision-Recall curve
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(rec_curve, prec_curve, label=f"AP={ap:.3f}")
plt.scatter(
    rec_curve[max_f1_idx],
    prec_curve[max_f1_idx],
    marker="o",
    color="red",
    label="Best F1",
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(fpr_curve, tpr_curve, label=f"AUC={auc_final:.3f}")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - DNN")
plt.legend()
plt.savefig("model_output/precision_recall_curve.png", dpi=300)
plt.tight_layout()
plt.show()

In [None]:
# Evaluate performance breakdown by source file
print("=" * 70)
print("PER-FILE PERFORMANCE ON TEST SET")
print("=" * 70)

for file_idx, fname in enumerate(files):
    mask = file_test == file_idx
    if mask.sum() == 0:
        continue

    probs_file = probs[mask]
    y_file = y_true[mask]

    # Find optimal threshold for THIS file to maximize recall
    best_recall = 0
    best_thr_recall = best_threshold
    for t in np.linspace(0.1, 0.9, 50):
        preds_t = (probs_file >= t).astype(int)
        cm_t = confusion_matrix(y_file, preds_t)
        if cm_t.shape[0] > 1:
            recall_t = cm_t[1, 1] / max(cm_t[1, 1] + cm_t[1, 0], 1)
            if recall_t > best_recall:
                best_recall = recall_t
                best_thr_recall = t

    preds_file = (probs_file >= best_threshold).astype(int)
    preds_file_optimized = (probs_file >= best_thr_recall).astype(int)

    cm_file = confusion_matrix(y_file, preds_file)
    cm_file_opt = confusion_matrix(y_file, preds_file_optimized)

    try:
        auc_file = roc_auc_score(y_file, probs_file)
    except ValueError:
        auc_file = float("nan")

    try:
        f1_file = f1_score(y_file, preds_file, zero_division=0)
    except Exception:
        f1_file = float("nan")

    recall_file = (
        cm_file[1, 1] / max(cm_file[1, 1] + cm_file[1, 0], 1)
        if cm_file.shape[0] > 1
        else 0
    )
    recall_file_opt = (
        cm_file_opt[1, 1] / max(cm_file_opt[1, 1] + cm_file_opt[1, 0], 1)
        if cm_file_opt.shape[0] > 1
        else 0
    )
    precision_file = (
        cm_file[1, 1] / max(cm_file[1, 1] + cm_file[0, 1], 1)
        if cm_file.shape[0] > 1
        else 0
    )

    print(f"\n{fname.split('/')[-1]}:")
    print(
        f"  Samples: {mask.sum()} (pos={y_file.sum()}, neg={len(y_file) - y_file.sum()})"
    )
    print(f"  AUC: {auc_file:.4f}")
    print(f"  F1: {f1_file:.4f}")
    print(f"  Precision: {precision_file:.4f}")
    print(f"  Recall @ global thr={best_threshold:.3f}: {recall_file:.4f}")
    print(
        f"  Recall @ optimal thr={best_thr_recall:.3f}: {recall_file_opt:.4f} (+{(recall_file_opt - recall_file) * 100:.1f}%)"
    )
    print(f"  Confusion Matrix (global thr):\n{cm_file}")

In [None]:
plt.figure(figsize=(6, 5))
plt.hist(probs[y_true == 1], bins=40, histtype="step", label="matched")
plt.hist(probs[y_true == 0], bins=40, histtype="step", label="fake")
plt.xlabel("Predicted Probability (matched)")
plt.ylabel("Tracks")
plt.legend()
plt.tight_layout()
plt.savefig("model_output/predicted_probability.png", dpi=300)
plt.show()

In [None]:
plt.figure(figsize=(4, 3))
sns.heatmap(cm, annot=True, fmt="d", cmap="YlOrRd", cbar=False)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("model_output/confusion_matrix.png", dpi=300)
plt.show()

## Model improvements analysis

In [None]:
%%skip_if FULL_TRAINING
# Analyze current model's strengths and weaknesses
print("=" * 70)
print("MODEL GENERALIZATION ANALYSIS")
print("=" * 70)

# 1. Check for overfitting
train_metrics = evaluate(model, train_loader, device, decision_threshold=best_threshold)
val_metrics = evaluate(model, val_loader, device, decision_threshold=best_threshold)
test_metrics_final = evaluate(
    model, test_loader, device, decision_threshold=best_threshold
)

print("\nPerformance across splits:")
print(f"{'Split':<10} {'AUC':>8} {'F1':>8} {'Precision':>10} {'Recall':>8}")
print("-" * 50)
for split_name, metrics in [
    ("Train", train_metrics),
    ("Val", val_metrics),
    ("Test", test_metrics_final),
]:
    prec = (
        metrics["cm"][1, 1] / max(metrics["cm"][1, 1] + metrics["cm"][0, 1], 1)
        if metrics["cm"].shape[0] > 1
        else 0
    )
    rec = (
        metrics["cm"][1, 1] / max(metrics["cm"][1, 1] + metrics["cm"][1, 0], 1)
        if metrics["cm"].shape[0] > 1
        else 0
    )
    print(
        f"{split_name:<10} {metrics['auc']:>8.4f} {metrics['f1']:>8.4f} {prec:>10.4f} {rec:>8.4f}"
    )

# Calculate generalization gap
auc_gap = train_metrics["auc"] - test_metrics_final["auc"]
f1_gap = train_metrics["f1"] - test_metrics_final["f1"]

print(f"\nGeneralization gap:")
print(
    f"  AUC gap (train-test): {auc_gap:.4f} {'(good)' if auc_gap < 0.01 else '(check overfitting)'}"
)
print(
    f"  F1 gap (train-test): {f1_gap:.4f} {'(good)' if f1_gap < 0.01 else '(check overfitting)'}"
)

# 2. Analyze misclassifications
print("\n" + "=" * 70)
print("MISCLASSIFICATION ANALYSIS")
print("=" * 70)

# Get feature values for misclassified samples
test_probs = test_metrics_final["probs"]
test_y = test_metrics_final["y"]
test_preds = (test_probs >= best_threshold).astype(int)

false_positives = (test_preds == 1) & (test_y == 0)
false_negatives = (test_preds == 0) & (test_y == 1)

print(f"\nMisclassification breakdown:")
print(
    f"  False Positives: {false_positives.sum()} ({100 * false_positives.mean():.2f}%)"
)
print(
    f"  False Negatives: {false_negatives.sum()} ({100 * false_negatives.mean():.2f}%)"
)

# Analyze probability distribution for misclassified samples
if false_positives.any():
    fp_probs = test_probs[false_positives]
    print(f"\nFalse Positive probabilities:")
    print(f"  Mean: {fp_probs.mean():.3f}, Median: {np.median(fp_probs):.3f}")
    print(f"  Min: {fp_probs.min():.3f}, Max: {fp_probs.max():.3f}")

if false_negatives.any():
    fn_probs = test_probs[false_negatives]
    print(f"\nFalse Negative probabilities:")
    print(f"  Mean: {fn_probs.mean():.3f}, Median: {np.median(fn_probs):.3f}")
    print(f"  Min: {fn_probs.min():.3f}, Max: {fn_probs.max():.3f}")

In [None]:
%%skip_if FULL_TRAINING
# Perform k-fold cross-validation to check model stability
from sklearn.model_selection import StratifiedKFold

print("=" * 70)
print("K-FOLD CROSS-VALIDATION (for robustness check)")
print("=" * 70)

n_folds = 5
kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

cv_scores = {"auc": [], "f1": [], "precision": [], "recall": []}

# Use a subset for faster CV
cv_sample_size = min(500000, len(X_train_model))
cv_indices = np.random.choice(len(X_train_model), size=cv_sample_size, replace=False)
X_cv = X_train_model[cv_indices]
y_cv = y_train[cv_indices]

print(f"Running {n_folds}-fold CV on {cv_sample_size} samples...")

for fold, (train_idx, val_idx) in enumerate(kfold.split(X_cv, y_cv)):
    print(f"\nFold {fold + 1}/{n_folds}...", end=" ")

    X_fold_train, X_fold_val = X_cv[train_idx], X_cv[val_idx]
    y_fold_train, y_fold_val = y_cv[train_idx], y_cv[val_idx]

    # Create fold model
    fold_model = ImprovedMLP(
        in_features=X_train_model.shape[1],
        layers=hidden_layers,
        dropout=dropout,
        use_residual=True,
    ).to(device)

    fold_optimizer = torch.optim.AdamW(
        fold_model.parameters(), lr=lr, weight_decay=weight_decay
    )

    # Quick training (fewer epochs for CV)
    fold_train_ds = NumpyDataset(X_fold_train, y_fold_train)
    fold_val_ds = NumpyDataset(X_fold_val, y_fold_val)

    fold_train_loader = DataLoader(fold_train_ds, batch_size=batch_size, shuffle=True)
    fold_val_loader = DataLoader(fold_val_ds, batch_size=batch_size, shuffle=False)

    best_fold_f1 = 0
    patience_fold = 10
    no_improve_fold = 0

    for epoch in range(50):  # Max 50 epochs per fold
        fold_model.train()
        for xb, yb in fold_train_loader:
            xb, yb = xb.to(device), yb.to(device)
            fold_optimizer.zero_grad()
            out = fold_model(xb)
            loss = criterion(out, yb)
            loss.backward()
            fold_optimizer.step()

        # Check validation
        val_metrics_fold = evaluate(
            fold_model, fold_val_loader, device, decision_threshold=0.5
        )
        if val_metrics_fold["f1"] > best_fold_f1:
            best_fold_f1 = val_metrics_fold["f1"]
            no_improve_fold = 0
        else:
            no_improve_fold += 1
            if no_improve_fold >= patience_fold:
                break

    # Final evaluation
    final_metrics = evaluate(
        fold_model, fold_val_loader, device, decision_threshold=best_threshold
    )
    cv_scores["auc"].append(final_metrics["auc"])
    cv_scores["f1"].append(final_metrics["f1"])

    prec = final_metrics["cm"][1, 1] / max(
        final_metrics["cm"][1, 1] + final_metrics["cm"][0, 1], 1
    )
    rec = final_metrics["cm"][1, 1] / max(
        final_metrics["cm"][1, 1] + final_metrics["cm"][1, 0], 1
    )
    cv_scores["precision"].append(prec)
    cv_scores["recall"].append(rec)

    print(f"AUC={final_metrics['auc']:.4f}, F1={final_metrics['f1']:.4f}")

print("\n" + "=" * 70)
print("CROSS-VALIDATION RESULTS")
print("=" * 70)
for metric_name, scores in cv_scores.items():
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    print(f"{metric_name.capitalize():>12}: {mean_score:.4f} ¬± {std_score:.4f}")

print(
    f"\nModel stability: {'Excellent' if np.std(cv_scores['f1']) < 0.01 else 'Good' if np.std(cv_scores['f1']) < 0.02 else 'Needs improvement'}"
)

In [None]:
%%skip_if FULL_TRAINING
# Train an ensemble of models with different initializations
print("=" * 70)
print("ENSEMBLE MODEL TRAINING")
print("=" * 70)

n_ensemble = 3  # Number of models in ensemble
ensemble_models = []
ensemble_seeds = [42, 123, 456]

print(f"Training {n_ensemble} models with different initializations...")

for i, seed in enumerate(ensemble_seeds):
    print(f"\n--- Model {i+1}/{n_ensemble} (seed={seed}) ---")
    
    # Set seed for this model
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    # Create new model
    ensemble_model = ImprovedMLP(
        in_features=X_train_model.shape[1],
        layers=hidden_layers,
        dropout=dropout,
        use_residual=True
    ).to(device)
    
    ensemble_optimizer = torch.optim.AdamW(
        ensemble_model.parameters(), 
        lr=lr, 
        weight_decay=weight_decay
    )
    
    # Train for fewer epochs
    best_ensemble_f1 = 0
    best_ensemble_state = None
    patience_ensemble = 20
    no_improve_ensemble = 0
    
    for epoch in range(100):  # Max 100 epochs
        ensemble_model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb to(device)
            ensemble_optimizer.zero_grad()
            out = ensemble_model(xb)
            loss = criterion(out, yb)
            loss.backward()
            ensemble_optimizer.step()
        
        # Validation
        if epoch % 10 == 0 or epoch > 80:
            val_metrics_ens = evaluate(ensemble_model, val_loader, device, decision_threshold=0.5)
            
            # Optimize threshold
            probs_val = val_metrics_ens["probs"]
            y_val_np = val_metrics_ens["y"]
            prec_curve, rec_curve, thr_pr = precision_recall_curve(y_val_np, probs_val)
            f1_curve = 2 * prec_curve * rec_curve / np.clip(prec_curve + rec_curve, 1e-9, None)
            max_f1_idx = np.nanargmax(f1_curve)
            thr_opt = thr_pr[max_f1_idx] if max_f1_idx < len(thr_pr) else 0.5
            
            val_metrics_thr = evaluate(ensemble_model, val_loader, device, decision_threshold=thr_opt)
            current_f1 = val_metrics_thr["f1"]
            
            if current_f1 > best_ensemble_f1:
                best_ensemble_f1 = current_f1
                best_ensemble_state = {k: v.cpu().clone() for k, v in ensemble_model.state_dict().items()}
                no_improve_ensemble = 0
            else:
                no_improve_ensemble += 1
                if no_improve_ensemble >= patience_ensemble:
                    break
    
    # Load best state
    if best_ensemble_state is not None:
        ensemble_model.load_state_dict(best_ensemble_state)
    
    ensemble_models.append(ensemble_model)
    print(f"  Best val F1: {best_ensemble_f1:.4f}")

# Ensemble prediction on test set
print("\n" + "=" * 70)
print("ENSEMBLE EVALUATION")
print("=" * 70)

# Get predictions from each model
ensemble_probs = []
for i, ens_model in enumerate(ensemble_models):
    ens_model.eval()
    with torch.no_grad():
        test_tensor = torch.from_numpy(X_test_model.astype(np.float32)).to(device)
        logits = ens_model(test_tensor)
        probs = torch.sigmoid(logits).cpu().numpy().squeeze()
        ensemble_probs.append(probs)

# Average predictions
ensemble_probs_avg = np.mean(ensemble_probs, axis=0)
ensemble_preds = (ensemble_probs_avg >= best_threshold).astype(int)

# Evaluate ensemble
cm_ensemble = confusion_matrix(y_test, ensemble_preds)
auc_ensemble = roc_auc_score(y_test, ensemble_probs_avg)
f1_ensemble = f1_score(y_test, ensemble_preds, zero_division=0)

prec_ensemble = cm_ensemble[1,1] / max(cm_ensemble[1,1] + cm_ensemble[0,1], 1)
rec_ensemble = cm_ensemble[1,1] / max(cm_ensemble[1,1] + cm_ensemble[1,0], 1)

print("\nEnsemble vs Single Model Performance:")
print(f"{'Metric':<15} {'Single Model':>15} {'Ensemble':>15} {'Improvement':>15}")
print("-" * 60)
print(f"{'AUC':<15} {auc_final:>15.4f} {auc_ensemble:>15.4f} {'+' if auc_ensemble > auc_final else ''}{(auc_ensemble - auc_final)*100:>14.2f}%")
print(f"{'F1':<15} {f1:>15.4f} {f1_ensemble:>15.4f} {'+' if f1_ensemble > f1 else ''}{(f1_ensemble - f1)*100:>14.2f}%")
print(f"{'Precision':<15} {precision:>15.4f} {prec_ensemble:>15.4f} {'+' if prec_ensemble > precision else ''}{(prec_ensemble - precision)*100:>14.2f}%")
print(f"{'Recall':<15} {recall:>15.4f} {rec_ensemble:>15.4f} {'+' if rec_ensemble > recall else ''}{(rec_ensemble - recall)*100:>14.2f}%")

print(f"\n{'‚úì Ensemble improves generalization!' if f1_ensemble > f1 else 'Single model is already well-optimized'}")

In [None]:
%%skip_if FULL_TRAINING
# Analyze which features contribute to misclassifications
print("=" * 70)
print("FEATURE ANALYSIS FOR MISCLASSIFICATIONS")
print("=" * 70)

# Get feature values for correct vs incorrect predictions
correct_preds = test_preds == y_test
incorrect_preds = ~correct_preds

print(f"\nAnalyzing {len(selected_feature_names)} features...")
print("\nFeature statistics (correct vs incorrect predictions):")
print(f"{'Feature':<25} {'Correct Mean':>15} {'Incorrect Mean':>18} {'Difference':>12}")
print("-" * 75)

significant_features = []
for i, feat_name in enumerate(selected_feature_names):
    feat_short = feat_name.split("_")[-1]

    correct_mean = X_test_model[correct_preds, i].mean()
    incorrect_mean = X_test_model[incorrect_preds, i].mean()
    diff = abs(correct_mean - incorrect_mean)

    print(
        f"{feat_short:<25} {correct_mean:>15.3f} {incorrect_mean:>18.3f} {diff:>12.3f}"
    )

    if diff > 0.5:  # Significant difference (in standardized space)
        significant_features.append((feat_short, diff))

if significant_features:
    print(f"\n{len(significant_features)} features show significant differences:")
    for feat, diff in sorted(significant_features, key=lambda x: x[1], reverse=True):
        print(f"  - {feat}: {diff:.3f}")
    print("\nüí° Consider adding interaction terms between these features")
else:
    print("\n‚úì Features are well-balanced between correct and incorrect predictions")

In [None]:
%%skip_if FULL_TRAINING
print("=" * 70)
print("MODEL IMPROVEMENT RECOMMENDATIONS")
print("=" * 70)

# Calculate current performance metrics
generalization_quality = (
    "Excellent"
    if abs(train_metrics["f1"] - test_metrics_final["f1"]) < 0.01
    else "Good"
)
model_size = sum(p.numel() for p in model.parameters())

print(f"\nüìä Current Model Status:")
print(f"  Performance: F1={f1:.4f}, AUC={auc_final:.4f}")
print(f"  Generalization: {generalization_quality}")
print(f"  Model Size: {model_size:,} parameters")
print(
    f"  Features Used: {len(selected_feature_names)}/32 ({100 * len(selected_feature_names) / 32:.0f}%)"
)

print(f"\nüéØ Recommended Next Steps:")

if abs(train_metrics["f1"] - test_metrics_final["f1"]) > 0.02:
    print(
        "  1. ‚ö†Ô∏è  Model shows overfitting - increase regularization (dropout, weight_decay)"
    )
elif abs(train_metrics["f1"] - test_metrics_final["f1"]) > 0.01:
    print("  1. ‚ö†Ô∏è  Slight overfitting detected - consider slight increase in dropout")
else:
    print("  1. ‚úì Generalization is excellent - model is well-regularized")

if f1 < 0.95:
    print("  2. üìà Try ensemble approach to boost F1 score")
    print("  3. üîç Analyze misclassifications to add targeted features")
else:
    print("  2. ‚úì F1 score is already very high (95%+)")

if len(selected_feature_names) < 15:
    print(f"  3. üí° Consider adding 2-3 more features from importance ranking")
elif len(selected_feature_names) > 25:
    print(f"  3. üí° Model might benefit from more aggressive feature pruning")
else:
    print(f"  3. ‚úì Feature count ({len(selected_feature_names)}) is well-balanced")

print("\nüöÄ Advanced Improvements (if needed):")
print("  ‚Ä¢ Add feature interactions (multiplicative terms)")
print("  ‚Ä¢ Try deeper architecture with more regularization")
print("  ‚Ä¢ Use focal loss with gamma=2-5 for hard examples")
print("  ‚Ä¢ Implement calibration (Platt scaling) for better probabilities")
print("  ‚Ä¢ Try label smoothing (epsilon=0.01) to prevent overconfidence")

print("\nüíª For Production:")
print("  ‚Ä¢ Current model is production-ready if F1 > 0.95")
print("  ‚Ä¢ Consider ensemble of 3-5 models for maximum robustness")
print("  ‚Ä¢ Monitor per-file performance for physics process variations")
print("  ‚Ä¢ Set threshold based on physics requirements (precision vs recall trade-off)")

print("\n" + "=" * 70)

In [None]:
# pt format for python evaluation
artifact = {
    "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
    "model_config": {
        "in_features": X_train_model.shape[1],  # Use pruned feature count
        "layers": hidden_layers,
        "dropout": dropout,
    },
    "scaler": scaler,
    "scaler_indices": selected_feature_indices if use_feature_pruning else None,
    "feature_names": feature_names_model,  # Use pruned feature names
    "original_feature_names": feature_names,  # Keep original for reference
    "best_threshold": best_threshold,
    "auc_test": float(auc_final),
    "use_feature_pruning": use_feature_pruning,
}
output_artifact = "model_output/simplified_DNN.pt"
torch.save(artifact, output_artifact)
size_mb = os.path.getsize(output_artifact) / 1024**2
print(f"Saved artifact to {output_artifact} ({size_mb:.2f} MB)")
print(f"Model uses {len(feature_names_model)}/{len(feature_names)} features")

In [None]:
# Save scaler parameters and feature info for CMSSW
# IMPORTANT: Save only the scaler params for the features we actually use
if use_feature_pruning:
    # Extract scaler parameters only for selected features
    scaler_mean_pruned = scaler.mean_[selected_feature_indices].tolist()
    scaler_scale_pruned = scaler.scale_[selected_feature_indices].tolist()
    features_to_save = selected_feature_names
else:
    scaler_mean_pruned = scaler.mean_.tolist()
    scaler_scale_pruned = scaler.scale_.tolist()
    features_to_save = feature_names

scaler_params = {
    "mean": scaler_mean_pruned,
    "scale": scaler_scale_pruned,
    "pruned_features_indices": selected_feature_indices.tolist()
    if use_feature_pruning
    else None,
    "feature_names": features_to_save,
    "threshold": float(best_threshold),
    "n_features": len(features_to_save),
}

import json

with open("model_output/simplified_scaler_params.json", "w") as f:
    json.dump(scaler_params, f, indent=2)

print("Saved scaler parameters to model_output/simplified_scaler_params.json")
print(f"  Features saved: {len(features_to_save)}")
print(f"  Feature names: {[f.split('_')[-1] for f in features_to_save]}")

# Sanity check: verify dimensions match
assert len(scaler_mean_pruned) == len(features_to_save), (
    "Scaler mean dimension mismatch!"
)
assert len(scaler_scale_pruned) == len(features_to_save), (
    "Scaler scale dimension mismatch!"
)
assert len(features_to_save) == X_train_model.shape[1], (
    "Feature count mismatch with model input!"
)
print("‚úì Scaler parameters match model input dimensions")

In [None]:
## Export Model to ONNX for CMSSW integration
import torch.onnx

# Set model to evaluation mode
model.eval()
model.to("cpu")

# Create dummy input with the PRUNED feature shape (not original X.shape!)
dummy_input = torch.randn(1, X_train_model.shape[1], dtype=torch.float32)

# Export to ONNX
onnx_path = "model_output/simplified_dnn.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,  # CMSSW typically supports opset 11-14
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

print(f"Model exported to {onnx_path}")
print(f"  Input shape: (batch_size, {X_train_model.shape[1]})")
print(f"  Features: {[f.split('_')[-1] for f in feature_names_model]}")

# Verify the exported model
import onnx

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")

# Test inference with ONNX Runtime using PRUNED features
import onnxruntime as ort

ort_session = ort.InferenceSession(onnx_path)

# Test with a few samples from PRUNED test set
test_input = X_test_model[:10].astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: test_input}
ort_outputs = ort_session.run(None, ort_inputs)

# Compare with PyTorch output
with torch.no_grad():
    torch_out = model(torch.from_numpy(test_input))
    torch_probs = torch.sigmoid(torch_out).numpy()

onnx_probs = 1.0 / (1.0 + np.exp(-ort_outputs[0]))  # sigmoid

print(
    "Max difference between PyTorch and ONNX:", np.abs(torch_probs - onnx_probs).max()
)
print("ONNX Runtime inference successful!")


ort_session = ort.InferenceSession(onnx_path)
input_name = ort_session.get_inputs()[0].name

# Test 1: Single track
print("Test 1: Single track")
single_input = X_test_model[:1].astype(np.float32)
single_output = ort_session.run(None, {input_name: single_input})
print(f"  Input shape: {single_input.shape}, Output shape: {single_output[0].shape}")

# Test 2: Multiple tracks (batch)
print("\nTest 2: Batch of 10 tracks")
batch_input = X_test_model[:10].astype(np.float32)
batch_output = ort_session.run(None, {input_name: batch_input})
print(f"  Input shape: {batch_input.shape}, Output shape: {batch_output[0].shape}")

# Test 3: What happens with different batch sizes
print("\nTest 3: Various batch sizes")
for batch_size in [1, 5, 100, 1000]:
    test_input = X_test_model[:batch_size].astype(np.float32)
    test_output = ort_session.run(None, {input_name: test_input})
    print(
        f"  Batch size {batch_size}: Input {test_input.shape} -> Output {test_output[0].shape}"
    )

print("\n‚úì ONNX model handles all batch sizes correctly")
print("Note: Empty batches (0 tracks) should be handled in C++ before calling ONNX")