In [None]:
import os
import tqdm
import glob
import numpy as np
import mplhep as hep
import awkward as ak
import boost_histogram as bh
import matplotlib.pyplot as plt
from enreg.tools.models import HPS
from enreg.tools import general as g
from enreg.tools.metrics import regression_evaluator as re
from hydra import compose, initialize
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="../enreg/config/", job_name="test_app"):
    cfg = compose(config_name="benchmarking")

hep.style.use(hep.styles.CMS)

In [None]:
ML_data = g.load_all_data("/scratch/persistent/laurits/ml-tau/20240924_lowered_recoPtCut/zh_test.parquet")

In [None]:
gen_pt = g.reinitialize_p4(ML_data.gen_jet_tau_p4s).pt

In [None]:
DMs = g.get_reduced_decaymodes(ML_data.gen_jet_tau_decaymode)

In [None]:
bin_edges = cfg.metrics.regression.ratio_plot.bin_edges.zh

In [None]:
binned_gen_tau_pt = np.digitize(gen_pt, bins=np.array(bin_edges))

In [None]:
binned_dms = [DMs[binned_gen_tau_pt == bin_idx] for bin_idx in
                         range(1, len(bin_edges))]

In [None]:
def calculate_bin_centers(edges: np.array) -> np.array:
    bin_widths = np.array([edges[i + 1] - edges[i] for i in range(len(edges) - 1)])
    bin_centers = []
    for i in range(len(edges) - 1):
        bin_centers.append(edges[i] + (bin_widths[i] / 2))
    return np.array(bin_centers), bin_widths / 2

bin_centers, bin_widths = calculate_bin_centers(bin_edges)

In [None]:
all_dms = set(DMs)
dm_fracs = {dm: [] for dm in all_dms}
for bin_values in binned_dms:
    total = len(bin_values)
    for dm in all_dms:
        dm_fracs[dm].append(sum(bin_values == dm)/total)
    

In [None]:
all_dms

In [None]:
COLORS = {
    0: "magenta",
    1: "orange",
    2: "blue",
    10: "green",
    11: "gray",
    15: "red"
}

In [None]:
bottom = np.zeros_like(bin_centers)
for dm in all_dms:
    plt.bar(bin_centers, dm_fracs[dm], align='center', width=2*bin_widths, label=f"DM{dm}", bottom=bottom, color=COLORS[dm])
    bottom += dm_fracs[dm]
plt.legend(frameon=True)
plt.xlabel(r"$p_T^{gen-\tau}$")
plt.ylabel("Relative fraction")
plt.xlim(0, 175)
plt.ylim(0, 1)
plt.savefig("/home/laurits/dm_vs_pTbin.pdf", format='pdf', bbox_inches='tight')

In [None]:
for dm in all_dms:
    print(f"DM{dm}", f"{(sum(DMs == dm) / len(DMs)) * 0.648:.4f}")