# Imports and setup

In [None]:
import uproot
import awkward as ak
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    classification_report,
    precision_recall_fscore_support,
    balanced_accuracy_score,
    matthews_corrcoef,
    average_precision_score,
    f1_score,
    brier_score_loss,
)
import os

os.environ["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

main_branch = "Events"
tk_branches = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
    "muon_pixel_tracks_matched",
    "muon_pixel_tracks_duplicate",
    "muon_pixel_tracks_tpPdgId",
    "muon_pixel_tracks_tpPt",
    "muon_pixel_tracks_tpEta",
    "muon_pixel_tracks_tpPhi",
]
gen_branches = [
    "GenPart_pt",
    "GenPart_eta",
    "GenPart_phi",
    "GenPart_mass",
    "GenPart_pdgId",
    "GenPart_statusFlags",  # added to select last-copy muons
]

l1tkMuon_branches = [
    "L1TkMu_pt",
    "L1TkMu_eta",
    "L1TkMu_phi",
]


legacy = False
allPixel = False

# Configuration Parameters
filesSelector = [
    "data/ntuples_TTbarCAExtensionFull.root",
    "data/ntuples_ZMMCAExtensionFull.root",
    "data/ntuples_WprimeCAExtensionFull.root",
]

filesAllPixel = [
    "data/ntuples_TTbarCAExtensionAllPixel.root",
    "data/ntuples_ZMMCAExtensionAllPixel.root",
    "data/ntuples_WprimeCAExtensionAllPixel.root"
]

files = filesSelector if not allPixel else filesAllPixel

if legacy:
    for i, f in enumerate(files):
        files[i] = f.replace("CAExtension", "Legacy")
print(files)

# ntuples selection
arrays = []
for f in files:
    with uproot.open(f) as file:
        arrays_f = file[main_branch].arrays(tk_branches + gen_branches + l1tkMuon_branches)
        arrays = ak.concatenate([arrays, arrays_f], axis=0)
    print(f"Done loading {f}")

print(f"Loaded {len(arrays)} events")

In [None]:
# Reproducibility
"""
import random
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
"""
use_gpu = True
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Torch version HIP:", torch.version.hip)
print("CUDA device name:", torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
print("Using device:", device)

## Build Feature Matrix and Labels

In [None]:
features = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_ptErr",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
]

LABEL_FIELD = "muon_pixel_tracks_matched"

def build_dataset(arr):
    mask = arr[features[0]] >= 0
    cols = []
    for f in features:
        minimum = ak.min(ak.flatten(arr[f][mask]))
        maximum = ak.max(ak.flatten(arr[f][mask]))
        if f in ["muon_pixel_tracks_p", "muon_pixel_tracks_pt"] or "Err" in f:
            print(f"Feature {f} min {minimum:.2f} max {maximum:.2f} -> log10")
            arr[f] = np.log10(arr[f])
        flat = ak.to_numpy(ak.flatten(arr[f][mask]))
        cols.append(flat)
    X = np.vstack(cols).T
    y = ak.to_numpy(ak.flatten(arr[LABEL_FIELD][mask])).astype(np.int8)
    finite = np.isfinite(X).all(axis=1)
    if not finite.all():
        X = X[finite]
        y = y[finite]
    return X, y


X, y = build_dataset(arrays)
print("Feature matrix shape:", X.shape)
print("Labels shape:", y.shape, "Positives:", y.sum(), f"({100 * y.mean():.2f}%)")
print("Feature order:", features)

## Train/Test Split and Scaling

In [None]:
train_frac = 0.7  # fraction for train+val
val_frac = 0.15   # fraction (of full dataset) reserved for validation (inside train+val)

# First split train_val vs test
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X,
    y,
    train_size=train_frac,
    stratify=y if (y.sum() > 0 and y.sum() < len(y)) else None,
)

# Derive actual validation fraction relative to train_val portion
if val_frac > 0:
    rel_val = val_frac / train_frac  # portion of train_val to carve out as validation
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val,
        y_train_val,
        test_size=rel_val,
        stratify=y_train_val if (y_train_val.sum() > 0 and y_train_val.sum() < len(y_train_val)) else None,
    )
else:
    X_train, y_train = X_train_val, y_train_val
    X_val, y_val = X_test, y_test  

print("Train size:", X_train.shape, "Val size:", X_val.shape, "Test size:", X_test.shape)

# Examine features pre-scaling stats
print("Pre-scaling feature stats:")
for i in range(X_train.shape[1]):
    print(f"{features[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}")
print("--------------------------------------------------------------------\n Post-scaling feature stats:")
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)
for i in range(X_train.shape[1]):
    print(f"{features[i].split('_')[-1]}: mean={X_train[:, i].mean():.4f}, std={X_train[:, i].std():.4f}")

In [None]:
pos = y_train.sum()
neg = len(y_train) - pos
if pos == 0:
    pos_weight_value = 1.0
else:
    pos_weight_value = neg / pos
pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=device)
print(f"Class counts train: pos={pos} neg={neg} -> pos_weight={pos_weight_value:.3f}")

## Torch Dataset and DataLoader

In [None]:
batch_size = 1024
use_weighted_sampler = False

class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32)).unsqueeze(1)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = NumpyDataset(X_train, y_train)
val_ds = NumpyDataset(X_val, y_val)
test_ds = NumpyDataset(X_test, y_test)

if use_weighted_sampler:
    # inverse frequency sampling to upweight minority
    class_sample_counts = np.array([ (y_train==0).sum(), (y_train==1).sum() ])
    weights = 1.0 / np.clip(class_sample_counts, 1, None)
    sample_weights = weights[y_train]
    sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=False)
else:
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)

val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

xb, yb = next(iter(train_loader))
print("First batch shapes:", xb.shape, yb.shape)
print("Class balance train (orig): pos=", y_train.sum(), "neg=", len(y_train)-y_train.sum())

## Define MLP Architecture

In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, layers, dropout=0):
        super().__init__()
        seq = []
        prev = in_features
        for h in layers:
            seq.append(nn.Linear(prev, h))
            seq.append(nn.BatchNorm1d(h))
            seq.append(nn.ReLU())
            if dropout > 0:
                seq.append(nn.Dropout(dropout))
            prev = h
        seq.append(nn.Linear(prev, 1))
        self.net = nn.Sequential(*seq)

    def forward(self, x):
        return self.net(x)

hidden_layers = [512, 128, 64, 32, 16]
dropout = 0.3  # slightly higher dropout to regularize
lr = 2e-3        # smaller LR for stability
weight_decay = 1e-4
epochs = 1000
patience = 50    # more realistic patience now that we use validation

model = MLP(in_features=X.shape[1], layers=hidden_layers, dropout=dropout).to(device)
print(model)

## Training Loop with Early Stopping

In [None]:
# Multi-metric configuration
primary_metric = "f1"  # options: 'auc','f1','ap','balanced_accuracy','mcc'
optimize_threshold = True
threshold_opt_metric = "f1"  # which metric to maximize when choosing threshold
min_improvement = 1e-4  # required relative improvement for early stopping reset
use_focal = True
focal_gamma = 2

@torch.no_grad()
def evaluate(model, loader, device, decision_threshold=0.5):
    model.eval()
    logits_list = []
    y_list = []
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        logits_list.append(logits.cpu())
        y_list.append(yb.cpu())
    logits = torch.cat(logits_list).squeeze(1)
    y_true = torch.cat(y_list).squeeze(1)
    probs = torch.sigmoid(logits).numpy()
    y_np = y_true.numpy().astype(int)
    preds = (probs >= decision_threshold).astype(int)
    cm = confusion_matrix(y_np, preds)
    try:
        auc = roc_auc_score(y_np, probs)
    except ValueError:
        auc = float("nan")
    try:
        ap = average_precision_score(y_np, probs)
    except ValueError:
        ap = float("nan")
    try:
        bal_acc = balanced_accuracy_score(y_np, preds)
    except Exception:
        bal_acc = float("nan")
    try:
        f1 = f1_score(y_np, preds, zero_division=0)
    except Exception:
        f1 = float("nan")
    try:
        mcc = matthews_corrcoef(y_np, preds)
    except Exception:
        mcc = float("nan")
    return dict(
        probs=probs,
        y=y_np,
        preds=preds,
        cm=cm,
        auc=auc,
        ap=ap,
        bal_acc=bal_acc,
        f1=f1,
        mcc=mcc,
    )

# Optional Focal Loss wrapper
class FocalBCEWithLogits(nn.Module):
    def __init__(self, pos_weight=None, gamma=2.0):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
        self.gamma = gamma
    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            pt = probs*targets + (1-probs)*(1-targets)
        focal_factor = (1-pt)**self.gamma
        return (focal_factor * bce_loss).mean()

criterion = FocalBCEWithLogits(pos_weight=pos_weight, gamma=focal_gamma) if use_focal else nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

best_metric = -1.0
best_state = None
best_threshold = 0.5
no_improve = 0

for epoch in range(1, epochs + 1):
    model.train()
    running = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running += loss.item() * xb.size(0)
    epoch_loss = running / len(train_loader.dataset)

    # Evaluate on validation set (use threshold=0.5 first)
    val_metrics = evaluate(model, val_loader, device, decision_threshold=0.5)
    probs_val = val_metrics["probs"]
    y_val = val_metrics["y"]

    # Threshold search on validation probabilities
    thr_opt = 0.5
    if optimize_threshold and probs_val.size > 0:
        candidate_thr = np.unique(np.quantile(probs_val, np.linspace(0, 1, 201)))
        best_score_local = -1
        for t in candidate_thr:
            preds_t = (probs_val >= t).astype(int)
            if threshold_opt_metric == "f1":
                metric_t = f1_score(y_val, preds_t, zero_division=0)
            elif threshold_opt_metric == "balanced_accuracy":
                metric_t = balanced_accuracy_score(y_val, preds_t)
            elif threshold_opt_metric == "mcc":
                try:
                    metric_t = matthews_corrcoef(y_val, preds_t)
                except Exception:
                    metric_t = -1
            else:
                # default to F1
                metric_t = f1_score(y_val, preds_t, zero_division=0)
            if metric_t > best_score_local:
                best_score_local = metric_t
                thr_opt = t
    if optimize_threshold:
        current_threshold = thr_opt
    else:
        current_threshold = 0.5

    # Recompute metrics at chosen threshold
    val_metrics_thr = evaluate(model, val_loader, device, decision_threshold=current_threshold)

    auc_val = val_metrics_thr["auc"]
    ap_val = val_metrics_thr["ap"]
    bal_acc_val = val_metrics_thr["bal_acc"]
    f1_val = val_metrics_thr["f1"]
    mcc_val = val_metrics_thr["mcc"]

    metric_map = {
        "auc": auc_val,
        "f1": f1_val,
        "ap": ap_val,
        "balanced_accuracy": bal_acc_val,
        "mcc": mcc_val,
    }
    current_primary = metric_map.get(primary_metric, f1_val)

    print(
        f"Epoch {epoch:03d} loss={epoch_loss:.4f} AUC_val={auc_val:.4f} AP_val={ap_val:.4f} F1_val={f1_val:.4f} BalAcc_val={bal_acc_val:.4f} MCC_val={mcc_val:.4f} thr={current_threshold:.3f} primary({primary_metric})={current_primary:.4f}"  # noqa: E501
    )

    if current_primary > best_metric + min_improvement * max(abs(best_metric), 1.0):
        best_metric = current_primary
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        best_threshold = current_threshold
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping triggered.")
            break

if best_state is not None:
    model.load_state_dict(best_state)
print(f"Best {primary_metric}={best_metric:.4f} at threshold={best_threshold:.4f}")

## Evaluate on Test Set

In [None]:
# Evaluate using best threshold from validation
print(f"Evaluating on test set with threshold={best_threshold:.4f}")
test_metrics = evaluate(model, test_loader, device, decision_threshold=best_threshold)
cm = test_metrics["cm"]
probs = test_metrics["probs"]
y_true = test_metrics["y"]
preds = test_metrics["preds"]
auc_final = test_metrics["auc"]
print("Confusion matrix:\n", cm)
report = classification_report(y_true, preds, digits=3, zero_division=0)
print(report)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true, preds, average="binary", zero_division=0
)
print(f"Precision={precision:.3f} Recall={recall:.3f} F1={f1:.3f} AUC={auc_final:.3f}")

## Extended Metrics

In [None]:
acc = (preds == y_true).mean()
precision = (
    ((preds & (y_true == 1)).sum() / max((preds == 1).sum(), 1))
    if (preds == 1).any()
    else 0.0
)
recall = (preds & (y_true == 1)).sum() / max((y_true == 1).sum(), 1)
specificity = ((preds == 0) & (y_true == 0)).sum() / max((y_true == 0).sum(), 1)
balanced_accuracy = 0.5 * (recall + specificity)
try:
    mcc = matthews_corrcoef(y_true, preds)
except Exception:
    mcc = float("nan")
try:
    ap = average_precision_score(y_true, probs)
except Exception:
    ap = float("nan")
try:
    brier = brier_score_loss(y_true, probs)
except Exception:
    brier = float("nan")

# Precision-Recall threshold analysis
prec_curve, rec_curve, thr_pr = precision_recall_curve(y_true, probs)
f1_curve = 2 * prec_curve * rec_curve / np.clip(prec_curve + rec_curve, 1e-9, None)
max_f1_idx = np.nanargmax(f1_curve)
opt_pr_threshold = (
    thr_pr[max_f1_idx - 1] if max_f1_idx > 0 and max_f1_idx - 1 < len(thr_pr) else 0.5
)
best_f1 = f1_curve[max_f1_idx]

# Youden J optimal ROC threshold
fpr_curve, tpr_curve, thr_roc = roc_curve(y_true, probs)
youden = tpr_curve - fpr_curve
j_idx = np.argmax(youden)
youden_thr = thr_roc[j_idx]

print("--- Extended Metrics ---")
print(f"Accuracy            : {acc:.4f}")
print(f"Precision (0.5 cut) : {precision:.4f}")
print(f"Recall (TPR)        : {recall:.4f}")
print(f"Specificity (TNR)   : {specificity:.4f}")
print(f"Balanced Accuracy   : {balanced_accuracy:.4f}")
print(f"MCC                 : {mcc:.4f}")
print(f"Average Precision   : {ap:.4f}")
print(f"Brier Score         : {brier:.4f}")
print(f"Best F1             : {best_f1:.4f} at PR threshold ~ {opt_pr_threshold:.4f}")
print(f"Youden J threshold  : {youden_thr:.4f}")

# Plot Precision-Recall curve
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(rec_curve, prec_curve, label=f"AP={ap:.3f}")
plt.scatter(
    rec_curve[max_f1_idx],
    prec_curve[max_f1_idx],
    marker="o",
    color="red",
    label="Best F1",
)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(fpr_curve, tpr_curve, label=f"AUC={auc_final:.3f}")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - DNN")
plt.legend()

plt.tight_layout()
plt.show()

## Probability Distributions Plot

In [None]:
plt.figure(figsize=(6, 5))
plt.hist(probs[y_true == 1], bins=40, histtype="step", label="matched")
plt.hist(probs[y_true == 0], bins=40, histtype="step", label="fake")
plt.xlabel("Predicted Probability (matched)")
plt.ylabel("Tracks")
plt.legend()
plt.tight_layout()
plt.show()

## Confusion Matrix Heatmap

In [None]:
plt.figure(figsize=(4, 3))
sns.heatmap(cm, annot=True, fmt="d", cmap="YlOrRd", cbar=False)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()

## Save Model Artifact

In [None]:
artifact = {
    "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
    "model_config": {
        "in_features": X.shape[1],
        "layers": hidden_layers,
        "dropout": dropout,
    },
    "scaler": scaler,
    "feature_names": features,
    "best_threshold": best_threshold,
    "auc_test": float(auc_final),
}
output_artifact = "dnn_artifact.pt"
torch.save(artifact, output_artifact)
size_mb = os.path.getsize(output_artifact) / 1024**2
print(f"Saved artifact to {output_artifact} ({size_mb:.2f} MB)")