In [None]:
from __future__ import annotations

import math
import os
from os import listdir
from os.path import exists

import matplotlib.pyplot as plt
import mplhep as hep
import numpy as np
import pandas as pd
from hist import Hist
from hist.intervals import clopper_pearson_interval

plt.rcParams.update({"font.size": 11})
plt.style.use(hep.style.CMS)

In [None]:
def check_selector(sample: str, selector: str | list[str]):
    if isinstance(selector, (list, tuple)):
        for s in selector:
            if s.startswith("*"):
                if s[1:] in sample:
                    return True
            else:
                if sample.startswith(s):
                    return True
    else:
        if selector.startswith("*"):
            if selector[1:] in sample:
                return True
        else:
            if sample.startswith(selector):
                return True

    return False

In [None]:
# year = "2022"
# yearlabel = "7.97"
# year = "2022EE-noE"
# yearlabel = "20"
year = "2022EE"
yearlabel = "26.3"

# data_dir = f"/eos/uscms/store/user/cmantill/bbbb/trigger_boosted/Aug15/"
# data_dir = f"/eos/uscms/store/user/cmantill/bbbb/trigger_boosted/Nov7_v12/"
data_dir = "/eos/uscms/store/user/cmantill/bbbb/trigger_boosted/23Nov9_v11_v11/"

samples = {
    "2022EE": {
        "data": ["Run2022E", "Run2022F", "Run2022G"],
        "ttbar": ["TTtoLNu2Q"],
    },
    "2022EE-noE": {
        "data": ["Run2022F", "Run2022G"],
        "ttbar": ["TTtoLNu2Q"],
    },
    "2022": {
        "data": ["Run2022C_single", "Run2022C", "Run2022D"],
        "ttbar": ["TTtoLNu2Q"],
    },
}[year]

y = year
if year == "2022EE-noE":
    y = "2022EE"


full_samples_list = listdir(f"{data_dir}/{y}")
events_dict = {}
for label, selector in samples.items():
    events_dict[label] = []
    print(selector)
    for sample in full_samples_list:
        if not check_selector(sample, selector):
            continue
        if not exists(f"{data_dir}/{y}/{sample}/parquet"):
            print(f"No parquet file for {sample}")
            continue

        events = pd.read_parquet(f"{data_dir}/{y}/{sample}/parquet", columns=None)
        not_empty = len(events) > 0
        if not_empty:
            events_dict[label].append(events)

        print(f"Loaded {sample: <50}: {len(events)} entries")

    if len(events_dict[label]):
        events_dict[label] = pd.concat(events_dict[label])
    else:
        del events_dict[label]

In [None]:
events_dict["data"].columns.tolist()

In [None]:
events_dict["ttbar"].QuadPFJet70_50_40_35_PFBTagParticleNet_2BTagSum0p65

In [None]:
to_loop = {
    "data": f"Data {year}",
    "ttbar": r"$t\bar{t}$ (semi-lep)",
}

mreg_bins_fine = (15, 45, 300)
mreg_bins = [45, 60, 75, 90, 105, 120, 135, 150, 165, 180]
msd_bins_fine = (20, 30, 250)
msd_bins = mreg_bins

ht_bins_fine = (25, 200, 3000)

pt_bins_fine = (25, 200, 1000)
pt_bins = [250, 275, 300, 325, 350, 375, 400, 425, 450, 475, 500, 550, 600]

xbb_bins = [0.0, 0.8, 0.9, 0.95, 0.98, 1.0]
xbb_bins_fine = (20, 0, 1)
arctanh_xbb_bins_fine = (20, 0, 4.5)

hmreg = Hist.new.Var(mreg_bins, name="jet0mreg", label="fj$^0$ $m_{reg}$ (GeV)").Double()
hmsd = Hist.new.Reg(*msd_bins_fine, name="jet0msd", label="fj$^0$ $m_{SD}$ (GeV)").Double()
hpt = Hist.new.Var(pt_bins, name="jet0pt", label="fj$^0$ $p_T$ (GeV)").Double()
hptmsd = (
    Hist.new.Var(pt_bins, name="jet0pt", label="fj$^0$ $p_T$ (GeV)")
    .Var(msd_bins, name="jet0msd", label="fj$^0$ $m_{SD}$ (GeV)")
    .Double()
)
hht = Hist.new.Reg(*ht_bins_fine, name="ht", label="HT (GeV)").Double()
hxbb = (
    Hist.new.Reg(*xbb_bins_fine, name="jet0txbb", label="fj$^0$ $T_{Xbb}$ Score")
    .Reg(*arctanh_xbb_bins_fine, name="jet0txbbtan", label="fj$^0$ atanh($T_{Xbb}$ Score)")
    .Double()
)

trigger_dict = {
    "Resolved": (
        [
            "QuadPFJet70_50_40_35_PFBTagParticleNet_2BTagSum0p65",
        ],
        "QuadPFJet",
    ),
    "HT": (["PFHT1050"], "HT1050"),
    "BoostedJet": (
        [
            "AK8PFJet425_SoftDropMass40",
        ],
        "PFJet425_MSD40",
    ),
    "BoostedDiJet": (
        [
            "AK8DiPFJet250_250_MassSD50",
            "AK8DiPFJet260_260_MassSD30",
        ],
        "DiPFJet250-MSD50 |\n DiPFJet260-MSD30",
    ),
    "BoostedHbb": (
        [
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
        ],
        "PFJet425_MSD40_Xbb0p35",
    ),
    "Combined": (
        [
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet425_SoftDropMass40",
            "AK8DiPFJet250_250_MassSD50",
            "AK8DiPFJet260_260_MassSD30",
        ],
        "PFJet425_MSD40_Xbb0p35 |\n PFJet425_MSD40 |\n DiPFJet250-MSD50 |\n DiPFJet260-MSD30",
    ),
    "Combined_noquad": (
        [
            "PFHT1050",
            "AK8DiPFJet250_250_MassSD50",
            "AK8DiPFJet260_260_MassSD30",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet425_SoftDropMass40",
        ],
        "PFJet425_MSD40_Xbb0p35 |\n PFJet425_MSD40 |\n DiPFJet250-MSD50 |\n DiPFJet260-MSD30 |\n HT1050",
    ),
    "Combined_all": (
        [
            "QuadPFJet70_50_40_35_PFBTagParticleNet_2BTagSum0p65",
            "PFHT1050",
            "AK8DiPFJet250_250_MassSD50",
            "AK8DiPFJet260_260_MassSD30",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet425_SoftDropMass40",
        ],
        "PFJet425_MSD40_Xbb0p35 |\n PFJet425_MSD40 |\n DiPFJet250-MSD50 |\n DiPFJet260-MSD30  |\n HT1050 |\n QuadPFJet",
    ),
    "Combined_ht": (
        [
            "PFHT1050",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet425_SoftDropMass40",
        ],
        "PFJet425_MSD40_Xbb0p35 |\n PFJet425_MSD40 |\n HT1050",
    ),
    "Combined_nodijet": (
        [
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet425_SoftDropMass40",
        ],
        "PFJet425_MSD40_Xbb0p35 |\n PFJet425_MSD40",
    ),
}

configurations = [
    (
        "one_jet",  # selection
        r"$p_T^0>250$ & mass$^0>40$",  # label
        "ptmsd_onejet",  # label to save figure
        "hptmsd",  # histogram to fill
        "jet0pt",  # variable to project
    ),
]
add_ht = True
if add_ht:
    configurations.extend(
        [
            (
                "one_jet",
                r"$p_T^0>250$ & mass$^0>40$",
                "ht_onejet",
                "hht",
                None,
            ),
            (
                "one_jet_08",
                r"$p_T^0>250$ & mass$^0>40 & Xbb>0.8$",
                "ht_onejet_xbb08",
                "hht",
                None,
            ),
            (
                "one_jet_09",
                r"$p_T^0>400$ & Xbb>0.9",
                "msd_onejet_xbb09",
                "hmsd",
                None,
            ),
            (
                "one_jet_300_09",
                r"$p_T^0>300$ & Xbb>0.9",
                "msd_onejet_xbb09_pt300",
                "hmsd",
                None,
            ),
        ]
    )

# for triggers that include Xbb in leg
configurations_bb = [
    # vary Xbb cut
    (
        "one_jet_xbb08",  # selection
        r"pre-sel & Xbb<0.8",  # label
        "jet0_onejet",  # label to save figure
        "hptmreg",  # histogram to fill
        "jet0pt",  # variable to project
    ),
    (
        "one_jet_xbb08-095",
        r"pre-sel & Xbb[0.8,0.95]",
        "jet0_xbb08-095",
        "hptmreg",
        "jet0pt",
    ),
    (
        "one_jet_xbb095-098",
        r"pre-sel & Xbb[0.95,0.98]",
        "jet0_xbb095-098",
        "hptmreg",
        "jet0pt",
    ),
    (
        "one_jet_xbb098",
        r"pre-sel & Xbb[0.98,1]",
        "jet0_xbb098",
        "hptmreg",
        "jet0pt",
    ),
    (
        "one_jet_xbb08-msd60",  # selection
        r"pre-sel & Xbb<0.8 & $m^0>60$",  # label
        "jet0_onejetxbb08_pt",  # label to save figure
        "hpt",  # histogram to fill
        None,  # variable to project
    ),
    (
        "one_jet_xbb08-09-msd60",
        r"pre-sel & Xbb[0.8,0.9] & $m^0>60$",
        "jet0_xbb08-09_pt",
        "hpt",
        None,
    ),
    (
        "one_jet_xbb09-095-msd60",
        r"pre-sel & Xbb[0.9,0.95] & $m^0>60$",
        "jet0_xbb09-095_pt",
        "hpt",
        None,
    ),
    (
        "one_jet_xbb095-098-msd60",
        r"pre-sel & Xbb[0.95,0.98] & $m^0>60$",
        "jet0_xbb095-098_pt",
        "hpt",
        None,
    ),
    (
        "one_jet_xbb098-msd60",
        r"pre-sel & Xbb[0.98,1] & $m^0>60$",
        "jet0_xbb098_pt",
        "hpt",
        None,
    ),
    (
        "pt400msd60",
        r"$p_T^0>400 & m_{SD}^0>60$",
        "jet0txbb_pt400msd60",
        "hxbb",
        "jet0txbb",
    ),
    (
        "pt400msd60",
        r"$p_T^0>400 & m_{SD}^0>60$",
        "jet0txbbtan_pt400msd60",
        "hxbb",
        "jet0txbbtan",
    ),
]

trigger_info = {}
trigger_info_2d = {}
for key, ev_label in to_loop.items():
    trigger_info[key] = {}
    trigger_info_2d[key] = {}
    print(key)

    events = events_dict[key]
    xbb_0 = events["ak8FatJetPNetXbb"][0]
    pt_0 = events["ak8FatJetPt"][0]
    msd_0 = events["ak8FatJetMsd"][0]
    mreg_0 = events["ak8FatJetPNetMass"][0]
    if add_ht:
        ht = events["ht"][0]

    msd_cut = (msd_0 > 40) | (mreg_0 > 40)
    m60_cut = (msd_0 > 60) | (mreg_0 > 60)
    selection_dict = {
        "one_jet": (pt_0 > 250) & (msd_cut),
        "one_jet_09": (pt_0 > 400) & (xbb_0 > 0.9),
        "one_jet_300_09": (pt_0 > 300) & (xbb_0 > 0.9),
        "pt400msd60": (pt_0 > 400) & (msd_0 > 60),
        "one_jet_xbb08": (pt_0 > 250) & (msd_cut) & (xbb_0 < 0.8),
        "one_jet_xbb08-095": (pt_0 > 250) & (msd_cut) & (xbb_0 > 0.8) & (xbb_0 < 0.95),
        "one_jet_xbb095-098": (pt_0 > 250) & (msd_cut) & (xbb_0 > 0.95) & (xbb_0 < 0.98),
        "one_jet_xbb098": (pt_0 > 250) & (msd_cut) & (xbb_0 > 0.98),
        "one_jet_xbb08-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 < 0.8),
        "one_jet_xbb08-095-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 > 0.8) & (xbb_0 < 0.95),
        "one_jet_xbb08-09-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 > 0.8) & (xbb_0 < 0.95),
        "one_jet_xbb09-095-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 > 0.8) & (xbb_0 < 0.95),
        "one_jet_xbb095-098-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 > 0.95) & (xbb_0 < 0.98),
        "one_jet_xbb098-msd60": (pt_0 > 250) & (m60_cut) & (xbb_0 > 0.98),
        "one_jet_08": (pt_0 > 250) & (msd_cut) & (xbb_0 > 0.8),
    }
    # two_jets = (pt_0 > 300) & (pt_1 > 250) & (msd_0 > 60) & (msd_1 > 60)

    for trigger_title, (triggers, trigger_label) in trigger_dict.items():
        title = f"{ev_label}_{trigger_title}"

        trigger_info[key][trigger_label] = []
        trigger_info_2d[key][trigger_label] = []

        if "BoostedHbb" in trigger_title or "Combined" in trigger_title:
            selections_to_loop = configurations + configurations_bb
        else:
            selections_to_loop = configurations

        print(trigger_title, selections_to_loop)

        for sel, label, lab, h_toclone, to_project in selections_to_loop:
            selection = selection_dict[sel]
            trigger_selection = np.zeros_like(selection)
            for hlt in triggers:
                trigger_selection |= (events[hlt].values == 1).squeeze()
            num_selection = selection & trigger_selection

            if h_toclone == "hxbb":
                den = hxbb.copy().fill(
                    jet0txbb=xbb_0[selection],
                    jet0txbbtan=np.arctanh(xbb_0[selection]),
                )
                num = hxbb.copy().fill(
                    jet0txbb=xbb_0[num_selection],
                    jet0txbbtan=np.arctanh(xbb_0[num_selection]),
                )
                trigger_info[key][trigger_label].append(
                    (num.project(to_project), den.project(to_project), label, f"{title}_{lab}")
                )
            elif h_toclone == "hmsd":
                den = hmsd.copy().fill(
                    jet0msd=msd_0[selection],
                )
                num = hmsd.copy().fill(
                    jet0msd=msd_0[num_selection],
                )
                trigger_info[key][trigger_label].append((num, den, label, f"{title}_{lab}"))
            elif h_toclone == "hht":
                den = hht.copy().fill(
                    ht=ht[selection],
                )
                num = hht.copy().fill(
                    ht=ht[num_selection],
                )
                trigger_info[key][trigger_label].append((num, den, label, f"{title}_{lab}"))
            elif h_toclone == "hpt":
                den = hpt.copy().fill(
                    jet0pt=pt_0[selection],
                )
                num = hpt.copy().fill(
                    jet0pt=pt_0[num_selection],
                )
                trigger_info[key][trigger_label].append((num, den, label, f"{title}_{lab}"))
            else:
                den = hptmsd.copy().fill(
                    jet0pt=pt_0[selection],
                    jet0msd=msd_0[selection],
                )
                num = hptmsd.copy().fill(
                    jet0pt=pt_0[num_selection],
                    jet0msd=msd_0[num_selection],
                )
                trigger_info_2d[key][trigger_label].append((num, den, label, f"{title}_{lab}"))

## Plot 1D

In [None]:
for trigger_label in trigger_info["data"]:
    for i, _ in enumerate(trigger_info["data"][trigger_label]):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        for key in trigger_info:
            numerator, denominator, label, title = trigger_info[key][trigger_label][i]
            hep.histplot(
                numerator / denominator,
                # yerr is the relative shift from the y value
                yerr=abs(
                    clopper_pearson_interval(numerator.view(), denominator.view())
                    - numerator.view() / denominator.view()
                ),
                ax=ax,
                label=to_loop[key],
                histtype="errorbar",
                capsize=4,
                elinewidth=1,
            )
        hep.cms.label(r"", data=True, lumi=yearlabel, year=year, fontsize=14)
        leg = ax.legend(loc="lower right", fontsize=10)
        leg.set_title(trigger_label + f"\n{label}", prop={"size": 10})
        leg.get_title().set_multialignment("center")
        ax.set_ylabel("Efficiency")
        ax.axhline(y=1.0, color="k", linestyle="dashdot")
        ax.set_ylim(0, 1.1)
        idir = title.split("_")[1]
        title = "_".join(title.split("_")[2:])
        fig.tight_layout()
        os.system(f"mkdir -p trigger_plots/{year}/{idir}")
        fig.savefig(f"trigger_plots/{year}/{idir}/1d_{title}.png")
        plt.close()
        # plt.show()

## Plot 2D

In [None]:
def plot_2d(ax, numerator, denominator, label, title):
    eff, bins_x, bins_y = (numerator / denominator).to_numpy()
    cbar = hep.hist2dplot(numerator / denominator, ax=ax, cmin=0.1, cmax=1)
    cbar.cbar.set_label(r"Efficiency", size=18)
    cbar.cbar.ax.get_yaxis().labelpad = 15
    for i in range(len(bins_x) - 1):
        for j in range(len(bins_y) - 1):
            if not math.isnan(eff[i, j]):
                ax.text(
                    (bins_x[i] + bins_x[i + 1]) / 2,
                    (bins_y[j] + bins_y[j + 1]) / 2,
                    eff[i, j].round(2),
                    color="black",
                    ha="center",
                    va="center",
                    fontsize=12,
                )
    hep.cms.label(r"", data=True, lumi=yearlabel, year=year, fontsize=14)
    ax.set_title(f"{title}\n{label}", fontsize=10, y=1.0, pad=25)


for key in trigger_info_2d:
    for trigger_label, info in trigger_info_2d[key].items():
        for i in info:
            numerator, denominator, label, title = i
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            plot_2d(
                ax,
                numerator,
                denominator,
                label,
                trigger_label,
            )
            idir = (title.split("_"))[1]
            title = "_".join(title.split("_")[2:])
            fig.tight_layout()
            os.system(f"mkdir -p trigger_plots/{year}/{idir}")
            fig.savefig(f"trigger_plots/{year}/{idir}/2d_{key}_{title}.png", bbox_inches="tight")
            plt.close()

# Plot statistical error

In [None]:
def plot_2d_error(ax, numerator, denominator, label, title, datalabel=None):
    err = abs(
        clopper_pearson_interval(numerator.view(), denominator.view())
        - numerator.view() / denominator.view()
    )[0]
    cbar = hep.hist2dplot(err, ax=ax, cmin=0.0, cmax=0.5)
    cbar.cbar.set_label(r"Stat Error", size=18)
    cbar.cbar.ax.get_yaxis().labelpad = 15
    ax.set_title(f"{title}\n{label}", fontsize=10, y=1.0, pad=25)


"""
for ikey, key in enumerate(trigger_info_2d.keys()):
    for trigger_label, info in trigger_info_2d[key].items():
        for i, inf in enumerate(info):
            numerator, denominator, label, title = inf
            datalabel = None
            if "2022" in title:
                datalabel = "7.97"
            if "2022EE" in title:
                datalabel = "26.3"
            fig, ax = plt.subplots(1, 1, figsize=(8, 6))
            plot_2d_error(
                ax,
                numerator,
                denominator,
                label,
                trigger_label,
                datalabel=datalabel,
            )
            title = title.replace(".png", "_stat_error.png")
            fig.tight_layout()
            # fig.savefig(f"trigger_plots/{title}", bbox_inches="tight")
            plt.close()

"""