In [112]:
import pandas as pd
import torch
import numpy as np

In [120]:
data = torch.load("/Users/avencastmini/PycharmProjects/EveNet/workspace/test_data/TT1L/prediction.pt")
# data_var = torch.load("/Users/avencastmini/PycharmProjects/EveNet/workspace/test_data/TT1L/prediction.var.pt")

In [116]:
def extract_batch_assignments(batch):
    pred = batch['assignment_prediction']
    target = batch['assignment_target']
    target_mask = batch['assignment_target_mask']

    process_match = {
        'num_electron': batch['EXTRA/num_electron'].numpy().astype(np.int32),
        'num_muon': batch['EXTRA/num_muon'].numpy().astype(np.int32),
        'num_bjet': batch['EXTRA/num_bjet'].numpy().astype(np.int32),
        'num_jet': batch['EXTRA/num_jet'].numpy().astype(np.int32),
    }

    for process, target_list in target.items():
        pred_process = pred[process]['best_indices']
        mask_process = target_mask[process]

        for p_idx, (assignment_target, assignment_prediction, assignment_target_mask) in enumerate(zip(target_list, pred_process, mask_process)):
            assignment_target = assignment_target.numpy()
            assignment_prediction = assignment_prediction.numpy()
            assignment_target_mask = assignment_target_mask.numpy()

            # Matching: true if all particles in the group are correctly assigned
            matched = (assignment_target == assignment_prediction)
            matched = matched.all(axis=1)  # along particle axis

            process_match[f"{process}_{p_idx}"] = matched
            process_match[f"{process}_{p_idx}_mask"] = assignment_target_mask

    return pd.DataFrame(process_match)

def compute_efficiency_general(df, process_info, base_selection, extra_selection):
    query = f"({base_selection}) and ({extra_selection})"
    selected_df = df.query(query)

    results = []

    for process, n_particles in process_info.items():
        # ---- All Event Region (NEW, correct particle pool) ----
        eff_per_particle = []
        n_valid_per_particle = []
        n_correct_per_particle = []

        for i in range(n_particles):
            mask = selected_df[f"{process}_{i}_mask"] == 1
            if mask.sum() == 0:
                eff = np.nan
            else:
                eff = (selected_df.loc[mask, f"{process}_{i}"]).sum() / mask.sum()

            eff_per_particle.append(eff)
            n_valid_per_particle.append(mask.sum())
            n_correct_per_particle.append((selected_df.loc[mask, f"{process}_{i}"]).sum())

        # "All matched" for all-event
        if np.any(np.isnan(eff_per_particle)):
            eff_all = np.nan
        else:
            correct_flags = []
            for i in range(n_particles):
                mask = selected_df[f"{process}_{i}_mask"] == 1
                correct = selected_df[f"{process}_{i}"].where(mask, True)  # ignore invalid
                correct_flags.append(correct)

            correct_stack = pd.concat(correct_flags, axis=1)
            eff_all = (correct_stack.all(axis=1)).sum() / selected_df.shape[0]

        row_all_event = {
            "Process": process,
            "Selection": extra_selection,
            "Selection Label": "all",
            "Event Portion": selected_df.shape[0] / df.shape[0],
            "All Matched": eff_all,
        }
        for i, eff in enumerate(eff_per_particle):
            row_all_event[f"{process}_Particle_{i}_Efficiency"] = eff

        results.append(row_all_event)

        # ---- Full Event Region (classic, all masks == 1) ----
        full_mask_expr = " & ".join([f"{process}_{i}_mask" for i in range(n_particles)])
        full_event_df = selected_df.query(full_mask_expr)

        if full_event_df.empty:
            eff_per_particle_full = [np.nan] * n_particles
            eff_all_full = np.nan
            event_portion_full = np.nan
        else:
            eff_per_particle_full = [
                full_event_df[f"{process}_{i}"].sum() / full_event_df.shape[0]
                for i in range(n_particles)
            ]
            full_correct_expr = " & ".join([f"{process}_{i}" for i in range(n_particles)])
            eff_all_full = full_event_df.eval(full_correct_expr).sum() / full_event_df.shape[0]
            event_portion_full = full_event_df.shape[0] / df.shape[0]

        row_full_event = {
            "Process": process,
            "Selection": extra_selection,
            "Selection Label": "full",
            "Event Portion": event_portion_full,
            "All Matched": eff_all_full,
        }
        for i, eff in enumerate(eff_per_particle_full):
            row_full_event[f"{process}_Particle_{i}_Efficiency"] = eff

        results.append(row_full_event)

    return results

In [123]:
# Loop over all batches
dfs = []
for batch in [
    *data,
    # *data_var
]:  # assuming `data` is your list of batches
    df_ = extract_batch_assignments(batch)
    dfs.append(df_)

# Final dataframe
df = pd.concat(dfs, ignore_index=True)


# Process info: process name -> number of particles
process_info = {
    "TT1L": 2,
}

# Global selection
base_selection = "(num_electron + num_muon == 1) & (num_bjet >= 2)"

# Different n_jet regions
extra_selections = {
    "n_jet == 4": "num_bjet + num_jet == 4",
    "n_jet == 5": "num_bjet + num_jet == 5",
    "n_jet >= 6": "num_bjet + num_jet >= 6",
    "all jets": "num_bjet + num_jet >= 4",
}

# Collect results
all_results = []

for label, extra_sel in extra_selections.items():
    res = compute_efficiency_general(df, process_info, base_selection, extra_sel)
    for row in res:
        all_results.append(row)

# Final table
efficiency_table = pd.DataFrame(all_results)

print(efficiency_table.query("`Selection Label` == 'all'").to_string(index=False, float_format="%.2f"))
print(efficiency_table.query("`Selection Label` == 'full'").to_string(index=False, float_format="%.2f"))


Process               Selection Selection Label  Event Portion  All Matched  TT1L_Particle_0_Efficiency  TT1L_Particle_1_Efficiency
   TT1L num_bjet + num_jet == 4             all           0.46         0.82                        0.86                        0.79
   TT1L num_bjet + num_jet == 5             all           0.32         0.74                        0.84                        0.71
   TT1L num_bjet + num_jet >= 6             all           0.22         0.65                        0.81                        0.59
   TT1L num_bjet + num_jet >= 4             all           1.00         0.76                        0.84                        0.71
Process               Selection Selection Label  Event Portion  All Matched  TT1L_Particle_0_Efficiency  TT1L_Particle_1_Efficiency
   TT1L num_bjet + num_jet == 4            full           0.20         0.84                        0.90                        0.84
   TT1L num_bjet + num_jet == 5            full           0.18         0.72 

Process    Selection Selection Label  Event Portion  All Matched  TT1L_Particle_0_Efficiency  TT1L_Particle_1_Efficiency
   TT1L num_jet == 2             all           0.46         0.81                        0.86                        0.78
   TT1L num_jet == 3             all           0.30         0.74                        0.84                        0.71
   TT1L num_jet >= 4             all           0.19         0.66                        0.82                        0.60
   TT1L num_jet >= 2             all           0.95         0.76                        0.84                        0.71
Process    Selection Selection Label  Event Portion  All Matched  TT1L_Particle_0_Efficiency  TT1L_Particle_1_Efficiency
   TT1L num_jet == 2            full           0.20         0.82                        0.89                        0.83
   TT1L num_jet == 3            full           0.17         0.71                        0.86                        0.73
   TT1L num_jet >= 4            