In [None]:
import pandas as pd
import torch
import numpy as np
import vector

In [None]:
data = torch.load("/Users/avencastmini/PycharmProjects/EveNet/workspace/test_data/nu2flow/prediction-mg5.pt")

In [None]:
def process_batch(batch):
    data_ = {}
    for label, group in [('predict', batch['predict']), ('target', batch['target'])]:
        for key, tensor in group.items():
            if key not in ['log_mass', 'log_pt', 'eta', 'phi']:
                continue

            for i in range(tensor.shape[1]):
                col_name = (label, f"{key.replace('log_', '')}_{i}")

                if 'log_pt' in key:
                    data_[col_name] = np.exp(tensor[:, i].numpy())
                elif 'log_mass' in key:
                    data_[col_name] = np.exp(tensor[:, i].numpy()) - 1
                else:
                    data_[col_name] = tensor[:, i].numpy()

    # Create MultiIndex DataFrame
    df = pd.DataFrame(data_)
    df.columns = pd.MultiIndex.from_tuples(df.columns)
    return df


def build_vector_array(df, label):
    # Stack pt, eta, phi, mass for neutrinos 0 and 1
    pts = np.stack([df[(label, f"pt_{i}")].values for i in range(2)], axis=1)
    etas = np.stack([df[(label, f"eta_{i}")].values for i in range(2)], axis=1)
    phis = np.stack([df[(label, f"phi_{i}")].values for i in range(2)], axis=1)
    masses = np.stack([df[(label, f"mass_{i}")].values for i in range(2)], axis=1)

    # Now build the vector array (num_events, 2)
    vecs = vector.array({
        "pt": pts,
        "eta": etas,
        "phi": phis,
        "mass": masses
    })
    return vecs


def process_data(data):
    dfs = []
    for batch in data:
        # Process the batch
        df_ = process_batch(batch['neutrinos'])
        dfs.append(df_)

    final_df = pd.concat(dfs, ignore_index=True)

    nu_pred = build_vector_array(final_df, "predict")
    nu_truth = build_vector_array(final_df, "target")

    return {
        "predict": nu_pred,
        "target": nu_truth,
    }

In [None]:
nu = process_data(data)

from downstreams.plotting.kinematic_comparison import plot_kinematics_comparison
from matplotlib import pyplot as plt

cfg = {
    "variables": ["pt", "eta", "phi"],
    "x_labels": [r"$p_T^{\nu}$ [GeV]", r"$\eta^{\nu}$", r"$\phi^{\nu}$"],
    "kin_range": {"pt": (0, 350), "eta": (-np.pi * 1.5, np.pi * 1.5), "phi": (-np.pi , np.pi)},
    "labels": [r"$\nu$ from $(top^+)$", r"$\nu$ from $(top^-)$"],
    "colors": ['#5bb5ac', '#de526c'],
    "file_keys": ["nu_p", "nu_m"],
    "log_y": [True, False, False],
}

for i, var in enumerate(cfg["variables"]):
    fig, axs = plt.subplots(
        3, 1, figsize=(10, 16),
        gridspec_kw={'height_ratios': [3, 1, 2], 'hspace': 0.0},
        sharex=True
    )

    plot_kinematics_comparison(
        axs=axs,
        kin=[getattr(nu['predict'][..., 0], var), getattr(nu['predict'][..., 1], var)],
        truth_kin=[getattr(nu['target'][..., 0], var), getattr(nu['target'][..., 1], var)],
        bins=50,
        kin_range=cfg["kin_range"][var],
        labels=cfg["labels"],
        colors=cfg["colors"],
        xlabel=cfg["x_labels"][i],
        normalize_col=cfg.get("normalize_col", False),
        log_z=cfg.get("log_z", True),
        log_y=cfg.get("log_y", [False, False, False])[i],
    )

plt.tight_layout()
plt.show()