In [None]:
import argparse
import os
from pathlib import Path
from collections import OrderedDict

import hist
import numpy as np
import uproot

from HH4b import plotting
from HH4b.utils import ShapeVar
from HH4b.hh_vars import data_key

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
MAIN_DIR = Path("../../../")
nTF = 1

vbf = False
# k2v0sig = True
mreg = True

plot_dir = MAIN_DIR / "plots/PostFit/24Apr21_legacy_bdt_ggf_tighter"
plot_dir.mkdir(exist_ok=True, parents=True)

regions = "all"

In [None]:
cards_dir = "24Apr21_legacy_bdt_ggf_tighter"
file = uproot.open(
    f"/uscms/home/rkansal/hhcombine/hh4b/cards/{cards_dir}/FitShapes.root"
    # f"/uscms/home/rkansal/eos/bbVV/cards/{cards_dir}/FitShapes.root"
)

In [None]:
# (name in templates -> name in cards)
hist_label_map_inverse = OrderedDict(
    [
        ("qcd", "CMS_bbbb_hadronic_qcd_datadriven"),
        ("others", "others"),
        ("ttbar", "ttbar"),
        ("vhtobb", "VH_hbb"),
        ("tthtobb", "ttH_hbb"),
        ("data", "data_obs"),
    ]
)

if vbf:
    hist_label_map_inverse["vbfhh4b-k2v0"] = "vbfhh4b-k2v0"
else:
    hist_label_map_inverse["hh4b"] = "hh4b"

hist_label_map = {val: key for key, val in hist_label_map_inverse.items()}
samples = list(hist_label_map.values())

fit_shape_var_msd = ShapeVar(
    "H2Msd",
    r"$m^{j2}_\mathrm{SD}$ (GeV)",
    [16, 60, 220],
    reg=True,
    blind_window=[110, 140],
)

fit_shape_var_mreg = ShapeVar(
    "H2PNetMass",
    r"$m^{j2}_\mathrm{reg}$ (GeV)",
    [16, 60, 220],
    reg=True,
    blind_window=[110, 140],
)
shape_vars = [fit_shape_var_msd] if not mreg else [fit_shape_var_mreg]

In [None]:
shapes = {
    "prefit": "Pre-Fit",
    # "postfit": "S+B Post-Fit",
    "postfit": "B-only Post-Fit",
}

selection_regions_labels = {
    "passbin1": "Pass Bin1",
    "passbin2": "Pass Bin2",
    "passbin3": "Pass Bin3",
    "fail": "Fail",
}

if vbf:
    selection_regions_labels["passvbf"] = "Pass VBF"

In [None]:
if regions == "all":
    signal_regions = ["passbin1", "passbin2", "passbin3"]
    if vbf:
        signal_regions = ["passvbf"] + signal_regions
else:
    signal_regions = [regions]

bins = [*signal_regions, "fail"]
selection_regions = {key: selection_regions_labels[key] for key in bins}

In [None]:
hists = {}
for shape in shapes:
    hists[shape] = {
        region: hist.Hist(
            hist.axis.StrCategory(samples, name="Sample"),
            *[shape_var.axis for shape_var in shape_vars],
            storage="double",
        )
        for region in selection_regions
    }

    for region in selection_regions:
        h = hists[shape][region]
        templates = file[f"{region}_{shape}"]
        # print(templates)
        for key, file_key in hist_label_map_inverse.items():
            if key != data_key:
                if file_key not in templates:
                    print(f"No {key} in {region}")
                    continue

                data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
                h.view(flow=False)[data_key_index, :] = templates[file_key].values()

        data_key_index = np.where(np.array(list(h.axes[0])) == data_key)[0][0]
        h.view(flow=False)[data_key_index, :] = np.nan_to_num(
            templates[hist_label_map_inverse[data_key]].values()
        )

In [None]:
print("Signal in mass window:", np.sum(hists["postfit"]["passbin1"]["hh4b", 5:8].values()))

bg_tot = np.sum(
    [
        np.sum(hists["postfit"]["passbin1"][key, 5:8].values())
        for key in hist_label_map_inverse
        if key not in ["hh4b", "vbfhh4b-k2v0", "data"]
    ]
)
print("BG in mass window:", bg_tot)

In [None]:
print([key for key in hist_label_map_inverse if key not in ["hh4b", "vbfhh4b-k2v0", "data"]])
{
    key: np.sum(hists["postfit"]["passbin1"][key, 5:8].values())
    for key in hist_label_map_inverse
    if key not in ["hh4b", "vbfhh4b-k2v0", "data"]
}

In [None]:
year = "2022-2023"
pass_ratio_ylims = [0, 2]
fail_ratio_ylims = [0, 2]
signal_scale = 5.0

ylims = {
    "passvbf": 15,
    "passbin1": 10,
    "passbin2": 50,
    "passbin3": 800,
    "fail": 100000,
}

for shape, shape_label in shapes.items():
    for region, region_label in selection_regions.items():
        pass_region = region.startswith("pass")
        for shape_var in shape_vars:
            # print(hists[shape][region])
            plot_params = {
                "hists": hists[shape][region],
                "sig_keys": ["hh4b"] if not vbf else ["vbfhh4b-k2v0"],
                "sig_scale_dict": (
                    {"hh4b": signal_scale if pass_region else 1.0} if not vbf else None
                ),
                "bg_keys": ["qcd", "ttbar", "vhtobb", "tthtobb", "others"],
                "show": True,
                "year": year,
                "ylim": ylims[region],
                "xlim": 220,
                # "xlim_low": 50,
                "xlim_low": 60,
                "ratio_ylims": pass_ratio_ylims if pass_region else fail_ratio_ylims,
                "title": f"{shape_label} {region_label} Region",
                "name": f"{plot_dir}/{shape}_{region}_{shape_var.var}.pdf",
                "bg_order": ["diboson", "vjets", "vhtobb", "ttbar", "qcd"],
                "energy": 13.6,
            }

            plotting.ratioHistPlot(**plot_params)