# Misc Checks

In [None]:
from __future__ import annotations

import importlib
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
from xgboost import XGBClassifier

import HH4b.plotting as plotting
import HH4b.postprocessing as postprocessing
from HH4b.hh_vars import samples, samples_run3, years
from HH4b.postprocessing import PostProcess

formatter = mticker.ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-3, 3))

In [None]:
# automatically reloads imported files on edits
%load_ext autoreload
%autoreload 2

## Load samples

In [None]:
MAIN_DIR = Path("../../../")
plot_dir = MAIN_DIR / "../plots/PostProcess/24Apr24Legacy"
plot_dir.mkdir(parents=True, exist_ok=True)

data_dir = "/ceph/cms/store/user/rkansal/bbbb/skimmer/24Apr19LegacyFixes_v12_private_signal/"
dirs = {data_dir: samples}

bdt_model_name = "24Apr21_legacy_vbf_vars"
bdt_config = "24Apr21_legacy_vbf_vars"

In [None]:
def load_process_run3_samples(data_dir, year, samples_run3):
    events_dict = postprocessing.load_run3_samples(data_dir, year, True, samples_run3)
    legacy_label = "Legacy"

    # define BDT model
    bdt_model = XGBClassifier()
    bdt_model.load_model(fname=f"../boosted/bdt_trainings_run3/{bdt_model_name}/trained_bdt.model")
    # get function
    make_bdt_dataframe = importlib.import_module(
        f".{bdt_config}", package="HH4b.boosted.bdt_trainings_run3"
    )

    # inference and assign score
    events_dict_postprocess = {}
    for key in events_dict:
        bdt_events = make_bdt_dataframe.bdt_dataframe(events_dict[key])
        preds = bdt_model.predict_proba(bdt_events)
        PostProcess.add_bdt_scores(bdt_events, preds)
        bdt_events["weight"] = events_dict[key]["finalWeight"].to_numpy()
        bdt_events["H2TXbb"] = events_dict[key][f"bbFatJetPNetTXbb{legacy_label}"].to_numpy()[:, 1]
        bdt_events["H2PNetMass"] = events_dict[key][f"bbFatJetPNetMass{legacy_label}"].to_numpy()[
            :, 1
        ]
        events_dict[key] = bdt_events

    return events_dict

In [None]:
processes = ["data"] + ["hh4b"] + ["ttbar"]

for year in samples_run3:
    for key in list(samples_run3[year].keys()):
        if key not in processes:
            samples_run3[year].pop(key)

In [None]:
bdt_training_keys = PostProcess.get_bdt_training_keys("24Apr21_legacy_vbf_vars")

events_dict_postprocess = {}
cutflows = {}
for year in years:
    print(f"\n{year}")
    events_dict_postprocess[year] = load_process_run3_samples(data_dir, year, samples_run3)

print("Loaded all years")

In [None]:
events_combined = PostProcess.combine_run3_samples(
    events_dict_postprocess, processes, ["qcd", "ttbar"]
)

## S/B optimization using the ABCD method

In [None]:
def get_nevents_sidebands(events, cut, mass, mass_window):
    mw_size = mass_window[1] - mass_window[0]

    # get yield in left sideband
    cut_mass_0 = (events[mass] < mass_window[0]) & (events[mass] > (mass_window[0] - mw_size / 2))

    # get yield in right sideband
    cut_mass_1 = (events[mass] < mass_window[1] + mw_size / 2) & (events[mass] > mass_window[1])

    return np.sum(events["weight"][(cut_mass_0 | cut_mass_1) & cut])


def get_nevents_signal(events, cut, mass, mass_window):
    cut_mass = (events[mass] >= mass_window[0]) & (events[mass] <= mass_window[1])

    # get yield in Higgs mass window
    return np.sum(events["weight"][cut & cut_mass])


def get_nevents_nosignal(events, cut, mass, mass_window):
    cut_mass = (events[mass] >= mass_window[0]) & (events[mass] <= mass_window[1])

    # get yield in Higgs mass window
    return np.sum(events["weight"][cut & ~cut_mass])


def get_s_b(events_dict, cut_dict, mass, mass_window):
    s = get_nevents_signal(events_dict["hh4b"], cut_dict["hh4b"], mass, mass_window)
    bd = get_nevents_sidebands(events_dict["data"], cut_dict["data"], mass, mass_window)
    bt = get_nevents_sidebands(events_dict["ttbar"], cut_dict["ttbar"], mass, mass_window)
    ts = get_nevents_sidebands(events_dict["ttbar"], cut_dict["ttbar"], mass, mass_window)
    b = bd - bt + ts
    return s, b

In [None]:
def data_tt(events_dict, cut_dict, mass, mass_window):
    s = get_nevents_signal(events_dict["data"], cut_dict["data"], mass, mass_window)
    b = get_nevents_sidebands(events_dict["ttbar"], cut_dict["ttbar"], mass, mass_window)
    return s, b


def abcd(events_dict, txbb_cut, bdt_cut, mass, mass_window):
    dicts = {"data": [], "ttbar": []}

    for key in ["hh4b", "data", "ttbar"]:
        events = events_dict[key]
        cut = (events["bdt_score"] > bdt_cut) & (events["H2TXbb"] > txbb_cut)

        if key == "hh4b":
            s = get_nevents_signal(events, cut, mass, mass_window)
            continue

        # region A
        dicts[key].append(get_nevents_signal(events, cut, mass, mass_window))
        # region B
        dicts[key].append(get_nevents_nosignal(events, cut, mass, mass_window))

        cut = (events["bdt_score"] < 0.6) & (events["H2TXbb"] < 0.8)
        # region C
        dicts[key].append(get_nevents_signal(events, cut, mass, mass_window))
        # region D
        dicts[key].append(get_nevents_nosignal(events, cut, mass, mass_window))

    dmt = np.array(dicts["data"]) - np.array(dicts["ttbar"])
    bqcd = dmt[2] * dmt[1] / dmt[3]
    # print(dicts)

    return s, bqcd + dicts["ttbar"][0], dicts["ttbar"][0]

## Run the optimization:

In [None]:
mass = "H2PNetMass"
mass_window = [115, 135]

for txbb_cut in np.arange(0.96, 1.0, 0.005):
    for bdt_cut in np.arange(0.9, 1.0, 0.01):
        s, b, bt = abcd(events_combined, txbb_cut, bdt_cut, mass, mass_window)
        print(txbb_cut, bdt_cut, s, b, bt, s / b)

# abcd(events_combined, 0.99, 0.97, mass, mass_window)

## Old stuff:

In [None]:
sig_samples = {"hh4b": samples[year]["hh4b"]}

In [None]:
events = pd.read_parquet(
    Path(data_dir) / "2022EE" / "GluGlutoHHto4B_kl-1p00_kt-1p00_c2-0p00_TuneCP5_13p6TeV" / "parquet"
)

In [None]:
samples = ["qcd", "ttbar"]
mass = "bbFatJetMsd"
tagger = "bbFatJetPNetTXbbLegacy"
i = 1

for sample in samples:
    plt.figure(figsize=(10, 10))
    plt.title(sample)
    for cut in [0, 0.8, 0.9, 0.95]:
        cut_mask = events_dict[sample][tagger][i] >= cut
        plt.hist(
            events_dict[sample][mass][i][cut_mask],
            np.arange(60, 251, 10),
            weights=events_dict[sample]["finalWeight"][cut_mask],
            histtype="step",
            label=rf"$T_{{Xbb}} \geq {cut}$",
            density=True,
        )

    plt.xlabel(f"Jet {i+1} {mass} (GeV)")
    plt.legend()
    plt.savefig(plot_dir / f"{sample}_{mass}{i}_{tagger}_sculpting.pdf", bbox_inches="tight")
    plt.show()

## BDT ROC Curve

## tt ROC curve

In [None]:
jet = 1
tagger = "bbFatJetPNetTXbbLegacy"
sig_jets_score = events_dict["hh4b"][tagger][jet]
bg_jets_score = {
    "qcd": events_dict["qcd"][tagger][jet],
    "ttbar": events_dict["ttbar"][tagger][jet],
}

In [None]:
from sklearn.metrics import roc_curve

bg_skip = 1
sig_key = "hh4b"
weight_key = "finalWeight"
rocs = {}

for bg_key in ["qcd", "ttbar"]:
    print(bg_key)
    y_true = np.concatenate(
        [
            np.ones(len(sig_jets_score)),
            np.zeros((len(bg_jets_score[bg_key]) - 1) // bg_skip + 1),
        ]
    )

    weights = np.concatenate(
        [
            events_dict[sig_key][weight_key].to_numpy(),
            events_dict[bg_key][weight_key].to_numpy()[::bg_skip],
        ]
    )

    scores = np.concatenate((sig_jets_score, bg_jets_score[bg_key][::bg_skip]))

    fpr, tpr, thresholds = roc_curve(y_true, scores, sample_weight=weights)

    rocs[bg_key] = {
        "fpr": fpr,
        "tpr": tpr,
        "thresholds": thresholds,
        "label": plotting.label_by_sample[bg_key],
    }

In [None]:
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

In [None]:
plotting.multiROCCurveGrey(
    {"test": rocs},
    [0.2, 0.5],
    xlim=[0, 0.8],
    ylim=[1e-5, 1],
    plot_dir=plot_dir,
    name=f"{tagger}_ROCs",
    show=True,
)