In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import awkward as ak
import numpy as np
import torch
import vector
from omegaconf import OmegaConf

import gabbro.plotting.utils as plot_utils
from gabbro.data.loading import read_jetclass_file
from gabbro.models.vqvae import VQVAELightning
from gabbro.plotting.feature_plotting import plot_features
from gabbro.utils.arrays import ak_select_and_preprocess, ak_subtract

plot_utils.set_mpl_style()

vector.register_awkward()

%matplotlib inline
# %config InlineBackend.figure_format='retina'

# torch.cuda.empty_cache()
# free, avail = torch.cuda.mem_get_info()
# clear cuda memory
# print(free / avail)

In [None]:
def to_p4(p4_obj):
    return vector.awk(
        ak.zip(
            {
                "mass": p4_obj.tau,
                "x": p4_obj.x,
                "y": p4_obj.y,
                "z": p4_obj.z,
            }
        )
    )


def p4s_relative_to_axis(p4s):
    axis = ak.sum(p4s, axis=1)
    return ak.Array(
        {
            "part_pt": p4s.pt,
            "part_etarel": p4s.deltaeta(axis),
            "part_phirel": p4s.deltaphi(axis),
            "part_mass": p4s.mass,
        }
    )


def future_file_to_ak_arrays(
    file_path: str,
    n_load: int,
    pp_dict: dict,
    pp_dict_cuts: dict,
    vqvae_model,
):
    """Loads a file from the future dataset and returns the original and tokenized arrays."""
    data = ak.from_parquet(file_path)
    p4s = to_p4(data.reco_cand_p4s[:n_load])
    p4s_rel = p4s_relative_to_axis(p4s)
    charge = data.reco_cand_charge[:n_load]
    pdg = data.reco_cand_pdg[:n_load]

    x_ak = ak.Array(
        {
            "part_etarel": p4s_rel.part_etarel,
            "part_phirel": p4s_rel.part_phirel,
            "part_pt": p4s_rel.part_pt,
            "part_mass": p4s_rel.part_mass,
            "part_charge": charge,
            "part_isElectron": ak.where(abs(pdg) == 11, 1, 0),
            "part_isMuon": ak.where(abs(pdg) == 13, 1, 0),
            "part_isPhoton": ak.where(abs(pdg) == 22, 1, 0),
            "part_isChargedHadron": ak.where(abs(pdg) == 211, 1, 0)
            + ak.where(abs(pdg) == 321, 1, 0)
            + ak.where(abs(pdg) == 2212, 1, 0),
            "part_isNeutralHadron": ak.where(abs(pdg) == 130, 1, 0)
            + ak.where(abs(pdg) == 2112, 1, 0)
            + ak.where(abs(pdg) == 0, 1, 0),
        }
    )
    # apply the cuts that were used in the VQVAE training
    x_ak_with_cuts = ak_select_and_preprocess(x_ak, pp_dict_cuts)

    # tokenize
    x_ak_tokenized = vqvae_model.tokenize_ak_array(
        ak_arr=x_ak_with_cuts,
        pp_dict=pp_dict,
        batch_size=512,
        pad_length=128,
    )
    # reconstruct
    x_ak_with_cuts_reco = vqvae_model.reconstruct_ak_tokens(
        tokens_ak=x_ak_tokenized["part_token_id"],
        pp_dict=pp_dict,
        batch_size=512,
        pad_length=128,
    )
    return x_ak_with_cuts, x_ak_with_cuts_reco


def jetclass_file_to_ak_arrays(
    jetclass_file_path: str,
    n_load: int,
    pp_dict: dict,
    pp_dict_cuts: dict,
    vqvae_model,
):
    x_ak, _, _ = read_jetclass_file(
        filepath=jetclass_file_path,
        particle_features=pp_dict.keys(),
        jet_features=None,
        labels=None,
        n_load=n_load,
    )
    # apply the cuts that were used in the VQVAE training
    x_ak_with_cuts = ak_select_and_preprocess(x_ak, pp_dict_cuts)[:, :128]
    # tokenize
    x_ak_tokenized = vqvae_model.tokenize_ak_array(
        ak_arr=x_ak_with_cuts,
        pp_dict=pp_dict,
        batch_size=512,
        pad_length=128,
    )
    # reconstruct
    x_ak_with_cuts_reco = vqvae_model.reconstruct_ak_tokens(
        tokens_ak=x_ak_tokenized["part_token_id"],
        pp_dict=pp_dict,
        batch_size=512,
        pad_length=128,
    )
    return x_ak_with_cuts, x_ak_with_cuts_reco

In [None]:
# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---


# ckpt_path = "/data/dust/user/birkjosc/beegfs/datasets/jetclass_tokenized/2024-02-19_20-54-01_nonfissile_defect_a56f_all_types/model_ckpt.ckpt"
ckpt_path = "/home/laurits/ml-tau-en-reg/enreg/omnijet_alpha/checkpoints/vqvae_32000_tokens_p3_mass_pid/model_ckpt.ckpt"
vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path)
vqvae_model.eval()  # important

cfg = OmegaConf.load(Path(ckpt_path).parent / "config.yaml")
pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)
for item in pp_dict:
    print(item, pp_dict[item])

pp_dict_cuts = {}
for feat_name in pp_dict:
    if pp_dict[feat_name] is None:
        pp_dict_cuts[feat_name] = {"larger_than": None, "smaller_than": None}
        continue
    pp_dict_cuts[feat_name] = {
        "larger_than": pp_dict[feat_name].get("larger_than", None),
        "smaller_than": pp_dict[feat_name].get("smaller_than", None),
    }

for item in pp_dict_cuts:
    print(item, pp_dict_cuts[item])

In [None]:
# load files from JetClass and Future dataset and convert them to our awkward arrays for plotting

n_load = 1000
data_dir = "/scratch/persistent/laurits/ml-tau/20240924_lowered_recoPtCut/"
zh_path = os.path.join(data_dir, "zh_train.parquet")
z_path = os.path.join(data_dir, "z_train.parquet")
qq_path = os.path.join(data_dir, "qq_train.parquet")
zh_ak_particles_original, zh_ak_particles_reco = future_file_to_ak_arrays(
    zh_path, n_load, pp_dict, pp_dict_cuts, vqvae_model
)
z_ak_particles_original, z_ak_particles_reco = future_file_to_ak_arrays(
    z_path, n_load, pp_dict, pp_dict_cuts, vqvae_model
)
qq_ak_particles_original, qq_ak_particles_reco = future_file_to_ak_arrays(
    qq_path, n_load, pp_dict, pp_dict_cuts, vqvae_model
)

jetclass_file_path_qg = (
    "/scratch/persistent/joosep/jetclass/test_20M/ZJetsToNuNu_100.root"
)
jetclass_file_path_tbqq = (
    "/scratch/persistent/joosep/jetclass/test_20M/TTBar_100.root"
)
jetclass_qg_ak_particles_original, jetclass_qg_ak_particles_reco = jetclass_file_to_ak_arrays(
    jetclass_file_path_qg, n_load, pp_dict, pp_dict_cuts, vqvae_model
)
jetclass_tbqq_ak_particles_original, jetclass_tbqq_ak_particles_reco = jetclass_file_to_ak_arrays(
    jetclass_file_path_tbqq, n_load, pp_dict, pp_dict_cuts, vqvae_model
)

In [None]:
# check that each particle has exactly one pid
z_ak_particles_pid_sum = (
    zh_ak_particles_original.part_isElectron
    + zh_ak_particles_original.part_isMuon
    + zh_ak_particles_original.part_isPhoton
    + zh_ak_particles_original.part_isChargedHadron
    + zh_ak_particles_original.part_isNeutralHadron
)
sum(
    ak.sum(z_ak_particles_pid_sum, axis=1) == ak.num(zh_ak_particles_original.part_etarel, axis=1)
) == n_load

In [None]:
# plot difference
labels_diff = {
    "part_pt": "Particle $p_T^\\mathrm{reco} - p_T^\\mathrm{original}$ [GeV]",
    "part_etarel": "Particle $\\Delta\\eta^\\mathrm{reco} - \\Delta\\eta^\\mathrm{original}$ ",
    "part_phirel": "Particle $\\Delta\\phi^\\mathrm{reco} - \\Delta\\phi^\\mathrm{original}$ ",
    "part_mass": "Particle $m^\\mathrm{reco} - m^\\mathrm{original} [GeV]$ ",
    # "part_charge": "Particle charge difference",
    # "part_isElectron": "isElectron difference",
    # "part_isMuon": "isMuon difference",
    # "part_isPhoton": "isPhoton difference",
    # "part_isChargedHadron": "isChargedHadron difference",
    # "part_isNeutralHadron": "isNeutralHadron difference",
}
fig, axarr = plot_features(
    ak_array_dict={
        "$\\mathrm{Fu}\\tau\\mathrm{ure}$ $ZH$": ak_subtract(
            zh_ak_particles_reco, zh_ak_particles_original
        ),
        # "$\\mathrm{Fu}\\tau\\mathrm{ure}$ $Z$": ak_subtract(
        #     z_ak_particles_reco, z_ak_particles_original
        # ),
        "$\\mathrm{Fu}\\tau\\mathrm{ure}$ $qq$": ak_subtract(
            qq_ak_particles_reco, qq_ak_particles_original
        ),
        "JetClass $q/g$": ak_subtract(
            jetclass_qg_ak_particles_reco, jetclass_qg_ak_particles_original
        ),
        "JetClass $t\\to bqq'$": ak_subtract(
            jetclass_tbqq_ak_particles_reco, jetclass_tbqq_ak_particles_original
        ),
    },
    names=labels_diff,
    ax_rows=2,
    bins_dict={
        "part_pt": np.linspace(-5, 5, 100),
        "part_etarel": np.linspace(-0.1, 0.1, 100),
        "part_phirel": np.linspace(-0.1, 0.1, 100),
        "part_mass": np.linspace(-0.05, 0.05, 100),
        # "part_charge": np.linspace(-1.5, 1.5, 4),
        # "part_isElectron": np.linspace(-1.5, 1.5, 4),
        # "part_isMuon": np.linspace(-1.5, 1.5, 4),
        # "part_isPhoton": np.linspace(-1.5, 1.5, 4),
        # "part_isChargedHadron": np.linspace(-1.5, 1.5, 4),
        # "part_isNeutralHadron": np.linspace(-1.5, 1.5, 4),
    },
    legend_kwargs={
        "loc": "upper left",
        "ncol": 2,
        "fontsize": 10,
    },
    ax_size=(4.2, 2),
    decorate_ax_kwargs={"yscale": 2.1},
    ylabel="Normalized",
    # legend_only_on=0,
)
fig.savefig('/home/laurits/tmp/feature_differences.pdf')
fig


In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt
import mplhep as hep


def visualize_confusion_matrix(
        histogram: np.array,
        categories: list,
        cmap: str = "Greys",
        bin_text_color: str = "r",
        y_label: str = "Predicted decay modes",
        x_label: str = "True decay modes",
        figsize: tuple = (12, 12),
):
    """Plots the confusion matrix for the classification task. Confusion
    matrix functions has the categories in the other way in order to have the
    truth on the x-axis.
    Args:
        histogram : np.array
            Histogram produced by the sklearn.metrics.confusion_matrix.
        categories : list
            Category labels in the correct order.
        cmap : str
            [default: "gray"] The colormap to be used.
        bin_text_color : str
            [default: "r"] The color of the text on bins.
        y_label : str
            [default: "Predicted"] The label for the y-axis.
        x_label : str
            [default: "Truth"] The label for the x-axis.
        figsize : tuple
            The size of the figure drawn.
    """
    fig, ax = plt.subplots(figsize=figsize)
    xbins = ybins = np.arange(len(categories) + 1)
    tick_values = np.arange(len(categories)) + 0.5
    hep.hist2dplot(histogram, xbins, ybins, flow=None, ax=ax, cbar=False)
    plt.xticks(tick_values, categories, fontsize=14, rotation=0)
    plt.yticks(tick_values + 0.2, categories, fontsize=14, rotation=90, va="center")
    plt.xlabel(f"{x_label}", fontdict={"size": 14})
    plt.ylabel(f"{y_label}", fontdict={"size": 14})
    ax.tick_params(axis="both", which="both", length=0)
    for i in range(len(ybins) - 1):
        for j in range(len(xbins) - 1):
            bin_value = histogram.T[i, j]
            ax.text(
                float(xbins[j] + 0.5),
                float(ybins[i] + 0.5),
                f"{bin_value:.2f}",
                color=bin_text_color,
                ha="center",
                va="center",
                fontweight="bold",
            )
    return fig, ax

def get_pid_classes(original_particles, reco_particles):
    original_concatenate = ak.concatenate([
        ak.unflatten(ak.flatten(original_particles.part_isElectron),counts=1),
        ak.unflatten(ak.flatten(original_particles.part_isMuon),counts=1),
        ak.unflatten(ak.flatten(original_particles.part_isPhoton),counts=1),
        ak.unflatten(ak.flatten(original_particles.part_isChargedHadron),counts=1),
        ak.unflatten(ak.flatten(original_particles.part_isNeutralHadron),counts=1)],
        axis=-1)
    reco_concatenate = ak.concatenate([
        ak.unflatten(ak.flatten(reco_particles.part_isElectron),counts=1),
        ak.unflatten(ak.flatten(reco_particles.part_isMuon),counts=1),
        ak.unflatten(ak.flatten(reco_particles.part_isPhoton),counts=1),
        ak.unflatten(ak.flatten(reco_particles.part_isChargedHadron),counts=1),
        ak.unflatten(ak.flatten(reco_particles.part_isNeutralHadron),counts=1)],
        axis=-1)
    reco_classes = ak.argmax(reco_concatenate, axis=-1)
    original_classes = ak.argmax(original_concatenate, axis=-1)
    return original_classes, reco_classes



# Tau dataset

In [None]:
ft_charges_original = ak.flatten(zh_ak_particles_original.part_charge).to_numpy()
ft_charges_reco = ak.round(ak.flatten(zh_ak_particles_reco.part_charge)).to_numpy()
ft_charges_reco[ft_charges_reco < -1] = -1

normalized_confusion_matrix = metrics.confusion_matrix(ft_charges_original, ft_charges_reco, normalize="pred")
fig, ax = visualize_confusion_matrix(
    histogram=normalized_confusion_matrix,
    categories=["-1", "0", "+1"],
    cmap=None,
    bin_text_color="r",
    y_label="Predicted charge",
    x_label="True charge",
    figsize=(5, 5)
)
fig.savefig('/home/laurits/tmp/ft_charges.pdf')
fig

In [None]:
original_classes_ft, reco_classes_ft = get_pid_classes(zh_ak_particles_original, zh_ak_particles_reco)
normalized_confusion_matrix = metrics.confusion_matrix(original_classes_ft, reco_classes_ft, normalize="pred")
fig, ax = visualize_confusion_matrix(
    histogram=normalized_confusion_matrix,
    categories=[r"$e^{\pm}$",r"$\mu^{\pm}$", r"$\gamma$", r"$h^{\pm}$", r"$h^{0}$"],
    cmap=None,
    bin_text_color="r",
    y_label="Predicted PID",
    x_label="True PID",
    figsize=(5, 5)
)
fig.savefig('/home/laurits/tmp/ft_pid.pdf')
fig

# JetClass dataset

In [None]:
jc_charges_original = ak.flatten(jetclass_tbqq_ak_particles_original.part_charge).to_numpy()
jc_charges_reco = ak.round(ak.flatten(jetclass_tbqq_ak_particles_reco.part_charge)).to_numpy()
jc_charges_reco[jc_charges_reco < -1] = -1

In [None]:
normalized_confusion_matrix = metrics.confusion_matrix(jc_charges_original, jc_charges_reco, normalize="pred")
fig, ax = visualize_confusion_matrix(
    histogram=normalized_confusion_matrix,
    categories=["-1", "0", "+1"],
    cmap=None,
    bin_text_color="r",
    y_label="Predicted charge",
    x_label="True charge",
    figsize=(5, 5)
)
fig.savefig('/home/laurits/tmp/jc_charges.pdf')
fig

In [None]:
original_classes_jc, reco_classes_jc = get_pid_classes(jetclass_tbqq_ak_particles_original, jetclass_tbqq_ak_particles_reco)
jc_normalized_confusion_matrix = metrics.confusion_matrix(original_classes_jc, reco_classes_jc, normalize="pred")
fig, ax = visualize_confusion_matrix(
    histogram=jc_normalized_confusion_matrix,
    categories=[r"$e^{\pm}$",r"$\mu^{\pm}$", r"$\gamma$", r"$h^{\pm}$", r"$h^{0}$"],
    cmap=None,
    bin_text_color="r",
    y_label="Predicted PID",
    x_label="True PID",
    figsize=(5, 5)
)
fig.savefig('/home/laurits/tmp/jc_pid.pdf')
fig

# Plot plain distributions of features

In [None]:
import matplotlib.pyplot as plt
f, ax = plt.subplots(figsize=(16,9))
ax.hist(ak.flatten(zh_ak_particles_reco.part_mass), density=True, label="Reco")
ax.hist(ak.flatten(zh_ak_particles_original.part_mass), density=True, label="Original")
plt.legend()
plt.xscale('symlog')
f

In [None]:
f, ax = plt.subplots(figsize=(16,9))
ax.hist(ak.flatten(qq_ak_particles_reco.part_mass), density=True, label="Reco")
ax.hist(ak.flatten(qq_ak_particles_original.part_mass), density=True, label="Original")
plt.legend()
plt.xscale('symlog')
f


In [None]:
ak.sum(jetclass_qg_ak_particles_original.part_mass < 0)

In [None]:
f, ax = plt.subplots(figsize=(16,9))
ax.hist(ak.flatten(jetclass_qg_ak_particles_reco.part_mass), density=True, label="Reco")
ax.hist(ak.flatten(nii.part_mass), density=True, label="Original")
plt.legend()
plt.xscale('symlog')
f

In [None]:
f, ax = plt.subplots(figsize=(16,9))
ax.hist(ak.flatten(jetclass_tbqq_ak_particles_reco.part_mass), density=True, label="Reco")
ax.hist(ak.flatten(jetclass_tbqq_ak_particles_original.part_mass), density=True, label="Original")
plt.legend()
plt.xscale('symlog')
f