In [None]:
import json
import uproot
import awkward as ak
import vector
import numpy as np
import matplotlib.pyplot as plt
import warnings

# plotting params
plt.rcParams.update({
    'figure.figsize': (10, 6),
    'axes.grid': True,
    'grid.alpha': 0.6,
    'grid.linestyle': '--',
    'font.size': 14,
    "figure.dpi": 200,
})

# Suppress a harmless warning from the vector library with awkward arrays
warnings.filterwarnings("ignore", message="Passing an awkward array to a ufunc")

# Register the vector library with awkward array
ak.behavior.update(vector.backends.awkward.behavior)

# --- 1. CONFIGURATION ---
# All user-changable settings are here.
with open("hh-bbbb-obj-config.json", "r") as config_file:
    CONFIG = json.load(config_file)

# --- 2. DATA LOADING & PREPARATION FUNCTIONS ---
from data_loading_helpers import load_and_prepare_data, select_gen_b_quarks_from_higgs
def load_and_prepare_data(file_pattern, tree_name, collections_to_load, max_events, CONFIG=None):
    """
    Loads the ROOT file, restructures the flat branches into objects,
    and creates 4-vector representations.
    """
    print(f"Loading data from {file_pattern}...")
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)

    try:
        events = uproot.concatenate(
                f"{file_pattern}:{tree_name}", 
                library="ak",
                entry_stop=max_events 
            )
    except FileNotFoundError:
            print(f"Error: No files found matching '{file_pattern}'. Please update the path.")
            exit()

    print("Reshaping data into nested objects...")
    for prefix in collections_to_load:
        prefixed_fields = [field for field in events.fields if field.startswith(prefix + "_")]
        if not prefixed_fields:
            print(f"Warning: No fields found with prefix '{prefix}_'. Skipping.")
            continue
        field_map = {field.replace(prefix + "_", ""): events[field] for field in prefixed_fields}
        events[prefix] = ak.zip(field_map)

    print("Creating 4-vector objects...")
    for prefix in collections_to_load:
        if prefix in events.fields and "pt" in events[prefix].fields:
            
            # Default to using the raw pt
            pt_field = events[prefix].pt

            # Handle mass:
            if "mass" in events[prefix].fields:
                mass_field = events[prefix].mass
            elif "et" in events[prefix].fields:
                # Calculate L1 mass from et, pt, and eta
                m2 = (events[prefix].et**2 - events[prefix].pt**2) * (np.cosh(events[prefix].eta)**2)
                m2_positive = ak.where(m2 < 0, 0, m2)
                mass_field = np.sqrt(m2_positive)
            else:
                mass_field = ak.zeros_like(pt_field)

            # Apply pT Corrections if this is the offline jet
            if prefix == CONFIG["offline"]["collection_name"]:
                tagger_name = CONFIG["offline"]["tagger_name"]
                print(f"Applying pT regression corrections to {prefix} {tagger_name}...")
                if tagger_name.startswith("btagPNet"):
                    pt_corrected = (
                        events[prefix].pt 
                        * events[prefix].PNetRegPtRawCorr 
                        * events[prefix].PNetRegPtRawCorrNeutrino
                    )
                elif tagger_name.startswith("btagUParTAK4"):
                    pt_corrected = (
                        events[prefix].pt 
                        * events[prefix].UParTAK4RegPtRawCorr 
                        * events[prefix].UParTAK4RegPtRawCorrNeutrino
                    )
                else:
                    pt_corrected = events[prefix].pt  # No correction if unknown tagger
                
                pt_corrected = (
                    events[prefix].pt 
                    * events[prefix].PNetRegPtRawCorr 
                    * events[prefix].PNetRegPtRawCorrNeutrino
                )
                # Scale mass by the same correction factor
                correction_factor = ak.where(events[prefix].pt > 0, pt_corrected / events[prefix].pt, 1.0)
                mass_field = mass_field * correction_factor
                pt_field = pt_corrected

            elif prefix == CONFIG["l1"]["collection_name"] and "ptCorrection" in events[prefix].fields:
                pt_corrected = events[prefix].pt * events[prefix].ptCorrection
                correction_factor = ak.where(events[prefix].pt > 0, pt_corrected / events[prefix].pt, 1.0)
                mass_field = mass_field * correction_factor
                pt_field = pt_corrected

            # getting the softmaxed scores for the next gen L1 jets
            if prefix == CONFIG["l1"]["collection_name"] and CONFIG["l1"]["collection_name"].endswith("NG"):
                l1_tag_scores = {field: events[prefix][field] for field in events[prefix].fields if field.endswith("Score")}
                for score in l1_tag_scores.keys():
                    events[prefix, score] = l1_tag_scores[score]

                b_v_udscg_score = events[prefix]["bTagScore"] / (events[prefix]["bTagScore"] + events[prefix]["cTagScore"] + events[prefix]["udsTagScore"] + events[prefix]["gTagScore"])
                c_v_b_score = events[prefix]["cTagScore"] / (events[prefix]["cTagScore"] + events[prefix]["bTagScore"])

                events[prefix, "b_v_udscg_score"] = b_v_udscg_score
                events[prefix, "c_v_b_score"] = c_v_b_score

            events[prefix, "vector"] = ak.zip(
                { "pt": pt_field, "eta": events[prefix].eta, "phi": events[prefix].phi, "mass": mass_field, },
                with_name="Momentum4D",
            )
            et_field = np.sqrt(pt_field**2 + mass_field**2) * np.cosh(events[prefix].eta)
            events[prefix, "et"] = et_field

            e_field = np.sqrt((pt_field * np.cosh(events[prefix].eta))**2 + mass_field**2)
            events[prefix, "e"] = e_field
            
    print(f"Loaded and restructured {len(events)} events.")
    return events

def select_gen_b_quarks_from_higgs(events):
    """
    Finds all b-quarks that are direct descendants of a Higgs boson.
    """
    print("Selecting gen-level b-quarks...")
    is_higgs = events.GenPart.pdgId == 25
    higgs_indices = ak.local_index(events.GenPart)[is_higgs]

    is_b = abs(events.GenPart.pdgId) == 5
    b_mother_idx = events.GenPart.genPartIdxMother
    
    b_mother_idx_expanded = b_mother_idx[:, :, None]
    higgs_indices_expanded = higgs_indices[:, None, :]
    
    comparison_b = (b_mother_idx_expanded == higgs_indices_expanded)
    has_higgs_mother_b = ak.any(comparison_b, axis=2)
    
    is_b_from_H = is_b & has_higgs_mother_b
    gen_b_quarks_from_H = events.GenPart[is_b_from_H]

    print(f"Found {ak.sum(ak.num(gen_b_quarks_from_H))} b-quarks from Higgs decays.")
    return gen_b_quarks_from_H


# --- 3. ANALYSIS FUNCTIONS ---
def get_efficiency_mask(gen_particles, reco_objects, CONFIG=None):
    """Returns a boolean mask for gen_particles, True if matched."""

    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)
    
    gen_expanded = gen_particles.vector[:, :, None]
    reco_expanded = reco_objects.vector[:, None, :]
    delta_r_matrix = gen_expanded.deltaR(reco_expanded)
    min_delta_r = ak.min(delta_r_matrix, axis=2)
    is_matched = min_delta_r < CONFIG["matching_cone_size"]
    return ak.fill_none(is_matched, False)

def get_purity_mask(gen_particles, reco_objects, CONFIG=None):
    """Returns a boolean mask for reco_objects, True if matched."""
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)
    gen_expanded = gen_particles.vector[:, None, :]
    reco_expanded = reco_objects.vector[:, :, None]
    delta_r_matrix = reco_expanded.deltaR(gen_expanded)
    min_delta_r = ak.min(delta_r_matrix, axis=2)
    is_matched = min_delta_r < CONFIG["matching_cone_size"]
    return ak.fill_none(is_matched, False)

def get_efficiency_mask_hungarian(gen_particles, reco_objects, CONFIG=None):
    """
    Matches using Hungarian Algorithm (1-to-1 uniqueness).
    Returns a boolean mask for gen_particles.
    """
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)

    is_matched_list = []
    for i in range(len(gen_particles)):
        gen_vec = gen_particles[i].vector
        reco_vec = reco_objects[i].vector
        
        if len(gen_vec) == 0 or len(reco_vec) == 0:
            is_matched_list.append(np.zeros(len(gen_vec), dtype=bool))
            continue
            
        # Shape: (N_gen, N_reco)
        # Note: vector library deltaR expects (N, 1) vs (1, M) broadcasting
        matrix = reco_vec[:, None].deltaR(gen_vec[None, :])
        
        # row_ind are indices in Reco, col_ind are indices in Gen
        row_ind, col_ind = linear_sum_assignment(matrix)
        
        valid_matches = matrix[row_ind, col_ind] < CONFIG["matching_cone_size"]
        event_mask = np.zeros(len(gen_vec), dtype=bool)
        
        # Set True only for gen indices that were assigned AND within cone
        event_mask[col_ind[valid_matches]] = True
        
        is_matched_list.append(event_mask)
        
    return ak.Array(is_matched_list)

def get_purity_mask_hungarian(gen_particles, reco_objects, CONFIG=None):
    """
    Matches using Hungarian Algorithm (1-to-1 uniqueness).
    Returns a boolean mask for reco_objects.
    """
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)

    is_matched_list = []
    for i in range(len(gen_particles)):
        gen_vec = gen_particles[i].vector
        reco_vec = reco_objects[i].vector
        
        if len(gen_vec) == 0 or len(reco_vec) == 0:
            is_matched_list.append(np.zeros(len(reco_vec), dtype=bool))
            continue
            
        # Shape: (N_reco, N_gen)
        # Note: vector library deltaR expects (N, 1) vs (1, M) broadcasting
        matrix = reco_vec[:, None].deltaR(gen_vec[None, :])
        
        # row_ind are indices in Reco, col_ind are indices in Gen
        row_ind, col_ind = linear_sum_assignment(matrix)
        
        valid_matches = matrix[row_ind, col_ind] < CONFIG["matching_cone_size"]
        event_mask = np.zeros(len(reco_vec), dtype=bool)
        
        # Set True only for reco indices that were assigned AND within cone
        event_mask[row_ind[valid_matches]] = True
        
        is_matched_list.append(event_mask)
        
    return ak.Array(is_matched_list)

def calculate_pur_eff_vs_variable(gen_particles, reco_objects, mask, variable, bins, is_purity_plot=False):
    """
    Calculates purity or efficiency vs. a kinematic variable for given reconstructed objects.
    Purity is defined as the fraction of reconstructed objects that are matched to a generated particle.
    Efficiency is defined as the fraction of generated particles that are matched to a reconstructed object.
    Returns the fraction and the error for each bin. The error is calculated using the Ullrich and Xu method.
    """

    if is_purity_plot:
        all_var = ak.to_numpy(ak.flatten(getattr(reco_objects, variable)))
        matched_var = ak.to_numpy(ak.flatten(getattr(reco_objects[mask], variable)))
    else:
        all_var = ak.to_numpy(ak.flatten(getattr(gen_particles, variable)))
        matched_var = ak.to_numpy(ak.flatten(getattr(gen_particles[mask], variable)))
    
    h_total, _ = np.histogram(all_var, bins=bins)
    h_matched, _ = np.histogram(matched_var, bins=bins)
    
    frac_offline = np.divide(h_matched, h_total, out=np.zeros_like(h_total, dtype=float), where=h_total!=0)
    err_offline = np.sqrt(((h_matched + 1) * (h_total - h_matched + 1)) / ((h_total + 2)**2 * (h_total + 3)))  # Ullrich and Xu

    return frac_offline, err_offline

def calculate_roc_points(reco_jets, is_pure_mask, tagger_name):
    """
    Calculates efficiency and mistag points for a ROC curve.
    Returns (mistag_points, efficiency_points, auc_score).
    """
    thresholds = np.linspace(min(0, ak.min(getattr(reco_jets, tagger_name))), max(1, ak.max(getattr(reco_jets, tagger_name))), 400)
    eff_points, mistag_points = [], []

    signal_jets = reco_jets[is_pure_mask]
    mistag_jets = reco_jets[~is_pure_mask]

    n_total_signal = ak.sum(ak.num(signal_jets))
    n_total_mistag = ak.sum(ak.num(mistag_jets))

    for cut in thresholds:
        # Calculate Signal Efficiency
        n_signal_passing = ak.sum(getattr(signal_jets, tagger_name) > cut)
        eff = n_signal_passing / n_total_signal if n_total_signal > 0 else 0
        
        # Calculate Mistag Rate
        n_mistag_passing = ak.sum(getattr(mistag_jets, tagger_name) > cut)
        mistag_rate = n_mistag_passing / n_total_mistag if n_total_mistag > 0 else 0
        
        eff_points.append(eff)
        mistag_points.append(mistag_rate)

    auc_score = auc(mistag_points, eff_points)
    return mistag_points, eff_points, auc_score, thresholds


# --- 4. PLOTTING FUNCTIONS ---
def plot_signal_background_histogram(reco_jets, is_pure_mask, bins, variable, xlabel, title):
    """
    Plots histograms of signal and background vs. a kinematic variable.
    """
    plt.figure(figsize=(10, 6))
    
    bin_centers = 0.5 * (bins[1:] + bins[:-1])

    signal_data = ak.to_numpy(ak.flatten(getattr(reco_jets[is_pure_mask], variable)))
    background_data = ak.to_numpy(ak.flatten(getattr(reco_jets[~is_pure_mask], variable)))

    h_signal, _ = np.histogram(signal_data, bins=bins)
    h_background, _ = np.histogram(background_data, bins=bins)

    plt.hist(signal_data, bins=bins, histtype="step", label='Signal', color='blue')
    plt.hist(background_data, bins=bins, histtype="step", label='Background', color='red')

    plt.xlabel(xlabel)
    plt.ylabel("Counts")
    plt.title(title)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.show()

def plot_kinematic_comparison(bins, variable, xlabel, title,
                              gen_particles,
                              objects, 
                              is_purity_plot=False,
                              fmt="o-", new_fig=True,
                              legend_postfix=""):
    """
    Plots efficiency or purity vs. a kinematic variable for objects input.
    Objects is a list of tuples: [(label, object_collection, mask)]
    """
    if new_fig:
        plt.figure(figsize=(10, 6))
    
    bin_centers = 0.5 * (bins[1:] + bins[:-1])

    for obj in objects:
        obj_label, obj_collection, obj_mask = obj
        y_values, y_errors = calculate_pur_eff_vs_variable(
            gen_particles, obj_collection, obj_mask, variable, bins, is_purity_plot=is_purity_plot
        )
        plt.errorbar(
            bin_centers, y_values, yerr=y_errors, fmt=fmt, label=f"{obj_label}{legend_postfix}"
        )
    plt.xlabel(xlabel)
    plt.ylabel("Purity" if is_purity_plot else "Efficiency")
    plt.title(title)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.ylim(0, 1.05)
    plt.legend()
    if new_fig:
        plt.show()

def plot_roc_comparison(roc_results, working_point=None):
    """
    Plots multiple ROC curves on the same axes.
    roc_results should be a list of tuples:
    [(label, (mistag_points, eff_points, auc_score)), ...]
    """
    plt.figure(figsize=(8, 8))
    
    for label, (mistag_pts, eff_pts, auc_score, _) in roc_results:
        plt.plot( mistag_pts, eff_pts, 'o-', 
                 label=f'{label} (AUC = {auc_score:.3f})', 
                 markersize=2)
    
    if working_point != None:
        plt.vlines(working_point, ymin=0, ymax=1, colors="black", linestyles="dashed", label=f"WP = {working_point*100}% Mistag rate")

    plt.ylabel("B-Tagging Efficiency")
    plt.xlabel("Mistag Rate")
    plt.xscale('log') 
    plt.xlim(1e-4, 1.0) 
    plt.ylim(1e-4, 1.05)
    plt.title("ROC Curve Comparison: Offline vs. L1 B-Tagging")
    plt.grid(True, linestyle='--', which='both', alpha=0.6)
    plt.legend(fontsize="small")
    plt.show()

def plot_btag_map(jets, tagger_name, pt_bins, eta_bins):
    """
    Plots a 2D heatmap of the average b-tag score vs. eta and jet pT.
    """
    print(f"Plotting 2D b-tag map for {tagger_name}...")
    
    # Flatten the jet properties into simple numpy arrays
    jet_pt = ak.to_numpy(ak.flatten(jets.vector.pt))
    jet_eta = ak.to_numpy(ak.flatten(jets.vector.eta))
    jet_btag = ak.to_numpy(ak.flatten(getattr(jets, tagger_name)))

    # --- Create the 2D Profile ---
    
    # 1. Create a 2D histogram of the SUM of b-tag scores in each bin
    #    We use the 'weights' argument to sum the b-tag scores
    h_sum_btag, xedges, yedges = np.histogram2d(
        jet_eta, jet_pt, bins=[eta_bins, pt_bins], weights=jet_btag
    )
    
    # 2. Create a 2D histogram of the COUNT of jets in each bin
    h_count_jets, _, _ = np.histogram2d(
        jet_eta, jet_pt, bins=[eta_bins, pt_bins]
    )
    
    # 3. Calculate the average score per bin
    #    Use np.divide to safely handle division by zero (for empty bins)
    h_avg_btag = np.divide(
        h_sum_btag, 
        h_count_jets, 
        out=np.zeros_like(h_sum_btag), 
        where=(h_count_jets != 0)
    )

    # --- Plotting ---
    # Use pcolormesh to plot the 2D array. Transpose (T) is needed
    # because numpy histogram and pcolormesh have different axis conventions.
    im = plt.pcolormesh(
        xedges, 
        yedges, 
        h_avg_btag.T, 
        cmap="jet",         # A common colormap for this
        norm=colors.Normalize(vmin=0.0, vmax=1.0) # Keep color scale 0-1
    )
    
    plt.ylabel(r"Corrected Jet $p_T$ [GeV]")
    plt.xlabel("Jet $\\eta$")
    plt.title(f"Average b-tag score ({tagger_name}) vs. $p_T$ and $\\eta$")
    
    # Add a color bar, which represents your z-axis
    cbar = plt.colorbar(im)
    cbar.set_label("Average b-tag score")
    
    plt.show()

def plot_cvb_map(jets, tagger_name, pt_bins, eta_bins):
    """
    Plots a 2D heatmap of the average CvB score vs. eta and jet pT.
    """
    print(f"Plotting 2D b-tag map for {tagger_name}...")
    
    # Flatten the jet properties into simple numpy arrays
    jet_pt = ak.to_numpy(ak.flatten(jets.vector.pt))
    jet_eta = ak.to_numpy(ak.flatten(jets.vector.eta))
    jet_cvb_tag = ak.to_numpy(ak.flatten(getattr(jets, tagger_name)))

    # --- Create the 2D Profile ---
    
    # 1. Create a 2D histogram of the SUM of b-tag scores in each bin
    #    We use the 'weights' argument to sum the b-tag scores
    h_sum_cvb_tag, xedges, yedges = np.histogram2d(
        jet_eta, jet_pt, bins=[eta_bins, pt_bins], weights=jet_cvb_tag
    )
    
    # 2. Create a 2D histogram of the COUNT of jets in each bin
    h_count_jets, _, _ = np.histogram2d(
        jet_eta, jet_pt, bins=[eta_bins, pt_bins]
    )
    
    # 3. Calculate the average score per bin
    #    Use np.divide to safely handle division by zero (for empty bins)
    h_avg_cvb_tag = np.divide(
        h_sum_cvb_tag, 
        h_count_jets, 
        out=np.zeros_like(h_sum_cvb_tag), 
        where=(h_count_jets != 0)
    )

    # --- Plotting ---
    # Use pcolormesh to plot the 2D array. Transpose (T) is needed
    # because numpy histogram and pcolormesh have different axis conventions.
    im = plt.pcolormesh(
        xedges, 
        yedges, 
        h_avg_cvb_tag.T, 
        cmap="jet",         # A common colormap for this
        norm=colors.Normalize(vmin=0.0, vmax=1.0) # Keep color scale 0-1
    )
    
    plt.ylabel(r"Corrected Jet $p_T$ [GeV]")
    plt.xlabel(r"Jet $\eta$")
    plt.title(f"Average CvB score ({tagger_name}) vs. $p_T$ and $\eta$")
    
    # Add a color bar, which represents your z-axis
    cbar = plt.colorbar(im)
    cbar.set_label("Average b-tag score")
    
    plt.show()

def plot_matching_criteria(gen_particles, reco_objects, CONFIG=None):
    """
    Plots a 2D heatmap of (reco_pT / gen_pT) vs. dR
    for the closest reco_object to each gen_particle.
    """
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)
    
    print("Plotting pT response vs. dR matching criteria...")
    
    # 1. Create the all-to-all deltaR matrix
    # gen_expanded shape: (events, n_gen, 1)
    # reco_expanded shape: (events, 1, n_reco)
    gen_expanded = gen_particles.vector[:, :, None]
    reco_expanded = reco_objects.vector[:, None, :]
    delta_r_matrix = gen_expanded.deltaR(reco_expanded)

    # 2. Find the index of the closest reco_object for each gen_particle
    #    shape: (events, n_gen)
    closest_reco_idx = ak.argmin(delta_r_matrix, axis=2)
    
    # 3. Get the dR value for this closest match
    #    shape: (events, n_gen)
    min_delta_r = ak.min(delta_r_matrix, axis=2)

    # 4. Get the pT of the gen particles
    #    shape: (events, n_gen)
    gen_pt = gen_particles.vector.pt
    
    # 5. Get the pT of all reco jets in each event
    #    shape: (events, n_reco)
    reco_pt = reco_objects.vector.pt
    
    # 6. Use the 'closest_reco_idx' to "pick" the pT of the matched jet
    #    This is the "fancy indexing" that matches gen to reco
    matched_reco_pt = reco_pt[closest_reco_idx]
    
    # 7. Calculate the pT ratio (reco / gen)
    #    We must use ak.where to prevent division by zero
    pt_ratio = ak.where(gen_pt > 0, matched_reco_pt / gen_pt, np.nan)

    # 8. Flatten everything into 1D numpy arrays for plotting
    flat_delta_r = ak.to_numpy(ak.flatten(min_delta_r))
    flat_pt_ratio = ak.to_numpy(ak.flatten(pt_ratio))
    
    # 9. Remove any invalid 'nan' values
    valid_mask = ~np.isnan(flat_pt_ratio)
    flat_delta_r = flat_delta_r[valid_mask]
    flat_pt_ratio = flat_pt_ratio[valid_mask]

    # --- Plotting ---
    plt.figure(figsize=(10, 8))
    
    # We use a 2D histogram (hist2d) to create the heatmap
    plt.hist2d(
        flat_delta_r, 
        flat_pt_ratio, 
        bins=[np.linspace(0, 2, 50), np.linspace(0, 2, 50)], 
        cmap='viridis', 
        norm=colors.LogNorm()  # Use a log scale for color to see faint spots
    )
    
    # Draw a red line at dR = 0.4 to show our matching cut
    plt.axvline(x=CONFIG["matching_cone_size"], color='red', linestyle='--', label=f'Matching Cut (ΔR={CONFIG["matching_cone_size"]})')
    
    plt.xlabel("ΔR (gen b-quark, closest reco jet)")
    plt.ylabel(r"p$_T$ Response (reco p$_T$ / gen p$_T$)")
    plt.title("p$_T$ Response vs. ΔR for b-quark to Jet Matching")
    plt.legend()
    
    # Add a color bar
    cbar = plt.colorbar()
    cbar.set_label("Number of Gen b-quarks")
    
    plt.show()

def plot_attr_vs_var(events, obj_collection, attr, variable, bins_attr, bins_var, xlabel, ylabel, title, mask=None):
    """
    Plots a 2D histogram of a given attribute vs. a variable for objects in a specified collection.
    """
    print(f"Plotting {attr} vs. {variable} for {obj_collection}...")

    objs = events[obj_collection]
    attr_values = getattr(objs, attr)
    var_values = getattr(objs, variable)

    if mask is not None:
        attr_values = attr_values[mask]
        var_values = var_values[mask]

    attr_values = ak.to_numpy(ak.flatten(attr_values))
    var_values = ak.to_numpy(ak.flatten(var_values))

    plt.figure(figsize=(10, 8))
    plt.hist2d(
        var_values,
        attr_values,
        bins=[bins_var, bins_attr],
        cmap='viridis',
        norm=colors.LogNorm()
    )

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    cbar = plt.colorbar()
    cbar.set_label("Counts")

    plt.show()

def plot_attr_vs_var_proj(events, obj_collection, attr, variable, bins_attr, bins_var, xlabel, ylabel, title, mask=None):
    """
    Plots the counts for the input collection against the variable on the x-axis adn the attribute on the y-axis.
    Accompanied by panels showing the projection of the counts onto each axis.
    """
    
    
    objs = events[obj_collection]
    attr_values = getattr(objs, attr)
    var_values = getattr(objs, variable)

    if mask is not None:
        attr_values = attr_values[mask]
        var_values = var_values[mask]

    attr_values = ak.to_numpy(ak.flatten(attr_values))
    var_values = ak.to_numpy(ak.flatten(var_values))

    # --- Plotting ---
    plt.figure(figsize=(10, 8))
    gs = gridspec.GridSpec(2, 2, height_ratios=[3, 1], hspace=0.05, width_ratios=[3, 1], wspace=0.05)
    
    # Top Panel: Average B-Tag Score
    ax0 = plt.subplot(gs[0, 0])
    ax0.hist2d(
        var_values, attr_values, 
        bins=[bins_var, bins_attr], 
        cmap='viridis',
        norm=colors.LogNorm()
    )
    ax0.set_ylabel(ylabel)
    ax0.set_title(title)
    ax0.set_xticklabels([]) # Hide x-labels for top plot
    ax0.set_yticks(np.linspace(min(bins_attr), max(bins_attr), 11)) # Reduce number of y-ticks for clarity
    ax0.set_xticks(np.linspace(min(bins_var), max(bins_var), 11)) # Reduce number of x-ticks for clarity
    ax0.grid(True, linestyle='--', alpha=0.6)
    
    # Bottom Panel: Projection on the x-axis 
    ax1 = plt.subplot(gs[1, 0])
    var_counts, _ = np.histogram(var_values, bins=bins_var)
    bin_centres = 0.5 * (bins_var[1:] + bins_var[:-1])
    ax1.step(bin_centres, var_counts, label="Counts", color="black")
    ax1.set_ylabel("Counts")
    ax1.set_xlabel(xlabel)
    ax1.set_xticks(np.linspace(min(bins_var), max(bins_var), 11)) 
    ax1.set_yticks(np.linspace(0, max(var_counts), 3)) 
    # ax1.set_yscale("log") # Counts often span orders of magnitude
    ax1.grid(True, linestyle='--', alpha=0.6)
    ax1.set_xlim(min(bins_var), max(bins_var))
    ax1.legend(fontsize='small')
    
    # Side Panel: Projection on the y-axis
    ax2 = plt.subplot(gs[0, 1])
    attr_counts, _ = np.histogram(attr_values, bins=bins_attr)
    bin_centres = 0.5 * (bins_attr[1:] + bins_attr[:-1])
    ax2.step(attr_counts, bin_centres, label="Counts", color='black')
    ax2.set_xlabel("Counts")
    # ax2.set_yscale("log") # Counts often span orders of magnitude
    ax2.set_yticklabels([]) # Hide y-labels for side plot
    ax2.set_xticks(np.linspace(0, max(attr_counts), 2)) # Reduce number of x-ticks for clarity
    ax2.set_yticks(np.linspace(min(bins_attr), max(bins_attr), 11)) # Reduce number of y-ticks for clarity
    ax2.legend(fontsize='small')
    ax2.grid(True, linestyle='--', alpha=0.6)
    ax2.set_ylim(min(bins_attr), max(bins_attr))
    cbar = plt.colorbar(ax0.collections[0], ax=ax2, pad=0.15)
    cbar.set_label("Counts")
    
    plt.show()

def plot_avg_attr_vs_var(reco_jets, is_signal_mask, attr, variable, bins, xlabel, ylabel, title):
    """
    Plots the average of a given attribute vs. a variable for signal and background jets.
    """
    attr_signal_flattened = ak.to_numpy(ak.flatten(getattr(reco_jets[is_signal_mask], attr)))
    variable_signal_flattened = ak.to_numpy(ak.flatten(getattr(reco_jets[is_signal_mask], variable)))
    attr_background_flattened = ak.to_numpy(ak.flatten(getattr(reco_jets[~is_signal_mask], attr)))
    variable_background_flattened = ak.to_numpy(ak.flatten(getattr(reco_jets[~is_signal_mask], variable)))
    avg_signal_attr_per_var_bin = []
    std_signal_attr_per_var_bin = []
    avg_background_attr_per_var_bin = []
    std_background_attr_per_var_bin = []
    for i in range(len(bins)-1):
        bin_mask_signal = (variable_signal_flattened >= bins[i]) & (variable_signal_flattened < bins[i+1])
        bin_mask_background = (variable_background_flattened >= bins[i]) & (variable_background_flattened < bins[i+1])
        if np.sum(bin_mask_signal) > 0:
            avg_signal_attr_per_var_bin.append(np.mean(attr_signal_flattened[bin_mask_signal]))
            std_signal_attr_per_var_bin.append(np.std(attr_signal_flattened[bin_mask_signal]) / np.sqrt(np.sum(bin_mask_signal)))
        else:
            avg_signal_attr_per_var_bin.append(0)
            std_signal_attr_per_var_bin.append(0)
        if np.sum(bin_mask_background) > 0:
            avg_background_attr_per_var_bin.append(np.mean(attr_background_flattened[bin_mask_background]))
            std_background_attr_per_var_bin.append(np.std(attr_background_flattened[bin_mask_background]) / np.sqrt(np.sum(bin_mask_background)))
        else:
            avg_background_attr_per_var_bin.append(0)
            std_background_attr_per_var_bin.append(0)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    plt.errorbar(bin_centers, avg_signal_attr_per_var_bin, yerr=std_signal_attr_per_var_bin, marker='o', label="Signal Avg")
    plt.errorbar(bin_centers, avg_background_attr_per_var_bin, yerr=std_background_attr_per_var_bin, marker='o', label="Background Avg")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.show()
  
def plot_resolution_vs_var(gen_var, resolution_values, bins, y_label, x_label, title):
    """
    Bins the data by Gen pT and calculates the Mean and Width (StdDev) 
    of the resolution in each bin.
    """
    bin_centers = 0.5 * (bins[1:] + bins[:-1])
    
    means = []
    widths = []
    errors = [] # Error on the width calculation
    
    # Digitize: Find which bin each event belongs to
    # indices 1 to len(bins)-1 are valid bins
    bin_indices = np.digitize(gen_var, bins)
    
    for i in range(1, len(bins)):
        # Select data for this specific bin
        data_in_bin = resolution_values[bin_indices == i]
        
        if len(data_in_bin) > 10: # Require minimum stats
            # Calculate Mean (Bias)
            mu = np.mean(data_in_bin)
            
            # Calculate Width (Resolution)
            # Standard Deviation is simple, but IQR/2 is more robust against tails
            sigma = np.std(data_in_bin) 
            
            # Error on std dev estimate approx: sigma / sqrt(2N)
            err = sigma / np.sqrt(2 * len(data_in_bin))
            
            means.append(mu)
            widths.append(sigma)
            errors.append(err)
        else:
            means.append(np.nan)
            widths.append(np.nan)
            errors.append(0)

    # --- Plotting ---
    plt.figure(figsize=(10, 8))
    
    # Top Panel: Resolution (Width)
    plt.subplot(2, 1, 1)
    plt.errorbar(bin_centers, widths, yerr=errors, fmt='o-', capsize=5, label='Resolution ($\sigma$)')
    plt.ylabel(f"{y_label} Resolution")
    plt.title(title)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    
    # Bottom Panel: Scale/Bias (Mean)
    plt.subplot(2, 1, 2)
    plt.errorbar(bin_centers, means, fmt='s--', color='red', capsize=5, label='Scale (Mean)')
    plt.axhline(0, color='black', linestyle='-', linewidth=1)
    plt.xlabel(x_label)
    plt.ylabel(f"{y_label} Scale (Bias)")
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    
    plt.tight_layout()
    plt.show()


# --- 5. SELECTION CUT FUNCTIONS ---
def apply_custom_cuts(reco_jets, config, key):
    """
    Apply custom cuts to a jet collection.

    Parameters
    ----------
    reco_jets : awkward.Array
        Jet collection (offline or L1).
    config : dict
        Global CONFIG dict.
    key : str
        Either "offline" or "l1" to select the appropriate config.
    """
    subcfg = config[key]

    pt_cut = subcfg["pt_cut"]
    eta_cut = subcfg["eta_cut"]
    b_tag_cut = subcfg["b_tag_cut"]
    tagger_name = subcfg["tagger_name"]

    print(f"\nApplying custom pT cut of {pt_cut} GeV for {key} jets...")
    pt_mask = reco_jets.pt > pt_cut
    eta_mask = abs(reco_jets.eta) < eta_cut
    final_mask = pt_mask & eta_mask

    print(f"Applying custom cuts for {tagger_name} ({key})...")

    if key == "offline":
        charm_veto_cut = subcfg["charm_veto_cut"]
        electron_veto_cut = subcfg["electron_veto_cut"]
        muon_veto_cut = subcfg["muon_veto_cut"]

        if tagger_name.startswith("btagPNet"):
            b_jet_mask = (reco_jets.btagPNetB > b_tag_cut)
            charm_veto_mask = (reco_jets.btagPNetCvB < charm_veto_cut)
            final_mask = final_mask & charm_veto_mask & b_jet_mask

        elif tagger_name.startswith("btagUParTAK4"):
            b_jet_mask = (reco_jets.btagUParTAK4probb > b_tag_cut)
            charm_veto_mask = (reco_jets.btagUParTAK4CvB < charm_veto_cut)
            electron_veto_mask = (reco_jets.btagUParTAK4Ele < electron_veto_cut)
            muon_veto_mask = (reco_jets.btagUParTAK4Mu < muon_veto_cut)
            final_mask = (
                final_mask
                & charm_veto_mask
                & electron_veto_mask
                & muon_veto_mask
                & b_jet_mask
            )

    elif key == "l1":
        # For L1, just apply the tagger cut generically
        tag_mask = getattr(reco_jets, tagger_name) > b_tag_cut
        final_mask = final_mask & tag_mask

    elif key == "l1ext":
        tag_mask = getattr(reco_jets, tagger_name) > b_tag_cut
        final_mask = final_mask & tag_mask

    reco_jets = reco_jets[final_mask]
    return reco_jets



In [None]:
# for comparing the L1NG to the L1ExtJet, offline collections

import copy

with open("hh-bbbb-obj-config.json", "r") as config_file:
    CONFIG = json.load(config_file)

pt_bins = np.linspace(0, 500, 201)
eta_bins = np.linspace(-3, 3, 201)
b_tag_bins = np.linspace(0, 1, 101)

events = load_and_prepare_data(
    CONFIG["file_pattern"], 
    CONFIG["tree_name"], 
    [
        "GenPart", 
        CONFIG["offline"]["collection_name"], 
        CONFIG["l1"]["collection_name"],
        "L1puppiExtJetSC4"
    ], 
    CONFIG["max_events"]
)

upart_CONFIG = copy.deepcopy(CONFIG)
upart_CONFIG["offline"]["tagger_name"] = "btagUParTAK4B"

upart_events = load_and_prepare_data(
    upart_CONFIG["file_pattern"], 
    upart_CONFIG["tree_name"], 
    [
        upart_CONFIG["offline"]["collection_name"],
    ], 
    upart_CONFIG["max_events"],
    CONFIG=upart_CONFIG
)

upart_CONFIG = copy.deepcopy(CONFIG)
upart_CONFIG["offline"]["tagger_name"] = "btagUParTAK4B"

gen_b_quarks = select_gen_b_quarks_from_higgs(events)
gen_b_quarks = gen_b_quarks[(gen_b_quarks.pt > CONFIG["gen"]["pt_cut"]) & (abs(gen_b_quarks.eta) < CONFIG["gen"]["eta_cut"])]

reco_jets_offline = apply_custom_cuts(events[CONFIG["offline"]["collection_name"]], CONFIG, "offline")
reco_jets_offline_upart = apply_custom_cuts(upart_events[upart_CONFIG["offline"]["collection_name"]], upart_CONFIG, "offline")
reco_jets_l1ng = apply_custom_cuts(events[CONFIG["l1"]["collection_name"]], CONFIG, "l1")
reco_jets_l1ext = apply_custom_cuts(events[CONFIG["l1ext"]["collection_name"]], CONFIG, "l1ext")

# efficiency masks
b_quarks_is_matched_offline = get_efficiency_mask(gen_b_quarks, reco_jets_offline)
b_quarks_is_matched_offline_upart = get_efficiency_mask(gen_b_quarks, reco_jets_offline_upart)
b_quarks_is_matched_l1ng = get_efficiency_mask(gen_b_quarks, reco_jets_l1ng)
b_quarks_is_matched_l1ext = get_efficiency_mask(gen_b_quarks, reco_jets_l1ext)
eff_objects = [
    ("Offline", reco_jets_offline, b_quarks_is_matched_offline),
    ("Offline UParT", reco_jets_offline_upart, b_quarks_is_matched_offline_upart),
    ("L1NG", reco_jets_l1ng, b_quarks_is_matched_l1ng),
    ("L1ExtJet", reco_jets_l1ext, b_quarks_is_matched_l1ext)
]

# purity masks
is_reco_jet_pure_offline = get_purity_mask(gen_b_quarks, reco_jets_offline)
is_reco_jet_pure_offline_upart = get_purity_mask(gen_b_quarks, reco_jets_offline_upart)
is_reco_jet_pure_l1ng = get_purity_mask(gen_b_quarks, reco_jets_l1ng)
is_reco_jet_pure_l1ext = get_purity_mask(gen_b_quarks, reco_jets_l1ext)
purity_objects = [
    ("Offline", reco_jets_offline, is_reco_jet_pure_offline),
    ("Offline UParT", reco_jets_offline_upart, is_reco_jet_pure_offline_upart),
    ("L1NG", reco_jets_l1ng, is_reco_jet_pure_l1ng),
    ("L1ExtJet", reco_jets_l1ext, is_reco_jet_pure_l1ext)
]

print("Plotting reconstruction efficiencies for L1NG vs L1ExtJet...")
plot_kinematic_comparison(
    bins=np.linspace(0, 500, 51), variable="pt",
    xlabel=r"Generated b-quark $p_T$ [GeV]", title="Reco Efficiency vs. $p_T$",
    gen_particles=gen_b_quarks,
    objects=eff_objects
)
plot_kinematic_comparison(
    bins=np.linspace(-3, 3, 51), variable="eta",
    xlabel="Generated b-quark $\\eta$", title="Reco Efficiency vs. $\\eta$",
    gen_particles=gen_b_quarks,
    objects=eff_objects
)

plot_kinematic_comparison(
    bins=np.linspace(0, 500, 51), variable="pt",
    xlabel=r"Reconstructed b-quark $p_T$ [GeV]", title="Reco Purity vs. $p_T$",
    gen_particles=gen_b_quarks,
    objects=purity_objects,
    is_purity_plot=True
)
plot_kinematic_comparison(
    bins=np.linspace(-3, 3, 51), variable="eta",
    xlabel="Reconstructed b-quark $\\eta$", title="Reco Purity vs. $\\eta$",
    gen_particles=gen_b_quarks,
    objects=purity_objects,
    is_purity_plot=True
)


# plot_btag_map(reco_jets_l1ext, "btagScore", pt_bins, eta_bins)
# plot_btag_map(reco_jets_offline, CONFIG["offline"]["tagger_name"], pt_bins, eta_bins)
# plot_btag_map(reco_jets_offline_upart, "btagUParTAK4B", pt_bins, eta_bins)
# plot_btag_map(reco_jets_l1ng, CONFIG["l1"]["tagger_name"], pt_bins, eta_bins)

plt.hist(ak.flatten(getattr(reco_jets_l1ext, CONFIG["l1ext"]["tagger_name"])), bins=b_tag_bins, histtype="step", label="L1ExtJet B-Tag Score")
plt.hist(ak.flatten(getattr(reco_jets_l1ng, CONFIG["l1"]["tagger_name"])), bins=b_tag_bins, histtype="step", label="L1NG B-Tag Score")
plt.hist(ak.flatten(getattr(reco_jets_offline, CONFIG["offline"]["tagger_name"])), bins=b_tag_bins, histtype="step", label="Offline B-Tag Score")
plt.hist(ak.flatten(getattr(reco_jets_offline_upart, "btagUParTAK4B")), bins=b_tag_bins, histtype="step", label="Offline UParT B-Tag Score")
plt.legend()
plt.show()

# plot_signal_background_histogram(reco_jets_offline, is_reco_jet_pure_offline, pt_bins, "pt", r"Offline Jet $p_T$ [GeV]", "Offline Jet $p_T$ Distribution")
# plot_signal_background_histogram(reco_jets_offline_upart, is_reco_jet_pure_offline_upart, pt_bins, "pt", r"Offline UParT Jet $p_T$ [GeV]", "Offline UParT Jet $p_T$ Distribution")
# plot_signal_background_histogram(reco_jets_l1ng, is_reco_jet_pure_l1ng, pt_bins, "pt", r"L1NG Jet $p_T$ [GeV]", "L1NG Jet $p_T$ Distribution")
# plot_signal_background_histogram(reco_jets_l1ext, is_reco_jet_pure_l1ext, pt_bins, "pt", r"L1ExtJet Jet $p_T$ [GeV]", "L1ExtJet Jet $p_T$ Distribution")

# plot_signal_background_histogram(reco_jets_offline, is_reco_jet_pure_offline, np.linspace(-3, 3, 51), "eta", "Offline Jet $\\eta$", "Offline Jet $\\eta$ Distribution")
# plot_signal_background_histogram(reco_jets_offline_upart, is_reco_jet_pure_offline_upart, np.linspace(-3, 3, 51), "eta", "Offline UParT Jet $\\eta$", "Offline UParT Jet $\\eta$ Distribution")
# plot_signal_background_histogram(reco_jets_l1ng, is_reco_jet_pure_l1ng, np.linspace(-3, 3, 51), "eta", "L1NG Jet $\\eta$", "L1NG Jet $\\eta$ Distribution")
# plot_signal_background_histogram(reco_jets_l1ext, is_reco_jet_pure_l1ext, np.linspace(-3, 3, 51), "eta", "L1ExtJet Jet $\\eta$", "L1ExtJet Jet $\\eta$ Distribution")

# plot_signal_background_histogram(reco_jets_offline, is_reco_jet_pure_offline, np.linspace(0, 1, 51), CONFIG["offline"]["tagger_name"], "Offline Jet B-Tag Score", "Offline Jet B-Tag Score Distribution")
# plot_signal_background_histogram(reco_jets_offline_upart, is_reco_jet_pure_offline_upart, np.linspace(0, 1, 51), "btagUParTAK4B", "Offline UParT Jet B-Tag Score", "Offline UParT Jet B-Tag Score Distribution")
# plot_signal_background_histogram(reco_jets_l1ng, is_reco_jet_pure_l1ng, np.linspace(0, 1, 51), CONFIG["l1"]["tagger_name"], "L1NG Jet B-Tag Score", "L1NG Jet B-Tag Score Distribution")
# plot_signal_background_histogram(reco_jets_l1ext, is_reco_jet_pure_l1ext, np.linspace(0, 1, 51), CONFIG["l1ext"]["tagger_name"], "L1ExtJet Jet B-Tag Score", "L1ExtJet Jet B-Tag Score Distribution")

offline_roc = calculate_roc_points(
    reco_jets_offline, 
    is_reco_jet_pure_offline, 
    CONFIG["offline"]["tagger_name"]
)

offline_upart_roc = calculate_roc_points(
    reco_jets_offline_upart, 
    is_reco_jet_pure_offline_upart, 
    "btagUParTAK4B"
)

l1ng_roc = calculate_roc_points(
    reco_jets_l1ng, 
    is_reco_jet_pure_l1ng, 
    CONFIG["l1"]["tagger_name"]
)
l1ext_roc = calculate_roc_points(
    reco_jets_l1ext, 
    is_reco_jet_pure_l1ext, 
    CONFIG["l1ext"]["tagger_name"]
)

plot_roc_comparison([
    (f"Offline {CONFIG["offline"]["tagger_name"]}", offline_roc),
    (f"Offline UParT", offline_upart_roc),
    ("L1NG", l1ng_roc),
    ("L1ExtJet", l1ext_roc)
    ],
    working_point=0.001
)


mistag_offline, eff_offline, auc_offline, thresh_offline = offline_roc
mistag_offline_upart, eff_offline_upart, auc_offline_upart, thresh_offline_upart = offline_upart_roc
mistag_l1, eff_l1, auc_l1, thresh_l1 = l1ng_roc
mistag_l1ext, eff_l1ext, auc_l1ext, thresh_l1ext = l1ext_roc

def get_roc_point_at_efficiency(mistag, eff, thresh, target_eff):
    return [(m, e, th) for m, e, th in zip(mistag, eff, thresh) if e >= target_eff][-1]

def get_roc_point_at_mistag(mistag, eff, thresh, target_mistag):
    return [(m, e, th) for m, e, th in zip(mistag, eff, thresh) if m >= target_mistag][-1]

fpr_offline_tight, tpr_offline_tight, thresh_offline_tight = get_roc_point_at_mistag(mistag_offline, eff_offline, thresh_offline, 0.001)
fpr_offline_medium, tpr_offline_medium, thresh_offline_medium = get_roc_point_at_mistag(mistag_offline, eff_offline, thresh_offline, 0.01)
fpr_offline_loose, tpr_offline_loose, thresh_offline_loose = get_roc_point_at_mistag(mistag_offline, eff_offline, thresh_offline, 0.1)
print(f"\nOffline \nAUC: {auc_offline:.4f}")
print(f"Tight WP: TPR: {tpr_offline_tight * 100:.4f}%, 1/FPR: {1/fpr_offline_tight:.4f}, Threshold: {thresh_offline_tight:.4f}")
print(f"Medium WP: TPR: {tpr_offline_medium * 100:.4f}%, 1/FPR: {1/fpr_offline_medium:.4f}, Threshold: {thresh_offline_medium:.4f}")
print(f"Loose WP: TPR: {tpr_offline_loose * 100:.4f}%, 1/FPR: {1/fpr_offline_loose:.4f}, Threshold: {thresh_offline_loose:.4f}")

fpr_offline_upart_tight, tpr_offline_upart_tight, thresh_offline_upart_tight = get_roc_point_at_mistag(mistag_offline_upart, eff_offline_upart, thresh_offline_upart, 0.001)
fpr_offline_upart_medium, tpr_offline_upart_medium, thresh_offline_upart_medium = get_roc_point_at_mistag(mistag_offline_upart, eff_offline_upart, thresh_offline_upart, 0.01)
fpr_offline_upart_loose, tpr_offline_upart_loose, thresh_offline_upart_loose = get_roc_point_at_mistag(mistag_offline_upart, eff_offline_upart, thresh_offline_upart, 0.1)
print(f"\nOffline UParT \nAUC: {auc_offline_upart:.4f}")
print(f"Tight WP: TPR: {tpr_offline_upart_tight * 100:.4f}%, 1/FPR: {1/fpr_offline_upart_tight:.4f}, Threshold: {thresh_offline_upart_tight:.4f}")
print(f"Medium WP: TPR: {tpr_offline_upart_medium * 100:.4f}%, 1/FPR: {1/fpr_offline_upart_medium:.4f}, Threshold: {thresh_offline_upart_medium:.4f}")
print(f"Loose WP: TPR: {tpr_offline_upart_loose * 100:.4f}%, 1/FPR: {1/fpr_offline_upart_loose:.4f}, Threshold: {thresh_offline_upart_loose:.4f}")


fpr_l1_tight, tpr_l1_tight, thresh_l1_tight = get_roc_point_at_mistag(mistag_l1, eff_l1, thresh_l1, 0.001)
fpr_l1_medium, tpr_l1_medium, thresh_l1_medium = get_roc_point_at_mistag(mistag_l1, eff_l1, thresh_l1, 0.01)
fpr_l1_loose, tpr_l1_loose, thresh_l1_loose = get_roc_point_at_mistag(mistag_l1, eff_l1, thresh_l1, 0.1)
print(f"\nL1 \nAUC: {auc_l1:.4f}")
print(f"Tight WP: TPR: {tpr_l1_tight * 100:.4f}%, 1/FPR: {1/fpr_l1_tight:.4f}, Threshold: {thresh_l1_tight:.4f}")
print(f"Medium WP: TPR: {tpr_l1_medium * 100:.4f}%, 1/FPR: {1/fpr_l1_medium:.4f}, Threshold: {thresh_l1_medium:.4f}")
print(f"Loose WP: TPR: {tpr_l1_loose * 100:.4f}%, 1/FPR: {1/fpr_l1_loose:.4f}, Threshold: {thresh_l1_loose:.4f}")

fpr_l1ext_tight, tpr_l1ext_tight, thresh_l1ext_tight = get_roc_point_at_mistag(mistag_l1ext, eff_l1ext, thresh_l1ext, 0.001)
fpr_l1ext_medium, tpr_l1ext_medium, thresh_l1ext_medium = get_roc_point_at_mistag(mistag_l1ext, eff_l1ext, thresh_l1ext, 0.01)
fpr_l1ext_loose, tpr_l1ext_loose, thresh_l1ext_loose = get_roc_point_at_mistag(mistag_l1ext, eff_l1ext, thresh_l1ext, 0.1)
print(f"\nL1ExtJet \nAUC: {auc_l1ext:.4f}")
print(f"Tight WP: TPR: {tpr_l1ext_tight * 100:.4f}%, 1/FPR: {1/fpr_l1ext_tight:.4f}, Threshold: {thresh_l1ext_tight:.4f}")
print(f"Medium WP: TPR: {tpr_l1ext_medium * 100:.4f}%, 1/FPR: {1/fpr_l1ext_medium:.4f}, Threshold: {thresh_l1ext_medium:.4f}")
print(f"Loose WP: TPR: {tpr_l1ext_loose * 100:.4f}%, 1/FPR: {1/fpr_l1ext_loose:.4f}, Threshold: {thresh_l1ext_loose:.4f}")

In [None]:
# enquire about the hwPt plot -> solved
# Answer: hwPT: raw pT from FPGAs -> cotinuous values in actual pt after a potentilly non-linear mapping from the raw integer values to a continuous pt space - can be ignored for now.
reco_jets_l1 = events[CONFIG["l1"]["collection_name"]]
plt.hist(ak.flatten(reco_jets_l1.hwPt * reco_jets_l1.ptCorrection), bins=np.linspace(0, 1000, 101), histtype="step", label="hwPt")
plt.hist(ak.flatten(reco_jets_l1.pt * reco_jets_l1.ptCorrection), bins=np.linspace(0, 1000, 101), histtype="step", label="pt")
plt.legend()
plt.show()


# plt.hist(ak.flatten(reco_jets_l1.c_v_b_score), bins=np.linspace(0, 1, 101), histtype="step", label="CvB Score")
plt.hist(ak.flatten(reco_jets_l1.cTagScore), bins=np.linspace(0, 1, 101), histtype="step", label="C Score")
plt.hist(ak.flatten(reco_jets_l1.bTagScore), bins=np.linspace(0, 1, 101), histtype="step", label="B Score")
# plt.hist(ak.flatten(reco_jets_l1.b_v_udscg_score), bins=np.linspace(0, 1, 101), histtype="step", label="BvUDSCG Score")
plt.hist(ak.flatten(reco_jets_l1.udsTagScore), bins=np.linspace(0, 1, 101), histtype="step", label="UDS Score")
# plt.hist(ak.flatten(reco_jets_l1.eTagScore), bins=np.linspace(0, 1, 101), histtype="step", label="E Score")
plt.legend()
plt.show()


In [None]:
# to see and plot the horns
plot_attr_vs_var_proj(events, CONFIG["gen"]["collection_name"], "pt", "eta", 
    bins_attr=np.linspace(0, 500, 101), bins_var=np.linspace(-3, 3, 101),
    xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
    title="Gen Jet $p_T$ vs. $\\eta$"
)
plot_attr_vs_var_proj(events, CONFIG["offline"]["collection_name"], "pt", "eta", 
    bins_attr=np.linspace(0, 500, 201), bins_var=np.linspace(-4, 4, 201),
    xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
    title="Offline Jet $p_T$ vs. $\\eta$"
)
plot_attr_vs_var_proj(events, CONFIG["offline"]["collection_name"], "pt", "eta", 
    bins_attr=np.linspace(0, 500, 201), bins_var=np.linspace(-4, 4, 201),
    xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
    title="Offline Pure Jet $p_T$ vs. $\\eta$",
    mask=is_reco_jet_pure_offline
)
plot_attr_vs_var_proj(events, CONFIG["offline"]["collection_name"], "pt", "eta", 
    bins_attr=np.linspace(0, 500, 201), bins_var=np.linspace(-4, 4, 201),
    xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
    title="Offline Impure Jet $p_T$ vs. $\\eta$",
    mask=~is_reco_jet_pure_offline
)
# plot_attr_vs_var_proj(events, CONFIG["l1"]["collection_name"], "pt", "eta", 
#     bins_attr=np.linspace(0, 500, 101), bins_var=np.linspace(-3, 3, 101),
#     xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
#     title="L1NG Jet $p_T$ vs. $\\eta$"
# )
# plot_attr_vs_var_proj(events, CONFIG["l1ext"]["collection_name"], "pt", "eta",
#     bins_attr=np.linspace(0, 500, 101), bins_var=np.linspace(-3, 3, 101),
#     xlabel="Jet $\\eta$", ylabel="Jet $p_T$ [GeV]",
#     title="L1ext Jet $p_T$ vs. $\\eta$"
# )

In [None]:
# avg b_tag scores
plot_objs_pt = [
    (reco_jets_offline, is_reco_jet_pure_offline,
    CONFIG["offline"]["tagger_name"], "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "B-Tag Score",
    "Offline Jet B-Tag Score vs. $p_T$"),
    (reco_jets_offline_upart, is_reco_jet_pure_offline_upart,
    "btagUParTAK4B", "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "B-Tag Score",
    "Offline UParT Jet B-Tag Score vs. $p_T$"),
    (reco_jets_l1ng, is_reco_jet_pure_l1ng,
    CONFIG["l1"]["tagger_name"], "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "B-Tag Score",
    "L1NG Jet B-Tag Score vs. $p_T$"),
    (reco_jets_l1ext, is_reco_jet_pure_l1ext,
    CONFIG["l1ext"]["tagger_name"], "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "B-Tag Score",
    "L1ExtJet Jet B-Tag Score vs. $p_T$")
]

for obj in plot_objs_pt:
    plot_avg_attr_vs_var(*obj)


plot_objs_eta = [
    (reco_jets_offline, is_reco_jet_pure_offline,
    CONFIG["offline"]["tagger_name"], "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "B-Tag Score",
    "Offline Jet B-Tag Score vs. $\\eta$"),
    (reco_jets_offline_upart, is_reco_jet_pure_offline_upart,
    "btagUParTAK4B", "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "B-Tag Score",
    "Offline UParT Jet B-Tag Score vs. $\\eta$"),
    (reco_jets_l1ng, is_reco_jet_pure_l1ng,
    CONFIG["l1"]["tagger_name"], "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "B-Tag Score",
    "L1NG Jet B-Tag Score vs. $\\eta$"),
    (reco_jets_l1ext, is_reco_jet_pure_l1ext,
    CONFIG["l1ext"]["tagger_name"], "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "B-Tag Score",
    "L1ExtJet Jet B-Tag Score vs. $\\eta$")
]
for obj in plot_objs_eta:
    plot_avg_attr_vs_var(*obj)


# reco_offline_hun_mask = get_purity_mask_hungarian(gen_b_quarks, reco_jets_offline)
# plot_kinematic_comparison(
#     bins=np.linspace(0, 500, 51), variable="pt",
#     xlabel=r"Generated b-quark $p_T$ [GeV]", title="Reco Purity vs. $p_T$ (Hungarian Matching)",
#     gen_particles=gen_b_quarks,
#     objects=[
#         ("Offline Jet (Hungarian Matching)", reco_jets_offline, reco_offline_hun_mask),
#         ("Offline Jet", reco_jets_offline, is_reco_jet_pure_offline)
#     ],
#     is_purity_plot=True
# )

In [None]:
# all histograms together for offline jets

plt.hist(ak.flatten((gen_b_quarks.eta)), bins=np.linspace(-4, 4, 81), histtype="step", label="Gen b-quarks")
plt.hist(ak.flatten((reco_jets_offline.eta)), bins=np.linspace(-4, 4, 81), histtype="step", label="Reco Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[is_reco_jet_pure_offline].eta)), bins=np.linspace(-4, 4, 81), histtype="step", label="Reco signal Offline Jets")
plt.hist(ak.flatten((gen_b_quarks[b_quarks_is_matched_offline].eta)), bins=np.linspace(-4, 4, 81), histtype="step", label="Matched Gen Jets")
plt.hist(ak.flatten((gen_b_quarks[~b_quarks_is_matched_offline].eta)), bins=np.linspace(-4, 4, 81), histtype="step", label="Unmatched Gen Jets")
# plt.hist(ak.flatten((reco_jets_offline[~is_reco_jet_pure_offline].eta)), bins=np.linspace(-4, 4, 51), histtype="step", label="Reco background Offline Jets")
plt.legend()
plt.xlabel("Eta")
plt.ylabel("Counts")
plt.show()

plt.hist(ak.flatten((gen_b_quarks.pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Gen b-quarks")
plt.hist(ak.flatten((reco_jets_offline.pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Reco Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[is_reco_jet_pure_offline].pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Reco signal Offline Jets")
plt.hist(ak.flatten((gen_b_quarks[b_quarks_is_matched_offline].pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Matched Gen Jets")
plt.hist(ak.flatten((gen_b_quarks[~b_quarks_is_matched_offline].pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Unmatched Gen Jets")
plt.legend()
plt.xlabel("pt")
plt.ylabel("Counts")
plt.show()

plt.hist(ak.flatten((gen_b_quarks.eta)), bins=np.linspace(-4, 4, 51), histtype="step", label="Gen b-quarks")
plt.hist(ak.flatten((reco_jets_offline.eta)), bins=np.linspace(-4, 4, 51), histtype="step", label="Reco Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[is_reco_jet_pure_offline].eta)), bins=np.linspace(-4, 4, 51), histtype="step", label="Reco signal Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[~is_reco_jet_pure_offline].eta)), bins=np.linspace(-4, 4, 51), histtype="step", label="Reco background Offline Jets")
plt.legend()
plt.xlabel("Eta")
plt.ylabel("Counts")
plt.show()

plt.hist(ak.flatten((gen_b_quarks.pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Gen b-quarks")
plt.hist(ak.flatten((reco_jets_offline.pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Reco Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[is_reco_jet_pure_offline].pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Reco signal Offline Jets")
plt.hist(ak.flatten((reco_jets_offline[~is_reco_jet_pure_offline].pt)), bins=np.linspace(0, 500, 101), histtype="step", label="Reco background Offline Jets")
plt.legend()
plt.xlabel("Pt")
plt.ylabel("Counts")
plt.xlim(150, 400)
plt.ylim(0, 2000)
plt.show()


plt.hist(ak.flatten(reco_jets_offline.PNetRegPtRawRes), bins=np.linspace(-1, 1, 101), histtype="step", label="Offline PNetRegPtRawRes")
plt.show()

plot_matching_criteria(gen_b_quarks, reco_jets_offline)
plot_matching_criteria(gen_b_quarks, reco_jets_offline_upart)
plot_matching_criteria(gen_b_quarks, reco_jets_l1ng)
plot_matching_criteria(gen_b_quarks, reco_jets_l1ext)

In [None]:
# Resolution and Scale Plots
# TODO: look at Prijith's work on jet energy scale and resolution
def calculate_jet_resolutions(gen_particles, reco_objects, CONFIG=None):
    """
    Calculates pT and Energy resolution using vectorized 
    Cross-Mutual Nearest Neighbor matching (1-to-1).
    
    Returns:
        gen_pt_flat:  The pT of the matched Gen particles (for the x-axis)
        pt_res_flat:  (Reco - Gen) / Gen pT resolution
        e_res_flat:   (Reco - Gen) / Gen Energy resolution
    """
    if CONFIG is None:
        with open("hh-bbbb-obj-config.json", "r") as config_file:
            CONFIG = json.load(config_file)
            
    # 1. Prepare 4-vectors for broadcasting
    # gen shape: (events, n_gen, 1)
    # reco shape: (events, 1, n_reco)
    gen_vec = gen_particles.vector[:, :, None]
    reco_vec = reco_objects.vector[:, None, :]
    
    # 2. Calculate DeltaR Matrix (All-to-All)
    delta_r_matrix = gen_vec.deltaR(reco_vec)
    
    # 3. Find Cross-Mutual Nearest Neighbors (Vectorized 1-to-1)
    
    # For every Gen particle, find index of closest Reco
    # shape: (events, n_gen)
    idx_closest_reco_to_gen = ak.argmin(delta_r_matrix, axis=2)
    
    # For every Reco jet, find index of closest Gen
    # shape: (events, n_reco)
    idx_closest_gen_to_reco = ak.argmin(delta_r_matrix, axis=1)
    
    # 4. Check the Cross-Match Condition
    # We look at the Gen particles. We ask: 
    # "Is the Reco jet closest to ME (Gen) also pointing back at ME?"
    
    # Get the index of the Gen particle that the closest Reco jet points to
    # We use the advanced indexing: array[indices]
    back_check_idx = idx_closest_gen_to_reco[idx_closest_reco_to_gen]
    
    # Create indices for comparison (0, 1, 2...)
    gen_indices = ak.local_index(gen_particles, axis=1)
    
    # A match is valid if:
    # 1. The indices match (Mutual agreement)
    # 2. The distance is within the cone size
    
    # Get the actual dR values for the closest matches
    min_dr_values = ak.min(delta_r_matrix, axis=2)
    
    is_mutual_match = (
        (back_check_idx == gen_indices) & 
        (min_dr_values < CONFIG["matching_cone_size"])
    )
    
    # 5. Extract the Matched Objects
    matched_gen = gen_particles[is_mutual_match]
    
    # We need to pull the specific Reco jets that matched these Gen particles
    # idx_closest_reco_to_gen contains the indices of the Reco jets we want
    # We apply the boolean mask to the INDICES first
    matched_reco_indices = idx_closest_reco_to_gen[is_mutual_match]
    
    # Now select those jets from the reco collection
    matched_reco = reco_objects[matched_reco_indices]
    
    # 6. Calculate Resolutions
    # Formula: (Reco - Gen) / Gen
    
    # Pt Resolution
    pt_res = (matched_reco.vector.pt - matched_gen.vector.pt) / matched_gen.vector.pt
    
    # Energy Resolution
    matched_reco.vector["energy"] = np.sqrt((matched_reco.vector.pt * np.cosh(matched_reco.vector.eta)) ** 2 + (matched_reco.vector.mass) ** 2)
    e_res = (matched_reco.vector.energy - matched_gen.vector.energy) / matched_gen.vector.energy
    
    # Flatten arrays for easy plotting
    # We return the Gen Pt as the x-axis variable
    return (
        ak.to_numpy(ak.flatten(matched_gen.vector.eta)),
        ak.to_numpy(ak.flatten(matched_gen.vector.pt)), 
        ak.to_numpy(ak.flatten(pt_res)),
        ak.to_numpy(ak.flatten(e_res))
    )

def plot_resolution_vs_var(objects):
    """
    Bins the data by Gen pT and calculates the Mean and Width (StdDev) 
    of the resolution in each bin.
    """
    plt.figure(figsize=(10, 8))
    for gen_var, resolution_values, bins, y_label, x_label, title in objects:
        bin_centers = 0.5 * (bins[1:] + bins[:-1])
        
        means = []
        widths = []
        errors = [] # Error on the width calculation
        
        # Digitize: Find which bin each event belongs to
        # indices 1 to len(bins)-1 are valid bins
        bin_indices = np.digitize(gen_var, bins)
        
        for i in range(1, len(bins)):
            # Select data for this specific bin
            data_in_bin = resolution_values[bin_indices == i]
            
            if len(data_in_bin) > 10: # Require minimum stats
                # Calculate Mean (Bias)
                mu = np.mean(data_in_bin)
                
                # Calculate Width (Resolution)
                # Standard Deviation is simple, but IQR/2 is more robust against tails
                sigma = np.std(data_in_bin) 
                
                # Error on std dev estimate approx: sigma / sqrt(2N)
                err = sigma / np.sqrt(2 * len(data_in_bin))
                
                means.append(mu)
                widths.append(sigma)
                errors.append(err)
            else:
                means.append(np.nan)
                widths.append(np.nan)
                errors.append(0)

        # --- Plotting ---
        # Top Panel: Resolution (Width)
        plt.subplot(2, 1, 1)
        plt.errorbar(bin_centers, widths, yerr=errors, fmt='o-', capsize=5, label=f'{y_label} Resolution ($\sigma$)')
        plt.ylabel(f"Resolution")
        plt.title(title)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend()
        
        # Bottom Panel: Scale/Bias (Mean)
        plt.subplot(2, 1, 2)
        plt.errorbar(bin_centers, means, fmt='s--', capsize=5, label=f'{y_label} Scale (Mean)')
        plt.axhline(0, color='black', linestyle='-', linewidth=1)
        plt.xlabel(x_label)
        plt.ylabel("Scale (Bias)")
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend()
    
    plt.tight_layout()
    plt.show()

flat_gen_eta, flat_gen_pt, flat_pt_res, flat_e_res = calculate_jet_resolutions(gen_b_quarks, reco_jets_offline)
flat_gen_eta_l1ng, flat_gen_pt_l1ng, flat_pt_res_l1ng, flat_e_res_l1ng = calculate_jet_resolutions(gen_b_quarks, reco_jets_l1ng)
flat_gen_eta_l1ext, flat_gen_pt_l1ext, flat_pt_res_l1ext, flat_e_res_l1ext = calculate_jet_resolutions(gen_b_quarks, reco_jets_l1ext)

pt_res_objects = [
    (flat_gen_pt, flat_pt_res, 
     np.linspace(0, 500, 51),
     "Offline Jet $p_T$",
     "Generated $p_T$ [GeV]",
     "Jet $p_T$ Resolution and Scale vs. Generated $p_T$"),
     (flat_gen_pt_l1ng, flat_pt_res_l1ng, 
     np.linspace(0, 500, 51),
     "L1NG Jet $p_T$",
     "Generated $p_T$ [GeV]",
     "Jet $p_T$ Resolution and Scale vs. Generated $p_T$"),
     (flat_gen_pt_l1ext, flat_pt_res_l1ext, 
     np.linspace(0, 500, 51),
     "L1Ext Jet $p_T$",
     "Generated $p_T$ [GeV]",
     "Jet $p_T$ Resolution and Scale vs. Generated $p_T$")
]
plot_resolution_vs_var(
    pt_res_objects
)

eta_res_objects = [
    (flat_gen_eta, flat_e_res, 
     np.linspace(-3, 3, 51),
     "PNet Jet $p_T$",
     "Generated $\\eta$",
     "Jet $p_T$ Resolution and Scale vs. Generated $\\eta$"),

     (flat_gen_eta_l1ng, flat_e_res_l1ng, 
     np.linspace(-3, 3, 51),
     "L1NG Jet $p_T$",
     "Generated $\\eta$",
     "Jet $p_T$ Resolution and Scale vs. Generated $\\eta$"),

     (flat_gen_eta_l1ext, flat_e_res_l1ext, 
     np.linspace(-3, 3, 51),
     "L1Ext Jet $p_T$",
     "Generated $\\eta$",
     "Jet $p_T$ Resolution and Scale vs. Generated $\\eta$")
]
plot_resolution_vs_var(
    eta_res_objects
)

plot_avg_attr_vs_var(
    reco_jets_offline, is_reco_jet_pure_offline,
    "PNetRegPtRawRes", "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "PNetRegPtRawRes",
    "Offline Jet PNetRegPtRawRes vs. $p_T$"
)
plot_avg_attr_vs_var(
    reco_jets_offline, is_reco_jet_pure_offline,
    "PNetRegPtRawRes", "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "PNetRegPtRawRes",
    "Offline Jet PNetRegPtRawRes vs. $\\eta$"
)
plot_avg_attr_vs_var(
    reco_jets_offline_upart, is_reco_jet_pure_offline,
    "UParTAK4RegPtRawRes", "pt", np.linspace(0, 500, 51),
    "Jet $p_T$ [GeV]", "UParTAK4RegPtRawRes",
    "Offline UParT Jet UParTAK4RegPtRawRes vs. $p_T$"
)
plot_avg_attr_vs_var(
    reco_jets_offline_upart, is_reco_jet_pure_offline,
    "UParTAK4RegPtRawRes", "eta", np.linspace(-3, 3, 51),
    "Jet $\\eta$", "UParTAK4RegPtRawRes",
    "Offline UParT Jet UParTAK4RegPtRawRes vs. $\\eta$"
)

In [None]:
# Storing hungarian masks to disk to speed up future analysis

import copy

events = load_and_prepare_data(
    CONFIG["file_pattern"], 
    CONFIG["tree_name"], 
    [
        "GenPart", 
        CONFIG["offline"]["collection_name"], 
        CONFIG["l1"]["collection_name"],
        "L1puppiExtJetSC4"
    ], 
    CONFIG["max_events"]
)

upart_CONFIG = copy.deepcopy(CONFIG)
upart_CONFIG["offline"]["tagger_name"] = "btagUParTAK4B"

upart_events = load_and_prepare_data(
    upart_CONFIG["file_pattern"], 
    upart_CONFIG["tree_name"], 
    [
        upart_CONFIG["offline"]["collection_name"],
    ], 
    upart_CONFIG["max_events"],
    CONFIG=upart_CONFIG
)

upart_CONFIG = copy.deepcopy(CONFIG)
upart_CONFIG["offline"]["tagger_name"] = "btagUParTAK4B"

gen_b_quarks = select_gen_b_quarks_from_higgs(events)
gen_b_quarks = gen_b_quarks[(gen_b_quarks.pt > CONFIG["gen"]["pt_cut"]) & (abs(gen_b_quarks.eta) < CONFIG["gen"]["eta_cut"])]

reco_jets_offline = events[CONFIG["offline"]["collection_name"]]
reco_jets_l1ng = events[CONFIG["l1"]["collection_name"]]
reco_jets_l1ext = events["L1puppiExtJetSC4"]
reco_jets_offline_upart = upart_events[upart_CONFIG["offline"]["collection_name"]]


print("Generating efficiency masks")
b_quarks_is_matched_offline = get_efficiency_mask_hungarian(gen_b_quarks, reco_jets_offline)
b_quarks_is_matched_offline_upart = get_efficiency_mask_hungarian(gen_b_quarks, reco_jets_offline_upart)
b_quarks_is_matched_l1ng = get_efficiency_mask_hungarian(gen_b_quarks, reco_jets_l1ng)
b_quarks_is_matched_l1ext = get_efficiency_mask_hungarian(gen_b_quarks, reco_jets_l1ext)

print("Generating purity masks")
is_reco_jet_pure_offline = get_purity_mask_hungarian(gen_b_quarks, reco_jets_offline)
is_reco_jet_pure_offline_upart = get_purity_mask_hungarian(gen_b_quarks, reco_jets_offline_upart)
is_reco_jet_pure_l1 = get_purity_mask_hungarian(gen_b_quarks, reco_jets_l1ng)
is_reco_jet_pure_l1ext = get_purity_mask_hungarian(gen_b_quarks, reco_jets_l1ext)


# print("Storing masks to disk")
mask_list = [
    b_quarks_is_matched_offline,
    b_quarks_is_matched_offline_upart,
    b_quarks_is_matched_l1ng,
    b_quarks_is_matched_l1ext,
    is_reco_jet_pure_offline,
    is_reco_jet_pure_offline_upart,
    is_reco_jet_pure_l1,
    is_reco_jet_pure_l1ext
]
key_list = [
    "b_quarks_is_matched_offline",
    "b_quarks_is_matched_offline_upart",
    "b_quarks_is_matched_l1ng",
    "b_quarks_is_matched_l1ext",
    "is_reco_jet_pure_offline",
    "is_reco_jet_pure_offline_upart",
    "is_reco_jet_pure_l1",
    "is_reco_jet_pure_l1ext"
]
with uproot.recreate(f"mask_data.root") as file:
    for mask, key in zip(mask_list, key_list):
        print(f"Storing mask with {ak.sum(mask)} true entries out of {len(ak.flatten(mask))} total entries ({ak.sum(mask)/len(ak.flatten(mask))*100:.2f}%)")
        file[key] = mask

# Reloading
print("Reloading masks from disk and verifying integrity")
for mask, key in zip(mask_list, key_list):
    with uproot.open("mask_data.root") as file:
        # .array() converts it back to an Awkward Array
        loaded_mask = file[key].array()
    assert ak.all(loaded_mask == mask), f"Loaded mask {key} does not match the original mask!"

In [None]:
mask, key = mask_list[0], key_list[0]
with uproot.open("mask_data.root") as file:
    # .array() converts it back to an Awkward Array
    loaded_mask = file[key]

file[b_quarks_is_matched_offline]

In [None]:
reco_mass_offline = reco_jets_offline.et
mass_mat_offline = reco_mass_offline[:, :, None] + reco_mass_offline[:, None, :]
# mass_mat_offline = mass_mat_offline - 120
mass_mat_offline

In [None]:
def gen_b_masks(events):
    """
    For each event, checks if at least n b-jets are found within the top k jets.
    Returns a boolean mask per event.
    """

    gen_b_from_higgs = select_gen_b_quarks_from_higgs(events)

    pt_cut_gen = CONFIG["gen"]["pt_cut"]
    eta_cut_gen = CONFIG["gen"]["eta_cut"]

    gen_b_from_higgs = gen_b_from_higgs[(gen_b_from_higgs.pt > pt_cut_gen) & (abs(gen_b_from_higgs.eta) < eta_cut_gen)]

    base_jets_offline = events[CONFIG["offline"]["collection_name"]]
    base_l1_jets = events[CONFIG["l1"]["collection_name"]]

    pt_ordered_offline = base_jets_offline[ak.argsort(base_jets_offline.vector.pt, ascending=False)]
    pt_ordered_l1 = base_l1_jets[ak.argsort(base_l1_jets.vector.pt, ascending=False)]

    b_score_ordered_offline = base_jets_offline[ak.argsort(getattr(base_jets_offline, CONFIG["offline"]["tagger_name"]), ascending=False)]
    b_score_ordered_l1 = base_l1_jets[ak.argsort(getattr(base_l1_jets, CONFIG["l1"]["tagger_name"]), ascending=False)]

    true_pt_offline_mask = get_purity_mask(gen_b_from_higgs, pt_ordered_offline)
    true_pt_l1_mask = get_purity_mask(gen_b_from_higgs, pt_ordered_l1)

    true_btag_offline_mask = get_purity_mask(gen_b_from_higgs, b_score_ordered_offline)
    true_btag_l1_mask = get_purity_mask(gen_b_from_higgs, b_score_ordered_l1)

    pt_ordered_offline = pt_ordered_offline[true_pt_offline_mask]
    pt_ordered_l1 = pt_ordered_l1[true_pt_l1_mask]

    b_score_ordered_offline = b_score_ordered_offline[true_btag_offline_mask]
    b_score_ordered_l1 = b_score_ordered_l1[true_btag_l1_mask]

    return {"gen_b_from_higgs": gen_b_from_higgs,
            "ordered": (pt_ordered_offline, pt_ordered_l1, b_score_ordered_offline, b_score_ordered_l1), 
            "masks": (true_pt_offline_mask, true_pt_l1_mask, true_btag_offline_mask, true_btag_l1_mask)
            }

def find_n_jets_rolling(gen_b_from_higgs, ordered, masks, n, k):
    top_k_pt_offline = ordered[0][:, :k]
    top_k_pt_l1 = ordered[1][:, :k]

    top_k_btag_offline = ordered[2][:, :k]
    top_k_btag_l1 = ordered[3][:, :k]

    eff_pt_offline = ak.sum(ak.num(top_k_pt_offline)) / ak.sum(ak.num(gen_b_from_higgs))
    eff_pt_l1 = ak.sum(ak.num(top_k_pt_l1)) / ak.sum(ak.num(gen_b_from_higgs))

    eff_btag_offline = ak.sum(ak.num(top_k_btag_offline)) / ak.sum(ak.num(gen_b_from_higgs))
    eff_btag_l1 = ak.sum(ak.num(top_k_btag_l1)) / ak.sum(ak.num(gen_b_from_higgs))

    true_pt_offline_mask = masks[0][:, :k]
    more_than_n_mask_offline = ak.sum(true_pt_offline_mask, axis=1) >= n
    more_than_n_efficiency_offline = ak.sum(more_than_n_mask_offline) / len(ak.flatten(gen_b_from_higgs))

    true_pt_l1_mask = masks[1][:, :k]
    more_than_n_mask_l1 = ak.sum(true_pt_l1_mask, axis=1) >= n
    more_than_n_efficiency_l1 = ak.sum(more_than_n_mask_l1) / len(ak.flatten(gen_b_from_higgs))

    true_btag_offline_mask = masks[2][:, :k]
    more_than_n_b_mask_offline = ak.sum(true_btag_offline_mask, axis=1) >= n
    more_than_n_b_efficiency_offline = ak.sum(more_than_n_b_mask_offline) / len(ak.flatten(gen_b_from_higgs))

    true_btag_l1_mask = masks[3][:, :k]
    more_than_n_b_mask_l1 = ak.sum(true_btag_l1_mask, axis=1) >= n
    more_than_n_b_efficiency_l1 = ak.sum(more_than_n_b_mask_l1) / len(ak.flatten(gen_b_from_higgs))

    return {"more_than_n_pt": (more_than_n_efficiency_offline, more_than_n_efficiency_l1),
            "more_than_n_b": (more_than_n_b_efficiency_offline, more_than_n_b_efficiency_l1),
            "pt_eff": (eff_pt_offline, eff_pt_l1),
            "btag_eff": (eff_btag_offline, eff_btag_l1)
           }


more_than_n_eff_pt_offline = []
more_than_n_eff_pt_l1 = []

more_than_n_eff_btag_offline = []
more_than_n_eff_btag_l1 = []

ordered_and_masks = gen_b_masks(events)
gen_b_from_higgs = ordered_and_masks["gen_b_from_higgs"]
ordered = ordered_and_masks["ordered"]
masks = ordered_and_masks["masks"]

n = 2

for k in range(20):
    out_dict = find_n_jets_rolling(gen_b_from_higgs, ordered, masks, n, k+1)

    more_than_n_efficiency_offline, more_than_n_efficiency_l1 = out_dict["more_than_n_pt"]
    more_than_n_eff_pt_offline.append(more_than_n_efficiency_offline)
    more_than_n_eff_pt_l1.append(more_than_n_efficiency_l1)

    more_than_n_efficiency_offline, more_than_n_efficiency_l1 = out_dict["more_than_n_b"]
    more_than_n_eff_btag_offline.append(more_than_n_efficiency_offline)
    more_than_n_eff_btag_l1.append(more_than_n_efficiency_l1)

print(f"\nPlotting Top-N Jet Efficiencies for N = {n}...")
plt.figure(figsize=(10, 6))
plt.step(range(20), more_than_n_eff_pt_offline, where='mid', label='Offline pT')
plt.step(range(20), more_than_n_eff_pt_l1, where='mid', label='L1 pT')
plt.xlabel("k (Number of Top Jets Considered)")
plt.ylabel(f"Efficiency of Finding at least {n} b-jets")
plt.title(f"Efficiency of Finding at least {n} b-jets vs. Top k Jets Considered")
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()

plt.figure(figsize=(10, 6))
plt.step(range(20), more_than_n_eff_btag_offline, where='mid', label='Offline BTag')
plt.step(range(20), more_than_n_eff_btag_l1, where='mid', label='L1 BTag')
plt.xlabel("k (Number of Top Jets Considered)")
plt.ylabel(f"Efficiency of Finding at least {n} b-jets")
plt.title(f"Efficiency of Finding at least {n} b-jets vs. Top k Jets Considered")
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend()
plt.show()

In [None]:
# TODO: fix this optmisation code

from scipy.optimize import minimize_scalar
import warnings

def calculate_efficiency_for_cut(reco_jets, gen_b_quarks, tagger_name, cut_value):
    """
    Calculates the b-tagging efficiency for a given cut on the tagger.
    Efficiency = (Number of true b-jets passing the cut) / (Total number of true b-jets)
    """
    # Find which reco jets are actually b-jets
    is_pure_mask = get_purity_mask(gen_b_quarks, reco_jets)
    true_b_jets = reco_jets[is_pure_mask]
    
    # Find how many of those true b-jets pass the cut
    passing_jets = true_b_jets[getattr(true_b_jets, tagger_name) > cut_value]
    
    # Calculate efficiency
    n_passing = ak.sum(ak.num(passing_jets, axis=0))
    n_total = ak.sum(ak.num(true_b_jets, axis=0))
    
    if n_total == 0:
        return 0.0
        
    return n_passing / n_total

def objective_function(cut_value, reco_jets, gen_b_quarks, tagger_name, target_efficiency):
    """
    Objective function for the optimizer.
    Calculates the squared difference between the current and target efficiency.
    """
    current_efficiency = calculate_efficiency_for_cut(reco_jets, gen_b_quarks, tagger_name, cut_value)
    return (current_efficiency - target_efficiency)**2

def find_cut_for_efficiency(reco_jets, gen_b_quarks, tagger_name, target_efficiency):
    """
    Uses scipy.optimize.minimize_scalar to find the b-tag cut value that
    results in the target b-tagging efficiency.
    """
    print(f"Optimizing cut for '{tagger_name}' to achieve {target_efficiency:.1%} efficiency...")
    
    # The optimizer needs a function that takes only the variable to be optimized (cut_value)
    # We use a lambda function to "freeze" the other arguments.
    obj_func = lambda cut: objective_function(cut, reco_jets, gen_b_quarks, tagger_name, target_efficiency)
    
    # We use minimize_scalar to find the minimum of our objective function.
    # This is equivalent to finding where the efficiency is closest to our target.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)
        result = minimize_scalar(
            obj_func,
            bounds=(0.0, 1.0),  # B-tag scores are between 0 and 1
            method='bounded'
        )
    
    optimal_cut = result.x
    final_efficiency = calculate_efficiency_for_cut(reco_jets, gen_b_quarks, tagger_name, optimal_cut)
    
    print(f"Optimization complete:")
    print(f"  - Target Efficiency: {target_efficiency:.3f}")
    print(f"  - Optimal Cut Value: {optimal_cut:.4f}")
    print(f"  - Resulting Efficiency: {final_efficiency:.3f}")
    print("-" * 30)
    
    return optimal_cut, final_efficiency

# --- Run the Optimization ---

# Ensure we have the necessary data loaded
if 'events' not in locals():
    print("Reloading data as 'events' was not found in the environment.")
    with open("hh-bbbb-obj-config.json", "r") as config_file:
        CONFIG = json.load(config_file)
    events = run_analysis(CONFIG)

# Select gen b-quarks with fiducial cuts
gen_b_quarks_for_opt = select_gen_b_quarks_from_higgs(events)
gen_b_quarks_for_opt = gen_b_quarks_for_opt[
    (gen_b_quarks_for_opt.pt > CONFIG["gen"]["pt_cut"]) & 
    (abs(gen_b_quarks_for_opt.eta) < CONFIG["gen"]["eta_cut"])
]

# Get the reconstructed jet collections
reco_jets_offline_for_opt = events[CONFIG["offline"]["collection_name"]]
reco_jets_l1_for_opt = events[CONFIG["l1"]["collection_name"]]

# --- Define Target Efficiencies and Run ---
target_eff_offline = 0.70  # e.g., 70%
target_eff_l1 = 0.70       # e.g., 50%

# Find the optimal cut for the Offline jets
optimal_cut_offline, final_eff_offline = find_cut_for_efficiency(
    reco_jets_offline_for_opt,
    gen_b_quarks_for_opt,
    CONFIG["offline"]["tagger_name"],
    target_eff_offline
)

# Find the optimal cut for the L1 jets
optimal_cut_l1, final_eff_l1 = find_cut_for_efficiency(
    reco_jets_l1_for_opt,
    gen_b_quarks_for_opt,
    CONFIG["l1"]["tagger_name"],
    target_eff_l1
)
