In [1]:
import pandas as pd
import torch
import numpy as np
import vector
from pathlib import Path
import os
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
from downstreams.plotting.kinematic_comparison import plot_kinematics_comparison

In [2]:
def process_batch(batch):
    data_ = {}
    for label, group in [('predict', batch['predict']), ('target', batch['target'])]:
        for key, tensor in group.items():
            for i in range(tensor.shape[1]):
                col_name = (label, f"{key.replace('log_', '')}_{i}")

                if 'log_pt' in key:
                    data_[col_name] = np.exp(tensor[:, i].numpy())
                else:
                    data_[col_name] = tensor[:, i].numpy()

    # Create MultiIndex DataFrame
    df = pd.DataFrame(data_)
    df.columns = pd.MultiIndex.from_tuples(df.columns)
    return df


def extract_neutrinos(df, label):
    # Stack pt, eta, phi, mass for neutrinos 0 and 1
    pts = np.stack([df[(label, f"pt_{i}")].values for i in range(2)], axis=1)
    etas = np.stack([df[(label, f"eta_{i}")].values for i in range(2)], axis=1)
    phis = np.stack([df[(label, f"phi_{i}")].values for i in range(2)], axis=1)

    # Now build the vector array (num_events, 2)
    vecs = vector.array({
        "pt": pts,
        "eta": etas,
        "phi": phis,
        "mass": np.zeros_like(pts),
    })
    return vecs


def extract_particles(df, prefix1, prefix2):
    """
    Builds a vector array of shape (num_events, 2) by combining two particle sources.
    Each is extracted using pt/eta/phi/mass or energy from the DataFrame.
    """

    def get_components(prefix):
        if f"{prefix}/mass" in df.columns:
            return (
                df[f"{prefix}/pt"].values,
                df[f"{prefix}/eta"].values,
                df[f"{prefix}/phi"].values,
                df[f"{prefix}/mass"].values,
                "mass"
            )
        elif f"{prefix}/energy" in df.columns:
            return (
                df[f"{prefix}/pt"].values,
                df[f"{prefix}/eta"].values,
                df[f"{prefix}/phi"].values,
                df[f"{prefix}/energy"].values,
                "energy"
            )
        else:
            raise ValueError(f"Missing mass or energy columns for prefix: {prefix}")

    # Get components and type for each particle
    pt1, eta1, phi1, m1, type1 = get_components(prefix1)
    pt2, eta2, phi2, m2, type2 = get_components(prefix2)

    if type1 != type2:
        raise ValueError(f"Inconsistent 4-momentum components: {prefix1} uses {type1}, {prefix2} uses {type2}")

    return vector.arr({
        "pt": np.stack([pt1, pt2], axis=1),
        "eta": np.stack([eta1, eta2], axis=1),
        "phi": np.stack([phi1, phi2], axis=1),
        type1: np.stack([m1, m2], axis=1),
    })


def process_data(data, baseline_selections):
    dfs = []
    for batch in data:
        # Process the batch
        df_ = process_batch(batch['neutrinos'])
        df_extra = {extra_key.replace('EXTRA/', ''): batch[extra_key] for extra_key in batch.keys() if
                    'EXTRA/' in extra_key}
        df_extra = pd.DataFrame(df_extra)

        dfs.append(pd.concat([df_, df_extra], axis=1))

    final_df = pd.concat(dfs, ignore_index=True)
    final_df = final_df.query(baseline_selections)

    nu_pred = extract_neutrinos(final_df, "predict")
    nu_truth = extract_neutrinos(final_df, "target")

    particles = {
        "predict": nu_pred,
        "target": nu_truth,

        "b": extract_particles(final_df, "t1/b", "t2/b"),
        "lepton": extract_particles(final_df, "t1/l", "t2/l"),
        "truth_top": extract_particles(final_df, "truth_t1/t", "truth_t2/t"),
        "truth_W": extract_particles(final_df, "truth_t1/W", "truth_t2/W"),
        "truth_lepton": extract_particles(final_df, "truth_t1/l", "truth_t2/l"),
    }

    # calculate reconstructed W
    particles['W'] = particles['lepton'] + particles['predict']
    # calculate reconstructed top
    particles['top'] = particles['b'] + particles['W']
    # replace the truth W with reconstructed lepton + truth neutrino
    particles['plot_truth_W'] = particles['lepton'] + particles['target']
    # replace the truth Top with reconstructed lepton + truth neutrino + b
    particles['plot_truth_top'] = particles['b'] + particles['lepton'] + particles['target']

    return particles

In [11]:
p_dir = Path(os.getcwd()) / "aux"

data = torch.load("/Users/avencastmini/PycharmProjects/EveNet/workspace/test_data/nu2flow/prediction-mg5-300.pt")
nu = process_data(
    data,
    baseline_selections="(num_bjet == 2) and `t1/b/pt` > 0 and `t2/b/pt` > 0 and `t1/l/pt` > 0 and `t2/l/pt` > 0"
    # baseline_selections="(num_bjet >= 0)"
    # baseline_selections="`t1/b/pt` > 25 and `t2/b/pt` > 25 and `t1/l/pt` > 15 and `t2/l/pt` > 15",
)

print(len(nu["predict"]))

26729


In [12]:
# calculate observables
def build_observables(top: vector.MomentumNumpy4D, lepton: vector.MomentumNumpy4D):
    # tt̄ system
    ttbar = top.sum(axis=1)

    df = pd.DataFrame({
        "m_tt": ttbar.mass,
        "pt_tt": ttbar.pt,
        "y_tt": ttbar.rapidity,
        "pt_t1": getattr(top[:, 0], "pt"),
        "pt_t2": getattr(top[:, 1], "pt"),
        "dphi_ll": getattr(lepton[:, 0], "deltaphi")(lepton[:, 1]) / np.pi,
    })

    return df


df_truth = build_observables(nu["truth_top"], nu["truth_lepton"])
df_reco_truthnu = build_observables(nu["plot_truth_top"], nu["lepton"])
df_reco_prednu = build_observables(nu["top"], nu["lepton"])

# Rename columns to avoid collisions
df_truth = df_truth.add_suffix("_truth")
df_reco_truthnu = df_reco_truthnu.add_suffix("_reco_truthnu")
df_reco_prednu = df_reco_prednu.add_suffix("_reco_prednu")

# Combine into one DataFrame (all same length, aligned row by row)
df_all = pd.concat([df_truth, df_reco_truthnu, df_reco_prednu], axis=1)

In [13]:
# Generalized response matrix builder
def build_16x16_response(
        df,
        mtt_truth_col,
        mtt_reco_col,
        var_truth_col,
        var_reco_col,
        mtt_bins,
        var_bins
):
    nbins = len(var_bins) - 1
    response = np.zeros((nbins * nbins, nbins * nbins))

    for i in range(nbins):  # reco mtt bin
        reco_mask = (df[mtt_reco_col] >= mtt_bins[i]) & (df[mtt_reco_col] < mtt_bins[i + 1])
        for j in range(nbins):  # truth mtt bin
            truth_mask = (df[mtt_truth_col] >= mtt_bins[j]) & (df[mtt_truth_col] < mtt_bins[j + 1])
            mask = reco_mask & truth_mask
            df_sel = df[mask]

            h2d, _, _ = np.histogram2d(
                df_sel[var_reco_col],
                df_sel[var_truth_col],
                bins=[var_bins, var_bins]
            )

            row_start = i * nbins
            col_start = j * nbins
            response[row_start:row_start + nbins, col_start:col_start + nbins] = h2d

    return response


# Generalized plotting function
def plot_block_response(
        response,
        var_labels,
        mtt_labels,
        title=None,
        xlabel="Truth $m_{tt}$ bin",
        ylabel="Reco variable bin",
        save_name=None,
):
    fig, ax = plt.subplots(figsize=(7, 7))

    truth_sums = response.sum(axis=0, keepdims=True)
    normed = 100 * np.divide(response, truth_sums, where=truth_sums != 0)

    im = ax.imshow(normed, origin='lower', cmap='Blues', vmin=0, vmax=100)

    # Annotate matrix
    for i in range(normed.shape[0]):
        for j in range(normed.shape[1]):
            val = normed[i, j]
            if val > 1:
                ax.text(j, i, f"{val:.0f}", ha='center', va='center', fontsize=7)

    # Grid lines
    block_size = len(var_labels)
    for i in range(0, response.shape[0], block_size):
        ax.axhline(i - 0.5, color='k', linestyle='--', lw=1)
        ax.axvline(i - 0.5, color='k', linestyle='--', lw=1)
    ax.axhline(response.shape[0] - 0.5, color='k', linestyle='--', lw=1)
    ax.axvline(response.shape[1] - 0.5, color='k', linestyle='--', lw=1)

    # X ticks: 1 per mtt bin (block center)
    xticks = [i * block_size + block_size / 2 - 0.5 for i in range(len(mtt_labels))]
    ax.set_xticks(xticks)
    ax.set_xticklabels(mtt_labels, fontsize=10)

    # Y ticks: 1 per var bin
    yticks = list(range(block_size * len(mtt_labels)))
    ytick_labels = var_labels * len(mtt_labels)
    ax.set_yticks(yticks)
    ax.set_yticklabels(ytick_labels, fontsize=8)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.05)  # shrink width to 3%
    cbar = plt.colorbar(im, cax=cax)
    cbar.set_label("Migration [%]", fontsize=9)

    # Trace fraction
    trace = np.trace(response)
    total = response.sum()
    trace_frac = trace / total if total > 0 else 0
    ax.text(1.0, 1.02, f"trace fraction = {trace_frac:.2f}", transform=ax.transAxes,
            ha='right', fontsize=10)

    plt.tight_layout()
    # plt.show()
    if save_name:
        if not os.path.exists(p_dir / "response"):
            os.makedirs(p_dir / "response")
        plt.savefig(p_dir / "response" / save_name)
    plt.close()


# Define bin edges
bins_mtt = [0, 400, 500, 800, np.inf]
mtt_labels = [
    r"$m_{t\bar{t}} < 400$",
    r"$400 < m_{t\bar{t}} < 500$",
    r"$500 < m_{t\bar{t}} < 800$",
    r"$m_{t\bar{t}} \geq 800$"
]

for scenario, s_name in zip(["reco_truthnu", "reco_prednu"], ["Truth_Nu", "Pred_Nu"]):
    variable_configs = [
        {
            "name": r"$\Delta\phi(\ell^+,\ell^-) / \pi$ [rad/$\pi$]",
            "truth_col": "dphi_ll_truth",
            "reco_col": f"dphi_ll_{scenario}",
            "bins": [0.0, 0.25, 0.5, 0.75, 1.0],
            "labels": ["0–0.25", "0.25–0.5", "0.5–0.75", "0.75–1.0"]
        },
        {
            "name": r"$p_T^t$ [GeV]",
            "truth_col": "pt_t1_truth",
            "reco_col": f"pt_t1_{scenario}",
            "bins": [0, 75, 125, 175, np.inf],
            "labels": ["<75", "75–125", "125–175", "≥175"]
        },
        {
            "name": r"$p_{T}^{t\bar{t}}$ [GeV]",
            "truth_col": "pt_tt_truth",
            "reco_col": f"pt_tt_{scenario}",
            "bins": [0, 70, 140, 200, np.inf],
            "labels": ["<70", "70–140", "140–200", "≥200"]
        },
        {
            "name": r"$y_{t\bar{t}}$",
            "truth_col": "y_tt_truth",
            "reco_col": f"y_tt_{scenario}",
            "bins": [-np.inf, -1.0, 0.0, 1.0, np.inf],
            "labels": ["<–1", "–1–0", "0–1", ">1"]
        }
    ]

    # Build and plot all variables
    for var in variable_configs:
        response = build_16x16_response(
            df_all,
            mtt_truth_col="m_tt_truth",
            mtt_reco_col="m_tt_reco_prednu",
            var_truth_col=var["truth_col"],
            var_reco_col=var["reco_col"],
            mtt_bins=bins_mtt,
            var_bins=var["bins"]
        )

        plot_block_response(
            response,
            title=f"EveNet: {s_name}",
            var_labels=var["labels"],
            mtt_labels=mtt_labels,
            xlabel=f"Detector-level {var['name']}",
            ylabel=f"Parton-level {var['name']}",
            save_name=f"{s_name}_{var['truth_col'].replace('_truth', '')}.pdf"
        )

In [14]:
named_configs = {
    "neutrino": {
        "variables": ["pt", "eta", "phi"],
        "x_labels": [r"$p_T^{\nu}$ [GeV]", r"$\eta^{\nu}$", r"$\phi^{\nu}$"],
        "kin_range": {"pt": (0, 350), "eta": (-np.pi * 1.5, np.pi * 1.5), "phi": (-np.pi, np.pi)},
        "labels": [r"$\nu$ from $(top^+)$", r"$\nu$ from $(top^-)$"],
        "colors": ['#5bb5ac', '#de526c'],
        "columns": ['predict', 'target'],
        "log_y": [True, False, False],
    },
    "top": {
        "variables": ["pt", "eta", "phi", "mass"],
        "x_labels": [r"$p_T^{t}$ [GeV]", r"$\eta^{t}$", r"$\phi^{t}$", r"$mass^{t}$ [GeV]"],
        "kin_range": {"pt": (0, 600), "eta": (-np.pi * 1.5, np.pi * 1.5), "phi": (-np.pi, np.pi), "mass": (100, 240)},
        "labels": [r"$(top^+)$", r"$(top^-)$"],
        "colors": ['#5bb5ac', '#de526c'],
        "columns": ['top', 'plot_truth_top'],
        "log_y": [True, False, False, False],
    },
    "W": {
        "variables": ["pt", "eta", "phi", "mass"],
        "x_labels": [r"$p_T^{W}$ [GeV]", r"$\eta^{W}$", r"$\phi^{W}$", r"$mass^{W}$ [GeV]"],
        "kin_range": {"pt": (0, 350), "eta": (-np.pi * 1.5, np.pi * 1.5), "phi": (-np.pi, np.pi), "mass": (40, 120)},
        "labels": [r"$(W^+)$", r"$(W^-)$"],
        "colors": ['#5bb5ac', '#de526c'],
        "columns": ['W', 'plot_truth_W'],
        "log_y": [True, False, False, False],
    }
}

for particle, cfg in named_configs.items():

    for i, var in enumerate(cfg["variables"]):
        fig, axs = plt.subplots(
            3, 1, figsize=(10, 16),
            gridspec_kw={'height_ratios': [3, 1, 2], 'hspace': 0.0},
            sharex=True
        )

        plot_kinematics_comparison(
            axs=axs,
            kin=[getattr(nu[cfg['columns'][0]][..., 0], var), getattr(nu[cfg['columns'][0]][..., 1], var)],
            truth_kin=[getattr(nu[cfg['columns'][1]][..., 0], var), getattr(nu[cfg['columns'][1]][..., 1], var)],
            # kin=[
            #     np.concatenate([
            #         getattr(nu[cfg['columns'][0]][..., 0], var),
            #         getattr(nu[cfg['columns'][0]][..., 1], var)
            #     ], axis=0),
            #     np.concatenate([
            #         getattr(nu[cfg['columns'][0]][..., 0], var),
            #         getattr(nu[cfg['columns'][0]][..., 1], var)
            #     ], axis=0)
            # ],
            # truth_kin=[
            #     np.concatenate([
            #         getattr(nu[cfg['columns'][1]][..., 0], var),
            #         getattr(nu[cfg['columns'][1]][..., 1], var)
            #     ], axis=0),
            #     np.concatenate([
            #         getattr(nu[cfg['columns'][1]][..., 0], var),
            #         getattr(nu[cfg['columns'][1]][..., 1], var)
            #     ], axis=0)
            # ],
            bins=100,
            kin_range=cfg["kin_range"][var],
            labels=cfg["labels"],
            colors=cfg["colors"],
            xlabel=cfg["x_labels"][i],
            normalize_col=cfg.get("normalize_col", False),
            log_z=cfg.get("log_z", True),
            log_y=cfg.get("log_y", [False, False, False, False])[i],
            c_percent=np.array([10, 100])
        )

        plt.tight_layout()
        if not os.path.exists(p_dir / "kinematics"):
            os.makedirs(p_dir / "kinematics")
        plt.savefig(p_dir / "kinematics" / f"{particle}_{var}.pdf")
        plt.close(fig)

  ax.contourf(X, Y, Z, levels=levels, cmap=contour_colors[i], alpha=0.5, norm=mcolors.LogNorm())
  ax.contourf(X, Y, Z, levels=levels, cmap=contour_colors[i], alpha=0.5, norm=mcolors.LogNorm())
  ax.contourf(X, Y, Z, levels=levels, cmap=contour_colors[i], alpha=0.5, norm=mcolors.LogNorm())
  ax.contourf(X, Y, Z, levels=levels, cmap=contour_colors[i], alpha=0.5, norm=mcolors.LogNorm())
