In [None]:
import uproot
import awkward as ak
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    classification_report,
    precision_recall_fscore_support,
    balanced_accuracy_score,
    matthews_corrcoef,
    average_precision_score,
    f1_score,
    brier_score_loss,
)
import os

os.environ["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

VERBOSE = False

main_branch = "Events"
tk_branches = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
    "muon_pixel_tracks_matched",
    "muon_pixel_tracks_duplicate",
    "muon_pixel_tracks_tpPdgId",
    "muon_pixel_tracks_tpPt",
    "muon_pixel_tracks_tpEta",
    "muon_pixel_tracks_tpPhi",
]
gen_branches = [
    "GenPart_pt",
    "GenPart_eta",
    "GenPart_phi",
    "GenPart_mass",
    "GenPart_pdgId",
    "GenPart_statusFlags",  # added to select last-copy muons
]

l1tkMuon_branches = [
    "L1TkMu_pt",
    "L1TkMu_eta",
    "L1TkMu_phi",
]

allPixel = False
useSpring24 = True

# Configuration Parameters
filesSelector = [
    "data/ntuples_TTbarCAExtensionFull.root",
    "data/ntuples_ZMMCAExtensionFull.root",
    "data/ntuples_WprimeCAExtensionFull.root",
]

filesAllPixel = [
    "data/ntuples_TTbarCAExtensionAllPixel.root",
    "data/ntuples_ZMMCAExtensionAllPixel.root",
    "data/ntuples_WprimeCAExtensionAllPixel.root",
]

if useSpring24:
    filesSelector += [
        "data/spring24/ntuplesExtNoHP_BsToMuMuG.root",
        "data/spring24/ntuplesExtNoHP_DYToLL.root",
        "data/spring24/ntuplesExtNoHP_TTTo2L2Nu.root",
        "data/spring24/ntuplesExtNoHP_TTToSemileptonic.root",
    ]

    filesAllPixel += [
        "data/spring24/ntuplesAllPixelNoHP_BsToMuMuG.root",
        "data/spring24/ntuplesAllPixelNoHP_DYToLL.root",
        "data/spring24/ntuplesAllPixelNoHP_TTTo2L2Nu.root",
        "data/spring24/ntuplesAllPixelNoHP_TTToSemileptonic.root",
    ]

files = filesSelector if not allPixel else filesAllPixel

print(files)

# ntuples selection
arrays = []
file_labels = []  # Track which file each event came from
for file_idx, f in enumerate(files):
    with uproot.open(f) as file:
        arrays_f = file[main_branch].arrays(
            tk_branches + gen_branches + l1tkMuon_branches
        )
        n_events = len(arrays_f)
        arrays = ak.concatenate([arrays, arrays_f], axis=0)
        file_labels.extend([file_idx] * n_events)  # Label events by file
    print(f"Done loading {f} ({n_events} events)")

print(f"Loaded {len(arrays)} events from {len(files)} files")
file_labels = np.array(file_labels)

In [None]:
# Reproducibility
"""
import random
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
"""
use_gpu = True
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Torch version HIP:", torch.version.hip)
print("CUDA device name:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
print("Using device:", device)

## Build Feature Matrix and Labels

In [None]:
features = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
]

LABEL_FIELD = "muon_pixel_tracks_matched"
useL1TkMuFeatures = True


def wrap_phi(phi):
    """Wrap phi to [-pi, pi]"""
    return ((phi + np.pi) % (2 * np.pi)) - np.pi


def build_dataset(arr, file_labels_in):
    feature_names = list(features)
    mask = arr[feature_names[0]] >= 0

    # Expand file labels from event level to track level
    n_tracks_per_event = ak.num(arr[feature_names[0]])
    file_labels_jagged = ak.unflatten(
        np.repeat(file_labels_in, n_tracks_per_event), n_tracks_per_event
    )
    file_labels_masked = ak.to_numpy(ak.flatten(file_labels_jagged[mask]))

    # Store original kinematics before log transformation
    trk_pt_original = arr["muon_pixel_tracks_pt"]
    trk_ptErr_original = arr["muon_pixel_tracks_ptErr"]
    trk_eta_original = arr["muon_pixel_tracks_eta"]
    trk_phi_original = arr["muon_pixel_tracks_phi"]
    trk_chi2 = arr["muon_pixel_tracks_chi2"]
    trk_nFoundHits = arr["muon_pixel_tracks_nFoundHits"]
    trk_nLostHits = arr["muon_pixel_tracks_nLostHits"]
    trk_dxy = arr["muon_pixel_tracks_dxy"]
    trk_dz = arr["muon_pixel_tracks_dz"]
    trk_dxyErr = arr["muon_pixel_tracks_dxyErr"]
    trk_dzErr = arr["muon_pixel_tracks_dzErr"]
    trk_qoverp_original = arr["muon_pixel_tracks_qoverp"]
    trk_qoverpErr_original = arr["muon_pixel_tracks_qoverpErr"]

    cols = []
    for f in features:
        minimum = ak.min(ak.flatten(arr[f][mask]))
        maximum = ak.max(ak.flatten(arr[f][mask]))
        if f in ["muon_pixel_tracks_p", "muon_pixel_tracks_pt"] or "Err" in f:
            print(f"Feature {f} min {minimum:.2f} max {maximum:.2f} -> log10")
            arr[f] = np.log10(arr[f] + 1e-6)
        flat = ak.to_numpy(ak.flatten(arr[f][mask]))
        cols.append(flat)

    # DERIVED FEATURES - Help model generalize across physics processes
    print("\nAdding derived features...")

    # 1. Relative momentum uncertainty (already have this)
    sigmaPtOverPt = np.log10(trk_ptErr_original / (trk_pt_original + 1e-6))
    cols.append(ak.to_numpy(ak.flatten(sigmaPtOverPt[mask])))
    feature_names.append("muon_pixel_tracks_sigmaPtOverPt")

    # 2. Hit efficiency: found / (found + lost)
    hit_efficiency = trk_nFoundHits / (trk_nFoundHits + trk_nLostHits + 1e-6)
    cols.append(ak.to_numpy(ak.flatten(hit_efficiency[mask])))
    feature_names.append("muon_pixel_tracks_hitEfficiency")

    # 3. Chi2 per hit (quality per measurement)
    chi2_per_hit = trk_chi2 / (trk_nFoundHits + 1e-6)
    cols.append(ak.to_numpy(ak.flatten(np.log10(chi2_per_hit + 1e-6)[mask])))
    feature_names.append("muon_pixel_tracks_chi2PerHit")

    # 4. Impact parameter significance (3D)
    impact_param_3d = trk_dxy**2 + trk_dz**2
    cols.append(ak.to_numpy(ak.flatten(np.log10(impact_param_3d + 1e-6)[mask])))
    feature_names.append("muon_pixel_tracks_impact3D")

    # 5. Impact parameter significance (normalized by uncertainty)
    dxy_significance = trk_dxy / (trk_dxyErr + 1e-6)
    dz_significance = trk_dz / (trk_dzErr + 1e-6)
    impact_significance_2d = np.sqrt(dxy_significance**2 + dz_significance**2)
    cols.append(ak.to_numpy(ak.flatten(np.log10(impact_significance_2d + 1e-6)[mask])))
    feature_names.append("muon_pixel_tracks_impactSignificance")

    # 6. Relative uncertainties product (captures overall measurement quality)
    rel_uncertainty_product = (trk_ptErr_original / trk_pt_original) * (
        trk_qoverpErr_original / trk_qoverp_original
    )
    cols.append(ak.to_numpy(ak.flatten(np.log10(rel_uncertainty_product + 1e-6)[mask])))
    feature_names.append("muon_pixel_tracks_relUncertaintyProduct")

    if useL1TkMuFeatures:
        print("\nComputing L1TkMuon matching features...")

        l1_pt = arr["L1TkMu_pt"]
        l1_eta = arr["L1TkMu_eta"]
        l1_phi = arr["L1TkMu_phi"]

        trk_zip = ak.zip(
            {"pt": trk_pt_original, "eta": trk_eta_original, "phi": trk_phi_original}
        )
        l1_zip = ak.zip({"pt": l1_pt, "eta": l1_eta, "phi": l1_phi})

        pairs = ak.cartesian({"t": trk_zip, "l": l1_zip}, axis=1, nested=True)

        deta = pairs.t.eta - pairs.l.eta
        dphi = wrap_phi(pairs.t.phi - pairs.l.phi)
        dR2 = deta**2 + dphi**2

        min_idx = ak.argmin(dR2, axis=2)
        dR2_min = ak.firsts(
            dR2[ak.local_index(dR2, axis=2) == min_idx[..., None]], axis=2
        )
        l1_pt_matched = ak.firsts(
            pairs.l.pt[ak.local_index(pairs.l.pt, axis=2) == min_idx[..., None]], axis=2
        )

        dPt_norm = np.abs(trk_pt_original - l1_pt_matched) / (l1_pt_matched + 1e-6)

        # L1 pT ratio
        pt_ratio = trk_pt_original / (l1_pt_matched + 1e-6)

        # Combined matching score (dR and pT)
        matching_score = dR2_min * (1 + dPt_norm)

        dR2_min = ak.fill_none(dR2_min, 999.0)
        dPt_norm = ak.fill_none(dPt_norm, 999.0)
        pt_ratio = ak.fill_none(pt_ratio, 0.0)
        matching_score = ak.fill_none(matching_score, 999.0)

        dR2_min_flat = ak.to_numpy(ak.flatten(dR2_min[mask]))
        dPt_norm_flat = ak.to_numpy(ak.flatten(dPt_norm[mask]))
        pt_ratio_flat = ak.to_numpy(ak.flatten(pt_ratio[mask]))
        matching_score_flat = ak.to_numpy(ak.flatten(matching_score[mask]))

        valid_matches = dPt_norm_flat < 999.0

        if VERBOSE:
            print(
                f"  ΔR2_min: min={dR2_min_flat.min():.4f}, max={dR2_min_flat.max():.4f}, mean={dR2_min_flat.mean():.4f}"
            )
            print(
                f"  ΔpT_norm: min={dPt_norm_flat.min():.4f}, max={dPt_norm_flat.max():.4f}, mean={dPt_norm_flat.mean():.4f}"
            )

            print(
                f"  Valid L1TkMu matches: {valid_matches.sum()}/{len(valid_matches)} ({valid_matches.mean() * 100:.2f}%)"
            )
            if valid_matches.any():
                print(
                    f"  ΔpT_norm (valid only): mean={dPt_norm_flat[valid_matches].mean():.4f}, "
                    f"median={np.median(dPt_norm_flat[valid_matches]):.4f}, "
                    f"95th percentile={np.percentile(dPt_norm_flat[valid_matches], 95):.4f}"
                )

        cols.append(np.log10(dR2_min_flat + 1e-6))
        cols.append(np.log10(dPt_norm_flat + 1e-6))
        cols.append(np.log10(pt_ratio_flat + 1e-6))
        cols.append(np.log10(matching_score_flat + 1e-6))

        feature_names.append("L1TkMu_dR2min")
        feature_names.append("L1TkMu_dPtNorm")
        feature_names.append("L1TkMu_ptRatio")
        feature_names.append("L1TkMu_matchingScore")

    X = np.vstack(cols).T
    y = ak.to_numpy(ak.flatten(arr[LABEL_FIELD][mask])).astype(np.int8)

    finite = np.isfinite(X).all(axis=1)
    if not finite.all():
        print(f"Removing {(~finite).sum()} non-finite samples")
        X = X[finite]
        y = y[finite]
        file_labels_masked = file_labels_masked[finite]

    return X, y, file_labels_masked, feature_names


X, y, file_labels_flat, feature_names = build_dataset(arrays, file_labels)
print("Feature matrix shape:", X.shape)
print("Labels shape:", y.shape, "Positives:", y.sum(), f"({100 * y.mean():.2f}%)")
print("Feature order:", feature_names)

# Show distribution across files
print("\nSample distribution by file:")
for file_idx, fname in enumerate(files):
    n_samples = (file_labels_flat == file_idx).sum()
    n_pos = ((file_labels_flat == file_idx) & (y == 1)).sum()
    print(
        f"  File {file_idx} ({fname.split('/')[-1]}): {n_samples} samples ({n_pos} positive, {100 * n_pos / max(n_samples, 1):.2f}%)"
    )

## Train/Test Split and Scaling

In [None]:
train_frac = 0.7  # fraction for train+test
val_frac = 0.15   # fraction (of full dataset) reserved for validation (inside train+val)

# Create composite stratification label: combine class label + file source
# This ensures both class balance AND file representation in each split
stratify_label = y * len(files) + file_labels_flat  # Unique label per (class, file) combination

# First split train_val vs test
X_train_val, X_test, y_train_val, y_test, file_train_val, file_test = train_test_split(
    X,
    y,
    file_labels_flat,
    train_size=train_frac,
    stratify=stratify_label if (y.sum() > 0 and y.sum() < len(y)) else None,
)

# Derive actual validation fraction relative to train_val portion
if val_frac > 0:
    rel_val = val_frac / train_frac  # portion of train_val to carve out as validation
    stratify_train_val = y_train_val * len(files) + file_train_val
    X_train, X_val, y_train, y_val, file_train, file_val = train_test_split(
        X_train_val,
        y_train_val,
        file_train_val,
        test_size=rel_val,
        stratify=stratify_train_val if (y_train_val.sum() > 0 and y_train_val.sum() < len(y_train_val)) else None,
    )
else:
    X_train, y_train, file_train = X_train_val, y_train_val, file_train_val
    X_val, y_val, file_val = X_test, y_test, file_test

print("Train size:", X_train.shape, "Val size:", X_val.shape, "Test size:", X_test.shape)

In [None]:
# Examine features pre-scaling 
print("Pre-scaling feature stats:")
for i in range(X_train.shape[1]):
    print(f"{feature_names[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}")

In [None]:
# Scale features to zero mean and unit variance
print("\nPost-scaling feature stats:")
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)
for i in range(X_train.shape[1]):
    print(f"{feature_names[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}")

In [None]:
# Verify file distribution in each split
print("\nFile Distribution Verification")
for split_name, file_split, y_split in [("Train", file_train, y_train), 
                                          ("Val", file_val, y_val), 
                                          ("Test", file_test, y_test)]:
    print(f"\n{split_name} set:")
    for file_idx, fname in enumerate(files):
        n_samples = (file_split == file_idx).sum()
        n_pos = ((file_split == file_idx) & (y_split == 1)).sum()
        pct_of_split = 100 * n_samples / len(file_split)
        print(f"  {fname.split('/')[-1]}: {n_samples} samples ({pct_of_split:.1f}% of {split_name.lower()}, {n_pos} pos)")

In [None]:
use_undersampling = True  # Keep file-aware undersampling
use_weighted_sampler = False

class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32)).unsqueeze(1)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

if use_undersampling:
    balanced_indices = []
    
    print("\n=== File-Aware Undersampling (Optimized for Generalization) ===")
    for file_idx, fname in enumerate(files):
        file_mask = file_train == file_idx
        file_indices = np.where(file_mask)[0]
        
        pos_mask = y_train[file_mask] == 1
        neg_mask = y_train[file_mask] == 0
        
        pos_indices = file_indices[pos_mask]
        neg_indices = file_indices[neg_mask]
        
        file_imbalance_ratio = len(neg_indices) / max(len(pos_indices), 1)
        
        # More aggressive on imbalanced files, but keep consistent ratio
        if file_imbalance_ratio > 10:
            undersample_ratio = 4  # Slightly more data than before
        elif file_imbalance_ratio > 5:
            undersample_ratio = 5
        else:
            undersample_ratio = 6  # Slightly less for ZMM to maintain balance
        
        n_neg_keep = min(len(neg_indices), len(pos_indices) * undersample_ratio)
        neg_indices_sampled = np.random.choice(neg_indices, size=n_neg_keep, replace=False)
        
        file_balanced = np.concatenate([pos_indices, neg_indices_sampled])
        balanced_indices.append(file_balanced)
        
        print(f"{fname.split('/')[-1]}:")
        print(f"  Before: pos={len(pos_indices)} neg={len(neg_indices)} (ratio={file_imbalance_ratio:.1f})")
        print(f"  After:  pos={len(pos_indices)} neg={n_neg_keep} (ratio={n_neg_keep/max(len(pos_indices),1):.1f})")
    
    balanced_indices = np.concatenate(balanced_indices)
    np.random.shuffle(balanced_indices)
    
    X_train_balanced = X_train[balanced_indices]
    y_train_balanced = y_train[balanced_indices]
    
    print(f"\nTotal: {len(y_train)} -> {len(y_train_balanced)} samples")
    print(f"  Class balance: pos={y_train_balanced.sum()} ({100*y_train_balanced.mean():.1f}%), neg={len(y_train_balanced)-y_train_balanced.sum()}")
    
    train_ds = NumpyDataset(X_train_balanced, y_train_balanced)
else:
    train_ds = NumpyDataset(X_train, y_train)

pos = y_train_balanced.sum() if use_undersampling else y_train.sum()
neg = (len(y_train_balanced) if use_undersampling else len(y_train)) - pos
if pos == 0:
    pos_weight_value = 1.0
else:
    pos_weight_value = neg / pos

# Balanced approach: moderate pos_weight + undersampling
pos_weight_multiplier = 3  # Not too aggressive
pos_weight = torch.tensor([pos_weight_value * pos_weight_multiplier], dtype=torch.float32, device=device)
print(f"\nClass counts train: pos={pos} neg={neg} -> pos_weight={pos_weight_value:.3f} (x{pos_weight_multiplier} = {pos_weight.item():.3f})")

## Torch Dataset and DataLoader

In [None]:
batch_size = 1024

val_ds = NumpyDataset(X_val, y_val)
test_ds = NumpyDataset(X_test, y_test)

if use_weighted_sampler and not use_undersampling:
    # inverse frequency sampling to upweight minority
    class_sample_counts = np.array([ (y_train==0).sum(), (y_train==1).sum() ])
    weights = 1.0 / np.clip(class_sample_counts, 1, None)
    sample_weights = weights[y_train]
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=False)
else:
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

xb, yb = next(iter(train_loader))
print("First batch shapes:", xb.shape, yb.shape)
print("Class balance train: pos=", (y_train_balanced if use_undersampling else y_train).sum(), 
      "neg=", len(y_train_balanced if use_undersampling else y_train)-(y_train_balanced if use_undersampling else y_train).sum())

## Define MLP Architecture

In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, layers, dropout=0):
        super().__init__()
        seq = []
        prev = in_features
        for h in layers:
            seq.append(nn.Linear(prev, h))
            seq.append(nn.BatchNorm1d(h))
            seq.append(nn.ReLU())
            if dropout > 0:
                seq.append(nn.Dropout(dropout))
            prev = h
        seq.append(nn.Linear(prev, 1))
        self.net = nn.Sequential(*seq)

    def forward(self, x):
        return self.net(x)

hidden_layers = [64,32] #[512, 256, 128, 64, 32, 16]
dropout = 0.15  # slightly higher dropout to regularize
lr = 2e-3        # smaller LR for stability
weight_decay = 1e-4
epochs = 1000
patience = 50    # more realistic patience now that we use validation

model = MLP(in_features=X.shape[1], layers=hidden_layers, dropout=dropout).to(device)
print(model)

## Training Loop with Early Stopping

In [None]:
# Multi-metric configuration
primary_metric = "f1"  # options: 'auc','f1','ap','balanced_accuracy','mcc'
optimize_threshold = True
threshold_opt_metric = "f1"  # which metric to maximize when choosing threshold
min_improvement = 1e-4  # required relative improvement for early stopping reset
use_focal = True
focal_gamma = 3

@torch.no_grad()
def evaluate(model, loader, device, decision_threshold=0.5):
    model.eval()
    logits_list = []
    y_list = []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        logits_list.append(logits.cpu())
        y_list.append(yb.cpu())
    logits = torch.cat(logits_list).squeeze(1)
    y_true = torch.cat(y_list).squeeze(1)
    probs = torch.sigmoid(logits).numpy()
    y_np = y_true.numpy().astype(int)
    preds = (probs >= decision_threshold).astype(int)
    cm = confusion_matrix(y_np, preds)
    try:
        auc = roc_auc_score(y_np, probs)
    except ValueError:
        auc = float("nan")
    try:
        ap = average_precision_score(y_np, probs)
    except ValueError:
        ap = float("nan")
    try:
        bal_acc = balanced_accuracy_score(y_np, preds)
    except Exception:
        bal_acc = float("nan")
    try:
        f1 = f1_score(y_np, preds, zero_division=0)
    except Exception:
        f1 = float("nan")
    try:
        mcc = matthews_corrcoef(y_np, preds)
    except Exception:
        mcc = float("nan")
    return dict(
        probs=probs,
        y=y_np,
        preds=preds,
        cm=cm,
        auc=auc,
        ap=ap,
        bal_acc=bal_acc,
        f1=f1,
        mcc=mcc,
    )

# Optional Focal Loss wrapper
class FocalBCEWithLogits(nn.Module):
    def __init__(self, pos_weight=None, gamma=2.0, label_smoothing=0.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, logits, targets):
        # Apply label smoothing: 0 -> epsilon, 1 -> 1-epsilon
        if self.label_smoothing > 0:
            targets = targets * (1 - self.label_smoothing) + 0.5 * self.label_smoothing
        
        bce_loss = self.bce(logits, targets)
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            pt = probs*targets + (1-probs)*(1-targets)
        focal_factor = (1-pt)**self.gamma
        return (focal_factor * bce_loss).mean()

criterion = FocalBCEWithLogits(pos_weight=pos_weight, gamma=focal_gamma, label_smoothing=0.1) if use_focal else nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

best_metric = -1.0
best_state = None
best_threshold = 0.5
no_improve = 0

for epoch in range(1, epochs + 1):
    model.train()
    running = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running += loss.item() * xb.size(0)
    epoch_loss = running / len(train_loader.dataset)

    # Evaluate on validation set (use threshold=0.5 first)
    val_metrics = evaluate(model, val_loader, device, decision_threshold=0.5)
    probs_val = val_metrics["probs"]
    y_val = val_metrics["y"]

    # Threshold search optimized for GENERALIZATION across files
    thr_opt = 0.5
    if optimize_threshold and probs_val.size > 0:
        candidate_thr = np.unique(np.quantile(probs_val, np.linspace(0, 1, 201)))
        best_score_local = -1
        
        for t in candidate_thr:
            preds_t = (probs_val >= t).astype(int)
            
            if threshold_opt_metric == "f1":
                # Compute F1 per file and minimize variance (for generalization)
                f1_per_file = []
                for file_idx in range(len(files)):
                    file_mask = file_val == file_idx
                    if file_mask.sum() > 0:
                        y_file = y_val[file_mask]
                        preds_file = preds_t[file_mask]
                        f1_file = f1_score(y_file, preds_file, zero_division=0)
                        f1_per_file.append(f1_file)
                
                if len(f1_per_file) > 0:
                    # Reward high mean F1 AND low variance (generalization)
                    mean_f1 = np.mean(f1_per_file)
                    std_f1 = np.std(f1_per_file)
                    # Penalize high variance
                    metric_t = mean_f1 - 0.5 * std_f1  # Adjust weight as needed
                else:
                    metric_t = f1_score(y_val, preds_t, zero_division=0)
            elif threshold_opt_metric == "balanced_accuracy":
                metric_t = balanced_accuracy_score(y_val, preds_t)
            elif threshold_opt_metric == "mcc":
                try:
                    metric_t = matthews_corrcoef(y_val, preds_t)
                except Exception:
                    metric_t = -1
            else:
                metric_t = f1_score(y_val, preds_t, zero_division=0)
            
            if metric_t > best_score_local:
                best_score_local = metric_t
                thr_opt = t
    
    if optimize_threshold:
        current_threshold = thr_opt
    else:
        current_threshold = 0.5

    # Recompute metrics at chosen threshold
    val_metrics_thr = evaluate(model, val_loader, device, decision_threshold=current_threshold)

    auc_val = val_metrics_thr["auc"]
    ap_val = val_metrics_thr["ap"]
    bal_acc_val = val_metrics_thr["bal_acc"]
    f1_val = val_metrics_thr["f1"]
    mcc_val = val_metrics_thr["mcc"]

    metric_map = {
        "auc": auc_val,
        "f1": f1_val,
        "ap": ap_val,
        "balanced_accuracy": bal_acc_val,
        "mcc": mcc_val,
    }
    current_primary = metric_map.get(primary_metric, f1_val)

    print(
        f"Epoch {epoch:03d} loss={epoch_loss:.4f} AUC_val={auc_val:.4f} AP_val={ap_val:.4f} F1_val={f1_val:.4f} BalAcc_val={bal_acc_val:.4f} MCC_val={mcc_val:.4f} thr={current_threshold:.3f} primary({primary_metric})={current_primary:.4f}"  # noqa: E501
    )

    if current_primary > best_metric + min_improvement * max(abs(best_metric), 1.0):
        best_metric = current_primary
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        best_threshold = current_threshold
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping triggered.")
            break

if best_state is not None:
    model.load_state_dict(best_state)
print(f"Best {primary_metric}={best_metric:.4f} at threshold={best_threshold:.4f}")

## Evaluate on Test Set

In [None]:
# Evaluate using best threshold from validation
print(f"Evaluating on test set with threshold={best_threshold:.4f}")
test_metrics = evaluate(model, test_loader, device, decision_threshold=best_threshold)
cm = test_metrics["cm"]
probs = test_metrics["probs"]
y_true = test_metrics["y"]
preds = test_metrics["preds"]
auc_final = test_metrics["auc"]
print("Confusion matrix:\n", cm)
report = classification_report(y_true, preds, digits=3, zero_division=0)
print(report)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true, preds, average="binary", zero_division=0
)
print(f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f} AUC={auc_final:.3f}")

## Extended Metrics

In [None]:
acc = (preds == y_true).mean()
precision = (
    ((preds & (y_true == 1)).sum() / max((preds == 1).sum(), 1))
    if (preds == 1).any()
    else 0.0
)
recall = (preds & (y_true == 1)).sum() / max((y_true == 1).sum(), 1)
specificity = ((preds == 0) & (y_true == 0)).sum() / max((y_true == 0).sum(), 1)
balanced_accuracy = 0.5 * (recall + specificity)
try:
    mcc = matthews_corrcoef(y_true, preds)
except Exception:
    mcc = float("nan")
try:
    ap = average_precision_score(y_true, probs)
except Exception:
    ap = float("nan")
try:
    brier = brier_score_loss(y_true, probs)
except Exception:
    brier = float("nan")

# Precision-Recall threshold analysis
prec_curve, rec_curve, thr_pr = precision_recall_curve(y_true, probs)
f1_curve = 2 * prec_curve * rec_curve / np.clip(prec_curve + rec_curve, 1e-9, None)
max_f1_idx = np.nanargmax(f1_curve)
opt_pr_threshold = (
    thr_pr[max_f1_idx - 1] if max_f1_idx > 0 and max_f1_idx - 1 < len(thr_pr) else 0.5
)
best_f1 = f1_curve[max_f1_idx]

# Youden J optimal ROC threshold
fpr_curve, tpr_curve, thr_roc = roc_curve(y_true, probs)
youden = tpr_curve - fpr_curve
j_idx = np.argmax(youden)
youden_thr = thr_roc[j_idx]

print("--- Extended Metrics ---")
print(f"Accuracy            : {acc:.4f}")
print(f"Precision (0.5 cut) : {precision:.4f}")
print(f"Recall (TPR)        : {recall:.4f}")
print(f"Specificity (TNR)   : {specificity:.4f}")
print(f"Balanced Accuracy   : {balanced_accuracy:.4f}")
print(f"MCC                 : {mcc:.4f}")
print(f"Average Precision   : {ap:.4f}")
print(f"Brier Score         : {brier:.4f}")
print(f"Best F1             : {best_f1:.4f} at PR threshold ~ {opt_pr_threshold:.4f}")
print(f"Youden J threshold  : {youden_thr:.4f}")

# Plot Precision-Recall curve
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(rec_curve, prec_curve, label=f"AP={ap:.3f}")
plt.scatter(
    rec_curve[max_f1_idx],
    prec_curve[max_f1_idx],
    marker="o",
    color="red",
    label="Best F1",
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(fpr_curve, tpr_curve, label=f"AUC={auc_final:.3f}")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - DNN")
plt.legend()

plt.tight_layout()
plt.show()

## Per-File Performance Analysis

In [None]:
# Evaluate performance breakdown by source file
print("=" * 70)
print("PER-FILE PERFORMANCE ON TEST SET")
print("=" * 70)

for file_idx, fname in enumerate(files):
    mask = file_test == file_idx
    if mask.sum() == 0:
        continue
    
    probs_file = probs[mask]
    y_file = y_true[mask]
    
    # Find optimal threshold for THIS file to maximize recall
    best_recall = 0
    best_thr_recall = best_threshold
    for t in np.linspace(0.1, 0.9, 50):
        preds_t = (probs_file >= t).astype(int)
        cm_t = confusion_matrix(y_file, preds_t)
        if cm_t.shape[0] > 1:
            recall_t = cm_t[1, 1] / max(cm_t[1, 1] + cm_t[1, 0], 1)
            if recall_t > best_recall:
                best_recall = recall_t
                best_thr_recall = t
    
    preds_file = (probs_file >= best_threshold).astype(int)
    preds_file_optimized = (probs_file >= best_thr_recall).astype(int)
    
    cm_file = confusion_matrix(y_file, preds_file)
    cm_file_opt = confusion_matrix(y_file, preds_file_optimized)
    
    try:
        auc_file = roc_auc_score(y_file, probs_file)
    except ValueError:
        auc_file = float("nan")
    
    try:
        f1_file = f1_score(y_file, preds_file, zero_division=0)
    except Exception:
        f1_file = float("nan")
    
    recall_file = cm_file[1, 1] / max(cm_file[1, 1] + cm_file[1, 0], 1) if cm_file.shape[0] > 1 else 0
    recall_file_opt = cm_file_opt[1, 1] / max(cm_file_opt[1, 1] + cm_file_opt[1, 0], 1) if cm_file_opt.shape[0] > 1 else 0
    precision_file = cm_file[1, 1] / max(cm_file[1, 1] + cm_file[0, 1], 1) if cm_file.shape[0] > 1 else 0
    
    print(f"\n{fname.split('/')[-1]}:")
    print(f"  Samples: {mask.sum()} (pos={y_file.sum()}, neg={len(y_file)-y_file.sum()})")
    print(f"  AUC: {auc_file:.4f}")
    print(f"  F1: {f1_file:.4f}")
    print(f"  Precision: {precision_file:.4f}")
    print(f"  Recall @ global thr={best_threshold:.3f}: {recall_file:.4f}")
    print(f"  Recall @ optimal thr={best_thr_recall:.3f}: {recall_file_opt:.4f} (+{(recall_file_opt-recall_file)*100:.1f}%)")
    print(f"  Confusion Matrix (global thr):\n{cm_file}")

## Probability Distributions Plot

In [None]:
plt.figure(figsize=(6, 5))
plt.hist(probs[y_true == 1], bins=40, histtype="step", label="matched")
plt.hist(probs[y_true == 0], bins=40, histtype="step", label="fake")
plt.xlabel("Predicted Probability (matched)")
plt.ylabel("Tracks")
plt.legend()
plt.tight_layout()
plt.show()

## Confusion Matrix Heatmap

In [None]:
plt.figure(figsize=(4, 3))
sns.heatmap(cm, annot=True, fmt="d", cmap="YlOrRd", cbar=False)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

## Save Model Artifact

In [None]:
# pt format for python evaluation
artifact = {
    "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
    "model_config": {
        "in_features": X.shape[1],
        "layers": hidden_layers,
        "dropout": dropout,
    },
    "scaler": scaler,
    "feature_names": features,
    "best_threshold": best_threshold,
    "auc_test": float(auc_final),
}
output_artifact = "dnn_artifact.pt"
torch.save(artifact, output_artifact)
size_mb = os.path.getsize(output_artifact) / 1024**2
print(f"Saved artifact to {output_artifact} ({size_mb:.2f} MB)")

In [None]:
## Export Model to ONNX for CMSSW integration
import torch.onnx

# Set model to evaluation mode
model.eval()
model.to("cpu")

# Create dummy input with the correct shape
dummy_input = torch.randn(1, X.shape[1], dtype=torch.float32)

# Export to ONNX
onnx_path = "muon_pixeltrack_selector.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,  # CMSSW typically supports opset 11-14
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)

print(f"Model exported to {onnx_path}")

# Verify the exported model
import onnx
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")

# Test inference with ONNX Runtime
import onnxruntime as ort
ort_session = ort.InferenceSession(onnx_path)

# Test with a few samples
test_input = X_test[:10].astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: test_input}
ort_outputs = ort_session.run(None, ort_inputs)

# Compare with PyTorch output
with torch.no_grad():
    torch_out = model(torch.from_numpy(test_input))
    torch_probs = torch.sigmoid(torch_out).numpy()

onnx_probs = 1.0 / (1.0 + np.exp(-ort_outputs[0]))  # sigmoid

print("Max difference between PyTorch and ONNX:", np.abs(torch_probs - onnx_probs).max())
print("ONNX Runtime inference successful!")

# Save scaler parameters and feature info for CMSSW
scaler_params = {
    "mean": scaler.mean_.tolist(),
    "scale": scaler.scale_.tolist(),
    "feature_names": features,
    "threshold": float(best_threshold),
}

import json
with open("scaler_params.json", "w") as f:
    json.dump(scaler_params, f, indent=2)

print("Saved scaler parameters to scaler_params.json")

In [None]:
# Test edge cases with ONNX model
import onnxruntime as ort

ort_session = ort.InferenceSession(onnx_path)
input_name = ort_session.get_inputs()[0].name

# Test 1: Single track
print("Test 1: Single track")
single_input = X_test[:1].astype(np.float32)
single_output = ort_session.run(None, {input_name: single_input})
print(f"  Input shape: {single_input.shape}, Output shape: {single_output[0].shape}")

# Test 2: Multiple tracks (batch)
print("\nTest 2: Batch of 10 tracks")
batch_input = X_test[:10].astype(np.float32)
batch_output = ort_session.run(None, {input_name: batch_input})
print(f"  Input shape: {batch_input.shape}, Output shape: {batch_output[0].shape}")

# Test 3: What happens with different batch sizes
print("\nTest 3: Various batch sizes")
for batch_size in [1, 5, 100, 1000]:
    test_input = X_test[:batch_size].astype(np.float32)
    test_output = ort_session.run(None, {input_name: test_input})
    print(f"  Batch size {batch_size}: Input {test_input.shape} -> Output {test_output[0].shape}")

print("\n✓ ONNX model handles all batch sizes correctly")
print("Note: Empty batches (0 tracks) should be handled in C++ before calling ONNX")

## Visualize L1TkMuon Matching Features

In [None]:
# Visualize the new L1TkMuon features
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Get feature indices
dR2_idx = features.index("L1TkMu_dR2min")
dPt_idx = features.index("L1TkMu_dPtNorm")

# ΔR2 distribution
ax = axes[0, 0]
ax.hist(X[y == 1, dR2_idx], bins=50, range=(0, 0.5), histtype="step", 
        label="Matched tracks", linewidth=2, color="blue", density=True)
ax.hist(X[y == 0, dR2_idx], bins=50, range=(0, 0.5), histtype="step", 
        label="Fake tracks", linewidth=2, color="red", density=True)
ax.set_xlabel("ΔR2 to closest L1TkMuon")
ax.set_ylabel("Density")
ax.set_title("ΔR2 Distribution")
ax.legend()
ax.grid(alpha=0.3)

# ΔpT/pT distribution
ax = axes[0, 1]
ax.hist(X[y == 1, dPt_idx], bins=50, range=(0, 1.0), histtype="step", 
        label="Matched tracks", linewidth=2, color="blue", density=True)
ax.hist(X[y == 0, dPt_idx], bins=50, range=(0, 1.0), histtype="step", 
        label="Fake tracks", linewidth=2, color="red", density=True)
ax.set_xlabel("|pT_track - pT_L1| / pT_L1")
ax.set_ylabel("Density")
ax.set_title("Normalized pT Difference")
ax.legend()
ax.grid(alpha=0.3)

# 2D correlation: ΔR2 vs ΔpT (matched tracks)
ax = axes[1, 0]
h = ax.hist2d(X[y == 1, dR2_idx], X[y == 1, dPt_idx], 
              bins=[50, 50], range=[[0, 0.5], [0, 1.0]], 
              cmap="Blues", norm=plt.matplotlib.colors.LogNorm())
plt.colorbar(h[3], ax=ax, label="Matched tracks")
ax.set_xlabel("ΔR2 to closest L1TkMuon")
ax.set_ylabel("|pT_track - pT_L1| / pT_L1")
ax.set_title("Matched Tracks: ΔR2 vs ΔpT")

# 2D correlation: ΔR2 vs ΔpT (fake tracks)
ax = axes[1, 1]
h = ax.hist2d(X[y == 0, dR2_idx], X[y == 0, dPt_idx], 
              bins=[50, 50], range=[[0, 0.5], [0, 1.0]], 
              cmap="Reds", norm=plt.matplotlib.colors.LogNorm())
plt.colorbar(h[3], ax=ax, label="Fake tracks")
ax.set_xlabel("ΔR2 to closest L1TkMuon")
ax.set_ylabel("|pT_track - pT_L1| / pT_L1")
ax.set_title("Fake Tracks: ΔR2 vs ΔpT")

plt.tight_layout()
plt.show()

# Print statistics
print("=" * 60)
print("L1TkMuon Matching Feature Statistics")
print("=" * 60)
print(f"\nMatched tracks (n={(y==1).sum()}):")
print(f"  ΔR2_min:     mean={X[y==1, dR2_idx].mean():.4f}, std={X[y==1, dR2_idx].std():.4f}")
print(f"  ΔpT_norm:   mean={X[y==1, dPt_idx].mean():.4f}, std={X[y==1, dPt_idx].std():.4f}")

print(f"\nFake tracks (n={(y==0).sum()}):")
print(f"  ΔR2_min:     mean={X[y==0, dR2_idx].mean():.4f}, std={X[y==0, dR2_idx].std():.4f}")
print(f"  ΔpT_norm:   mean={X[y==0, dPt_idx].mean():.4f}, std={X[y==0, dPt_idx].std():.4f}")

print("\nSeparation power:")
print(f"  ΔR2_min:     ratio={X[y==0, dR2_idx].mean() / X[y==1, dR2_idx].mean():.2f}x")
print(f"  ΔpT_norm:   ratio={X[y==0, dPt_idx].mean() / X[y==1, dPt_idx].mean():.2f}x")