In [None]:
import uproot
import awkward as ak
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, HTML

# Set the output box size for images
display(
    HTML(
        "<style>.output_png, .output_jpeg, .output_svg {height: 500px; overflow-y: scroll;}</style>"
    )
)

# Branches for muon pixel tracks
main_branch = "Events"
tk_branches = [
    "muon_pixel_tracks_p",
    "muon_pixel_tracks_pt",
    "muon_pixel_tracks_eta",
    "muon_pixel_tracks_etaErr",
    "muon_pixel_tracks_phi",
    "muon_pixel_tracks_phiErr",
    "muon_pixel_tracks_chi2",
    "muon_pixel_tracks_normalizedChi2",
    "muon_pixel_tracks_nPixelHits",
    "muon_pixel_tracks_nTrkLays",
    "muon_pixel_tracks_nFoundHits",
    "muon_pixel_tracks_nLostHits",
    "muon_pixel_tracks_dsz",
    "muon_pixel_tracks_dszErr",
    "muon_pixel_tracks_dxy",
    "muon_pixel_tracks_dxyErr",
    "muon_pixel_tracks_dz",
    "muon_pixel_tracks_dzErr",
    "muon_pixel_tracks_qoverp",
    "muon_pixel_tracks_qoverpErr",
    "muon_pixel_tracks_lambdaErr",
    "muon_pixel_tracks_matched",
    "muon_pixel_tracks_duplicate",
    "muon_pixel_tracks_tpPdgId",
    "muon_pixel_tracks_tpPt",
    "muon_pixel_tracks_tpEta",
    "muon_pixel_tracks_tpPhi",
]
gen_branches = [
    "GenPart_pt",
    "GenPart_eta",
    "GenPart_phi",
    "GenPart_mass",
    "GenPart_pdgId",
    "GenPart_statusFlags",  # added to select last-copy muons
]

legacy = False

files = [
    "data/10k_ZMM200PU_ext.root",
    "data/9k_TTbar200PU_ext.root",
]

if legacy:
    for i, f in enumerate(files):
        files[i] = f.replace("_ext", "_legacy")
print(files)
# ntuples selection
arrays = []
for f in files:
    with uproot.open(f) as file:
        arrays_f = file[main_branch].arrays(tk_branches + gen_branches)
        arrays = ak.concatenate([arrays, arrays_f], axis=0)
print(f"Loaded {len(arrays)} events")

### Utilities

In [None]:
# Helpers
def wrap_phi(phi):
    return ((phi + np.pi) % (2 * np.pi)) - np.pi


def binned(mask, vals, bins):
    num, _ = np.histogram(vals[mask], bins=bins)
    den, _ = np.histogram(vals, bins=bins)
    return num, den


def wilson(num, den, z=1.0):
    # Safe Wilson interval: avoid divide-by-zero and suppress warnings for empty bins
    num = num.astype(float)
    den = den.astype(float)
    mask = den > 0
    p = np.zeros_like(num, dtype=float)
    p[mask] = num[mask] / den[mask]
    center = np.zeros_like(p)
    half = np.zeros_like(p)
    if np.any(mask):
        denm = den[mask]
        denom = 1 + z**2 / denm
        center[mask] = (p[mask] + z**2 / (2 * denm)) / denom
        half[mask] = (
            z * np.sqrt(p[mask] * (1 - p[mask]) / denm + z**2 / (4 * denm**2)) / denom
        )
    return center, half


# Selection parameters (same as TP selector for Muons)
MUON_ABS_PDGID = 13
PT_MIN = 0.9
ETA_MAX = 2.4
DR_MATCH = 0.01
RELPT_MATCH = 0.10

# MTV-like histo parameters
MTV_MIN_PT = 0.9
MTV_MAX_PT = 2000.0
MTV_N_PT = 50
USE_LOG_PT = True
MTV_MIN_ETA = -2.5
MTV_MAX_ETA = 2.5
MTV_N_ETA = 50
MTV_MIN_PHI = -3.1416
MTV_MAX_PHI = 3.1416
MTV_N_PHI = 36

In [None]:
# Track quality parameters plotter
def plot_track_quality_parameters(
    arrays=arrays, selection=(arrays.muon_pixel_tracks_pt >= 0)
):
    metrics = []
    pt = arrays.muon_pixel_tracks_pt[selection]
    metrics.append(pt)
    eta = arrays.muon_pixel_tracks_eta[selection]
    metrics.append(eta)
    chi2 = arrays.muon_pixel_tracks_chi2[selection]
    metrics.append(chi2)
    normalizedChi2 = arrays.muon_pixel_tracks_normalizedChi2[selection]
    metrics.append(normalizedChi2)
    nPixelHits = arrays.muon_pixel_tracks_nPixelHits[selection]
    metrics.append(nPixelHits)
    nTrkLays = arrays.muon_pixel_tracks_nTrkLays[selection]
    metrics.append(nTrkLays)
    nFoundHits = arrays.muon_pixel_tracks_nFoundHits[selection]
    metrics.append(nFoundHits)
    nLostHits = arrays.muon_pixel_tracks_nLostHits[selection]
    metrics.append(nLostHits)
    dxy = arrays.muon_pixel_tracks_dxy[selection]
    metrics.append(dxy)
    dxyErr = arrays.muon_pixel_tracks_dxyErr[selection]
    metrics.append(dxyErr)
    dz = arrays.muon_pixel_tracks_dz[selection]
    metrics.append(dz)
    dzErr = arrays.muon_pixel_tracks_dzErr[selection]
    metrics.append(dzErr)
    etaErr = arrays.muon_pixel_tracks_etaErr[selection]
    metrics.append(etaErr)
    phiErr = arrays.muon_pixel_tracks_phiErr[selection]
    metrics.append(phiErr)
    qoverp = arrays.muon_pixel_tracks_qoverp[selection]
    metrics.append(qoverp)
    qoverpErr = arrays.muon_pixel_tracks_qoverpErr[selection]
    metrics.append(qoverpErr)
    lambdaErr = arrays.muon_pixel_tracks_lambdaErr[selection]
    metrics.append(lambdaErr)
    dsz = arrays.muon_pixel_tracks_dsz[selection]
    metrics.append(dsz)
    dszErr = arrays.muon_pixel_tracks_dszErr[selection]
    metrics.append(dszErr)

    matched = arrays.muon_pixel_tracks_matched[selection]

    nbins = [
        200,
        50,
        15,
        15,
        15,
        15,
        15,
        10,
        100,
        100,
        100,
        100,
        100,
        100,
        100,
        100,
        100,
        100,
        100,
    ]
    ranges = [
        (0, 200),
        (-2.5, 2.5),
        (0, 15),
        (0, 15),
        (0, 15),
        (0, 15),
        (0, 15),
        (0, 10),
        (-0.1, 0.1),
        (0, 0.05),
        (-20, 20),
        (0, 0.05),
        (0, 0.006),
        (0, 0.02),
        (-0.5, 0.5),
        (0, 0.07),
        (-15, 15),
        (0, 0.03),
        (0, 0.003),
    ]
    labels = [
        r"$p_{T}$",
        r"$\eta$",
        r"$\chi^{2}$",
        r"$\text{normalized } \chi^{2}$",
        "nPixelHits",
        "nTrkLays",
        "nFoundHits",
        "nLostHits",
        r"$d_{xy}$",
        "dxyErr",
        r"$d_{z}",
        "dzErr",
        "etaErr",
        "phiErr",
        "qoverp",
        "qoverpErr",
        "dsz",
        "dszErr",
        "lambdaErr",
    ]

    assert len(metrics) == len(nbins) == len(ranges) == len(labels)

    fig = plt.figure(figsize=(12, 28))
    for i in range(1, len(metrics) + 1):
        plt.subplot(10, 2, i)
        plt.hist(
            ak.to_numpy(ak.flatten(metrics[i - 1])),
            bins=nbins[i - 1],
            range=ranges[i - 1],
            histtype="step",
            label="All",
        )
        plt.hist(
            ak.to_numpy(ak.flatten(metrics[i - 1][matched == 1])),
            bins=nbins[i - 1],
            range=ranges[i - 1],
            histtype="step",
            label="Matched",
        )
        plt.hist(
            ak.to_numpy(ak.flatten(metrics[i - 1][matched == 0])),
            bins=nbins[i - 1],
            range=ranges[i - 1],
            histtype="step",
            label="Fake",
        )
        if labels[i - 1] == r"$p_{T}$":
            plt.xscale("log")
        plt.xlabel(f"Muon PixelTracks {labels[i - 1]}")
        plt.ylabel("Entries")
        plt.legend()
        plt.title(f"Muon PixelTracks {labels[i - 1]} distribution")
    plt.tight_layout()
    return fig

In [None]:
# Efficiency and fakerate plotter
def plot_efficiency_and_fake(
    arrays=arrays, selection=(arrays.muon_pixel_tracks_pt >= 0)
):
    def make_log_edges(min_pt, max_pt, n_bins):
        # Emulate BinLogX logic exactly (no extra clamp needed; min_pt=0.9 > 0.01)
        log_min = np.log10(min_pt)
        log_max = np.log10(max_pt)
        return np.logspace(log_min, log_max, n_bins + 1, base=10.0)

    def make_linear_edges(a, b, n):
        return np.linspace(a, b, n + 1)

    def geometric_centers(edges):
        low = edges[:-1]
        high = edges[1:]
        return np.sqrt(low * high)

    def linear_centers(edges):
        return 0.5 * (edges[:-1] + edges[1:])

    pt_bins = (
        make_log_edges(MTV_MIN_PT, MTV_MAX_PT, MTV_N_PT)
        if USE_LOG_PT
        else make_linear_edges(MTV_MIN_PT, MTV_MAX_PT, MTV_N_PT)
    )
    eta_bins = make_linear_edges(MTV_MIN_ETA, MTV_MAX_ETA, MTV_N_ETA)
    phi_bins = make_linear_edges(MTV_MIN_PHI, MTV_MAX_PHI, MTV_N_PHI)

    pt_centers = geometric_centers(pt_bins) if USE_LOG_PT else linear_centers(pt_bins)
    eta_centers = linear_centers(eta_bins)
    phi_centers = linear_centers(phi_bins)

    # Gen selection (denominator) -> last-copy muons within η and pT ranges used by TP selector
    LAST_COPY_BIT = 13
    statusFlags = arrays.GenPart_statusFlags
    is_last_copy = (statusFlags & (1 << LAST_COPY_BIT)) != 0

    gen_base = (abs(arrays.GenPart_pdgId) == MUON_ABS_PDGID) & is_last_copy
    gen_sel = (
        gen_base & (arrays.GenPart_pt > PT_MIN) & (abs(arrays.GenPart_eta) < ETA_MAX)
    )

    gen_pt = arrays.GenPart_pt[gen_sel]
    gen_eta = arrays.GenPart_eta[gen_sel]
    gen_phi = wrap_phi(arrays.GenPart_phi[gen_sel])

    # Track-level truth info
    tp_pdg = arrays.muon_pixel_tracks_tpPdgId
    tp_pt = arrays.muon_pixel_tracks_tpPt
    tp_eta = arrays.muon_pixel_tracks_tpEta
    tp_phi = wrap_phi(arrays.muon_pixel_tracks_tpPhi)

    # Apply selection here so that only kept tracks can be matched with Gen
    tp_sel = (
        (abs(tp_pdg) == MUON_ABS_PDGID)
        & (tp_pt > PT_MIN)
        & (abs(tp_eta) < ETA_MAX)
        & selection
    )
    eff_tp_pt = tp_pt[tp_sel]
    eff_tp_eta = tp_eta[tp_sel]
    eff_tp_phi = tp_phi[tp_sel]

    # ------ Efficiency ------
    # Match TPs with Gen for Efficiency (geometric + relative pT)
    gen_zip = ak.zip({"pt": gen_pt, "eta": gen_eta, "phi": gen_phi})
    tp_zip = ak.zip({"pt": eff_tp_pt, "eta": eff_tp_eta, "phi": eff_tp_phi})

    pairs = ak.cartesian({"g": gen_zip, "t": tp_zip}, axis=1, nested=True)
    dphi = wrap_phi(pairs.g.phi - pairs.t.phi)
    deta = pairs.g.eta - pairs.t.eta
    dr = np.sqrt(deta**2 + dphi**2)
    relpt = np.abs(pairs.g.pt - pairs.t.pt) / pairs.g.pt

    match_matrix = (dr < DR_MATCH) & (relpt < RELPT_MATCH)
    gen_matched_mask = ak.any(match_matrix, axis=2)  # shape: (events, NgenSelected)

    # Flatten all values and mask
    gen_pt_all = ak.to_numpy(ak.flatten(gen_pt))
    gen_eta_all = ak.to_numpy(ak.flatten(gen_eta))
    gen_phi_all = ak.to_numpy(ak.flatten(gen_phi))
    gen_match_flat = ak.to_numpy(ak.flatten(gen_matched_mask))

    # Get numerators and denominators
    # Efficiency = (number of matched Gen) / (total selected Gen)
    num_pt, den_pt = binned(gen_match_flat, gen_pt_all, pt_bins)
    num_eta, den_eta = binned(gen_match_flat, gen_eta_all, eta_bins)
    num_phi, den_phi = binned(gen_match_flat, gen_phi_all, phi_bins)

    # Compute ratio and mask bins with den = 0
    eff_pt, eff_pt_err = wilson(num_pt, den_pt)
    eff_eta, eff_eta_err = wilson(num_eta, den_eta)
    eff_phi, eff_phi_err = wilson(num_phi, den_phi)
    valid_eff_pt = den_pt > 0
    valid_eff_eta = den_eta > 0
    valid_eff_phi = den_phi > 0

    global_eff = gen_match_flat.sum() / max(len(gen_match_flat), 1)
    print(f"Global efficiency (last-copy muons): {global_eff * 100:.2f}%")

    # ------ Fake rate ------
    trk_pt_all = arrays.muon_pixel_tracks_pt
    trk_eta_all = arrays.muon_pixel_tracks_eta
    trk_phi_all = wrap_phi(arrays.muon_pixel_tracks_phi)

    # Select tracks within η and pT ranges used by TP selector + selection cuts (denominator)
    trk_sel = (trk_pt_all > PT_MIN) & (abs(trk_eta_all) < ETA_MAX) & selection

    fake_track_mask = trk_sel & (arrays.muon_pixel_tracks_matched == 0)

    # Flatten for histograms
    keep_flat = ak.to_numpy(ak.flatten(trk_sel))
    fake_flat = ak.to_numpy(ak.flatten(fake_track_mask))[keep_flat]

    trk_pt_kept = ak.to_numpy(ak.flatten(trk_pt_all[trk_sel]))
    trk_eta_kept = ak.to_numpy(ak.flatten(trk_eta_all[trk_sel]))
    trk_phi_kept = ak.to_numpy(ak.flatten(wrap_phi(trk_phi_all[trk_sel])))

    # Get numerators and denominators
    # Fake rate = (number of fake tracks) / (total selected tracks)
    fake_num_pt, fake_den_pt = binned(fake_flat, trk_pt_kept, pt_bins)
    fake_num_eta, fake_den_eta = binned(fake_flat, trk_eta_kept, eta_bins)
    fake_num_phi, fake_den_phi = binned(fake_flat, trk_phi_kept, phi_bins)

    # Compute the ratio and mask bins with den = 0
    fake_pt, fake_pt_err = wilson(fake_num_pt, fake_den_pt)
    fake_eta, fake_eta_err = wilson(fake_num_eta, fake_den_eta)
    fake_phi, fake_phi_err = wilson(fake_num_phi, fake_den_phi)
    valid_fake_pt = fake_den_pt > 0
    valid_fake_eta = fake_den_eta > 0
    valid_fake_phi = fake_den_phi > 0

    global_fake = fake_flat.sum() / max(len(fake_flat), 1)
    print(f"Global fake rate: {global_fake * 100:.2f}%")

    # ------ Make plots ------
    fig, axes = plt.subplots(3, 2, figsize=(12, 12))

    # Row 1: Efficiency / Fake rate vs pT
    ax = axes[0, 0]
    ax.errorbar(
        pt_centers[valid_eff_pt],
        eff_pt[valid_eff_pt],
        yerr=eff_pt_err[valid_eff_pt],
        markersize=2,
        fmt="s",
        capsize=2,
    )
    ax.set_xscale("log")
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Gen muon pT [GeV]")
    ax.set_ylabel("Efficiency")
    ax.grid(alpha=0.3)
    ax.set_title("Efficiency vs pT")

    ax = axes[0, 1]
    ax.errorbar(
        pt_centers[valid_fake_pt],
        fake_pt[valid_fake_pt],
        yerr=fake_pt_err[valid_fake_pt],
        markersize=2,
        fmt="s",
        capsize=2,
        color="tab:red",
    )
    ax.set_xscale("log")
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Reco track pT [GeV]")
    ax.set_ylabel("Fake rate")
    ax.grid(alpha=0.3)
    ax.set_title("Fake Rate vs pT")

    # Row 2: Efficiency / Fake rate vs η
    ax = axes[1, 0]
    ax.errorbar(
        eta_centers[valid_eff_eta],
        eff_eta[valid_eff_eta],
        yerr=eff_eta_err[valid_eff_eta],
        markersize=2,
        fmt="s",
        capsize=2,
    )
    ax.set_xlim(MTV_MIN_ETA, MTV_MAX_ETA)
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Gen muon η")
    ax.set_ylabel("Efficiency")
    ax.grid(alpha=0.3)
    ax.set_title("Efficiency vs η")

    ax = axes[1, 1]
    ax.errorbar(
        eta_centers[valid_fake_eta],
        fake_eta[valid_fake_eta],
        yerr=fake_eta_err[valid_fake_eta],
        markersize=2,
        fmt="s",
        capsize=2,
        color="tab:red",
    )
    ax.set_xlim(MTV_MIN_ETA, MTV_MAX_ETA)
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Reco track η")
    ax.set_ylabel("Fake rate")
    ax.grid(alpha=0.3)
    ax.set_title("Fake Rate vs η")

    # Row 3: Efficiency / Fake rate vs φ
    ax = axes[2, 0]
    ax.errorbar(
        phi_centers[valid_eff_phi],
        eff_phi[valid_eff_phi],
        yerr=eff_phi_err[valid_eff_phi],
        markersize=2,
        fmt="s",
        capsize=2,
    )
    ax.set_xlim(MTV_MIN_PHI, MTV_MAX_PHI)
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Gen muon φ")
    ax.set_ylabel("Efficiency")
    ax.grid(alpha=0.3)
    ax.set_title("Efficiency vs φ")

    ax = axes[2, 1]
    ax.errorbar(
        phi_centers[valid_fake_phi],
        fake_phi[valid_fake_phi],
        yerr=fake_phi_err[valid_fake_phi],
        markersize=2,
        fmt="s",
        capsize=2,
        color="tab:red",
    )
    ax.set_xlim(MTV_MIN_PHI, MTV_MAX_PHI)
    ax.set_ylim(0, 1.05)
    ax.set_xlabel("Reco track φ")
    ax.set_ylabel("Fake rate")
    ax.grid(alpha=0.3)
    ax.set_title("Fake Rate vs φ")

    # Set axes ticks
    for ax in axes.flatten():
        if not ax.get_xscale() == "log":
            ax.xaxis.set_major_locator(plt.MultipleLocator(0.5))
            ax.xaxis.set_minor_locator(plt.MultipleLocator(0.1))
        ax.yaxis.set_major_locator(plt.MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(plt.MultipleLocator(0.02))

    fig.tight_layout()
    return fig

# Visual inspection

## No cuts

### Track quality

In [None]:
tqNoSel = plot_track_quality_parameters()
outname = (
    "track_parameters_legacy_all.png" if legacy else "track_parameters_ext_all.png"
)
tqNoSel.savefig(outname, dpi=300)

### Efficiency and fake rate

In [None]:
effFakeNoSel = plot_efficiency_and_fake()
outname = "eff_and_fake_legacy_all.png" if legacy else "eff_and_fake_ext_all.png"
effFakeNoSel.savefig(outname, dpi=300)

## Simple cuts

In [None]:
dzErrCut = arrays.muon_pixel_tracks_dzErr < (
    0.018 if legacy else 0.015
)  # 0.006 aggressive
phiErrCut = arrays.muon_pixel_tracks_phiErr < (0.0040 if legacy else 0.0020)
qoverpCut = np.abs(arrays.muon_pixel_tracks_qoverp) < (
    1 if legacy else 0.075
)  # 0.05 aggressive
qoverpErrCut = arrays.muon_pixel_tracks_qoverpErr < (0.035 if legacy else 0.004)
lambdaErrCut = arrays.muon_pixel_tracks_lambdaErr < (1 if legacy else 0.0006)
etaErrCut = arrays.muon_pixel_tracks_etaErr < (1 if legacy else 0.0006)

sel = (
    (arrays.muon_pixel_tracks_pt >= 0)
    # & phiErrCut
    & dzErrCut
    & qoverpCut
    # & qoverpErrCut
    # & lambdaErrCut
    & etaErrCut
)

### Track quality

In [None]:
tqSel = plot_track_quality_parameters(selection=sel)
outname = (
    "track_parameters_legacy_cut_selection.png"
    if legacy
    else "track_parameters_ext_cut_selection.png"
)
tqSel.savefig(outname, dpi=300)

### Efficiency and fake

In [None]:
effFakeSel = plot_efficiency_and_fake(selection=sel)
outname = (
    "efficiency_and_fake_legacy_cut_selection.png"
    if legacy
    else "efficiency_and_fake_ext_cut_selection.png"
)
effFakeSel.savefig(outname, dpi=300)

# ML Models performance 

## BDT

In [None]:
# Load model (prefer full pipeline if available)
import joblib

pipeline_artifact = joblib.load("bdt_pipeline.pkl")
print("Loaded bdt_pipeline.pkl (pipeline + feature ordering).")
bdt_model = pipeline_artifact["pipeline"]
feature_order = pipeline_artifact["feature_names"]

In [None]:
# BDT application (adaptive). Requires bdt_pipeline.pkl for consistent scaling.
if not hasattr(bdt_model, "predict_proba"):
    raise RuntimeError(
        "Loaded model has no predict_proba; please load bdt_pipeline.pkl artifact."
    )

# Build flat feature matrix in training order
feature_map = {
    "pt": arrays.muon_pixel_tracks_pt,
    "eta": arrays.muon_pixel_tracks_eta,
    "phi": arrays.muon_pixel_tracks_phi,
    "qoverp": arrays.muon_pixel_tracks_qoverp,
    "qoverpErr": arrays.muon_pixel_tracks_qoverpErr,
    "dzErr": arrays.muon_pixel_tracks_dzErr,
    "etaErr": arrays.muon_pixel_tracks_etaErr,
    "lambdaErr": arrays.muon_pixel_tracks_lambdaErr,
    "dxyErr": arrays.muon_pixel_tracks_dxyErr,
    "phiErr": arrays.muon_pixel_tracks_phiErr,
    "normalizedChi2": arrays.muon_pixel_tracks_normalizedChi2,
    "nPixelHits": arrays.muon_pixel_tracks_nPixelHits,
    "nTrkLays": arrays.muon_pixel_tracks_nTrkLays,
    "dszErr": arrays.muon_pixel_tracks_dszErr,
}
cols = [ak.to_numpy(ak.flatten(feature_map[n])) for n in feature_order]
X_raw = np.vstack(cols).T
finite_mask = np.isfinite(X_raw).all(axis=1)
if not finite_mask.all():
    print(f"Non-finite rows dropped: {(~finite_mask).sum()}")
X_use = X_raw[finite_mask]

proba_valid = bdt_model.predict_proba(X_use)[:, list(bdt_model.classes_).index(1)]
proba_all = np.zeros(len(X_raw), dtype=float)
proba_all[finite_mask] = proba_valid

matched_truth = ak.to_numpy(ak.flatten(arrays.muon_pixel_tracks_matched)).astype(int)
assert len(matched_truth) == len(proba_all)

# Basic probability diagnostics
spread = proba_all.max() - proba_all.min()
print(
    "Prob stats: min={:.2f} max={:.2f} mean={:.2f} var={:.2e} spread={:.2f}".format(
        proba_all.min(), proba_all.max(), proba_all.mean(), proba_all.var(), spread
    )
)
print(
    "Quantiles (0,10,25,50,75,90,100)%:",
    np.quantile(proba_all, [0, 0.1, 0.25, 0.5, 0.75, 0.9, 1]),
)

keep_mask = np.ones_like(proba_all, dtype=bool)
min_proba = 0.8
if spread < 1e-6:
    print("WARNING: probabilities constant -> cannot discriminate; keeping all.")
else:
    keep_mask = proba_all >= min_proba
    print(f"Using default threshold {min_proba}.")

# Final diagnostics
n_tot = keep_mask.size
n_keep = keep_mask.sum()
purity = matched_truth[keep_mask].mean() * 100 if n_keep else 0.0
eff_true = (
    (matched_truth[keep_mask].sum() / max(matched_truth.sum(), 1)) * 100
    if matched_truth.sum()
    else 0.0
)
print(
    f"Selection: kept {n_keep}/{n_tot} ({n_keep / n_tot * 100:.2f}%) Purity={purity:.2f}% Matched-eff={eff_true:.2f}%"
)

y_pred_classes = (
    (proba_all >= min_proba).astype(int) if spread >= 1e-6 else np.ones_like(proba_all)
)

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(matched_truth, y_pred_classes, labels=[0, 1])
if cm.size == 4:
    tn, fp, fn, tp = cm.ravel()
    acc = (tn + tp) / n_tot
    print(
        f"Confusion @{min_proba} (rows truth fake/matched, cols pred fake/matched):\n{cm}"
    )
    print(
        f"Acc={acc * 100:.2f}% TPR={tp / max(tp + fn, 1):.3f} FPR={fp / max(fp + tn, 1):.3f}"
    )
else:
    print("Confusion matrix degenerate:", cm)

# Build jagged mask
counts = ak.to_numpy(ak.num(arrays.muon_pixel_tracks_pt))
bdt_selection = ak.unflatten(keep_mask, counts)

In [None]:
# Plots using the derived BDT mask
tqBDT = plot_track_quality_parameters(selection=bdt_selection)
tqBDT.savefig("track_parameters_bdt_selection.png", dpi=300)
effFakeBDT = plot_efficiency_and_fake(selection=bdt_selection)
effFakeBDT.savefig("efficiency_and_fake_bdt_selection.png", dpi=300)