In [None]:
# automatically reloads imported files on edits
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
from pathlib import Path

from HH4b.utils import ShapeVar, singleVarHist
from HH4b import plotting

In [None]:
tag = "nanov15_20251202_v15_signal"
INFERENCE_RESULTS_DIR = Path(
    f"/ceph/cms/store/user/zichun/bbbb/signal_processed/bdt_inference/{tag}"
)

In [None]:
def read_pickle(file_path: Path):
    with file_path.open("rb") as file:
        data = pickle.load(file)
    return data

In [None]:
control_plot_vars = [
    ShapeVar(
        var="bdt_score",
        label="$BDT$",
        bins=[50, 0, 1],
    ),
    ShapeVar(
        var="bdt_score_vbf",
        label="$BDT_{VBF}$",
        bins=[50, 0, 1],
    ),
    ShapeVar(
        var="HHPt",
        label=r"$p_{T}^{jj}$ [GeV]",
        bins=[50, 200, 3000],
    ),
    ShapeVar(
        var="HHeta",
        label="$\eta_{jj}$",
        bins=[50, -2.5, 2.5],
    ),
    ShapeVar(
        var="HHmass",
        label="$m_{jj}$ [GeV]",
        bins=[50, 200, 2500],
    ),
    ShapeVar(
        var="MET",
        label="MET [GeV]",
        bins=[50, 0, 1000],
    ),
    ShapeVar(
        var="H1T32",
        label=r"$\tau_{32}^{j1}$",
        bins=[50, 0, 1.2],
    ),
    ShapeVar(
        var="H2T32",
        label=r"$\tau_{32}^{j2}$",
        bins=[50, 0, 1.2],
    ),
    ShapeVar(
        "H1Pt",
        label=r"$p_{T}^{j1}$ [GeV]",
        bins=[50, 250, 2500],
    ),
    ShapeVar(
        "H2Pt",
        label=r"$p_{T}^{j2}$ [GeV]",
        bins=[50, 250, 2500],
    ),
    ShapeVar(
        "H1eta",
        label=r"$\eta^{j1}$",
        bins=[50, -2.5, 2.5],
    ),
    ShapeVar(
        "H1Pt_HHmass",
        label=r"$p_{T}^{j1}$ / $m_{HH}$",
        bins=[50, 0, 2.5],
    ),
    ShapeVar(
        "H2Pt_HHmass",
        label=r"$p_{T}^{j2}$ / $m_{HH}$",
        bins=[50, 0, 2.5],
    ),
    ShapeVar(
        "H1Pt_H2Pt",
        label=r"$p_{T}^{j1}$ / $p_{T}^{j2}$",
        bins=[50, 0, 3],
    ),
    ShapeVar(var="VBFjjMass", label="$m^{jj}_{VBF}$ [GeV]", bins=[50, 0, 5000]),
    ShapeVar(
        var="VBFjjDeltaEta",
        label="$\Delta \eta^{jj}_{VBF}$",
        bins=[50, 0, 9],
    ),
    ShapeVar(
        var="H1AK4JetAway1dR",
        label=r"$\Delta R(j1, AK4_{j1})$",
        bins=[50, 0.8, 6],
    ),
    ShapeVar(
        var="H2AK4JetAway2dR",
        label=r"$\Delta R(j2, AK4_{j2})$",
        bins=[50, 0.8, 6],
    ),
    ShapeVar(
        var="H1AK4JetAway1mass",
        label="$m^{j1 + AK4_{j1}}$ [GeV]",
        bins=[50, 0, 2000],
    ),
    ShapeVar(
        var="H2AK4JetAway2mass",
        label="$m^{j2 + AK4_{j2}}$ [GeV]",
        bins=[50, 0, 2000],
    ),
]

In [None]:
sig_keys = [
    "hh4b",
    # 'hh4b-kl0',
    # 'hh4b-kl2p45',
    # 'hh4b-kl5',
    "vbfhh4b",
    # 'vbfhh4b-kvm1p83-k2v3p57-klm3p39',
    # 'vbfhh4b-kvm1p6-k2v2p72-klm1p36',
    # 'vbfhh4b-kvm1p21-k2v1p94-klm0p94',
    # 'vbfhh4b-kvm0p962-k2v0p959-klm1p43',
    # 'vbfhh4b-kvm0p758-k2v1p44-klm19p3',
    # 'vbfhh4b-kvm0p012-k2v0p03-kl10p2',
    # 'vbfhh4b-kv1p74-k2v1p37-kl14p4',
    # 'vbfhh4b-k2v0',
    # 'vbfhh4b-kv2p12-k2v3p87-klm5p96',
]
bkg_keys = ["qcd", "ttbar", "tthhtobb", "novhhtobb", "nozzdiboson", "vhtobb", "vjets", "zz"]

for year in ("2024", "2025"):
    file_path = INFERENCE_RESULTS_DIR / f"{year}_bdt_scores.pkl"
    scores = read_pickle(file_path)
    hists = {}

    plot_dir = Path("bdt_control_plots") / year
    plot_dir.mkdir(parents=True, exist_ok=True)

    kwargs = {}

    for i, shape_var in enumerate(control_plot_vars):
        if shape_var.var not in hists:
            hists[shape_var.var] = singleVarHist(
                scores,
                shape_var,
                weight_key="finalWeight",
            )

            qcd_norm = plotting.ratioHistPlot(
                hists[shape_var.var],
                year,
                sig_keys,
                bkg_keys,
                name=str(plot_dir / f"{shape_var.var}"),
                show=False,
                log=True,
                plot_significance=False,
                significance_dir=shape_var.significance_dir,
                # ratio_ylims=[0.2, 1.8],
                ratio_ylims=[0.0, 2.0],
                bg_err_mcstat=True,
                reweight_qcd=True,
                xbin_gev="[GeV]" in shape_var.label,
                **kwargs,
            )
        if i == 0:
            kwargs["qcd_norm"] = qcd_norm