In [None]:
from __future__ import annotations

import pickle
from collections import OrderedDict
from pathlib import Path

import hist
import matplotlib.pyplot as plt
import numpy as np
import uproot
from hist import Hist

from boostedhh.hh_vars import data_key, years
from bbtautau.postprocessing.datacardHelpers import sum_templates
from bbtautau.postprocessing.postprocessing import shape_vars
from bbtautau.postprocessing import plotting
from bbtautau.postprocessing import utils as putils
from bbtautau.postprocessing.Samples import BGS, CHANNELS
from bbtautau.userConfig import SHAPE_VAR

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
folder = "26Jan6-vbf"
bmin = 10


sig_keys = ["ggfbbtt", "vbfbbtt"]  #
# sig_keys = ["vbfbbtt"]  #


use_bdt = True
tag = ""
# tag = ["nobbttpresel_noNormSig_noBkgConstraint_ParT", "lowregBDT"][use_bdt]

plot_dir = Path("/home/users/lumori/bbtautau/plots/PostFit") / folder / tag / f"bmin_{bmin}"
print(plot_dir)
plot_dir.mkdir(exist_ok=True, parents=True)

# csv_dir = Path("/home/users/lumori/bbtautau/plots/SensitivityStudy/2025-12-12/full_presel/")/

In [None]:
file = uproot.open(
    f"/home/users/lumori/bbtautau/src/bbtautau/cards/{folder}/{tag}/bmin_{bmin}/allsigsFitShapes.root"
    # f"/home/users/lumori/bbtautau/src/bbtautau/cards/{folder}/{tag}/bmin_{bmin}/ggfFitShapes.root"
)

# template_folder = "25Dec27-ggf-only"
template_folder = "25Dec27-vbf"
templates_dir = f"templates/{template_folder}/{tag}/bmin_{bmin}/"
print(templates_dir)

In [None]:
file.keys()

In [None]:
def get_pre_templates(templates_dir):
    templates_dict = {}
    for year in years:
        with (templates_dir / f"{year}_templates.pkl").open("rb") as f:
            templates_dict[year] = pickle.load(f)

    return sum_templates(templates_dict, years)

In [None]:
workspace_data_key = "data_obs"

# (name in templates, name in cards)
hist_label_map_inverse = OrderedDict(
    [
        ("qcddy", "CMS_bbtautau_boosted_qcd_datadriven"),
        ("ttbarsl", "ttbarsl"),
        ("ttbarll", "ttbarll"),
        ("ttbarhad", "ttbarhad"),
        # ("dyjets", "dyjets"),
        ("wjets", "wjets"),
        ("zjets", "zjets"),
        ("hbb", "hbb"),
        (data_key, workspace_data_key),
    ]
)

hist_label_map = {val: key for key, val in hist_label_map_inverse.items()}


# pbg_keys = [bk for bk in bg_keys if bk not in ["Diboson", "Hbb", "HWW"]]
pbg_keys = ["qcddy", "ttbarhad", "ttbarsl", "ttbarll", "wjets", "zjets", "hbb"]
samples = pbg_keys + sig_keys + [data_key]

In [None]:
shapes = {
    "prefit": "Pre-Fit",
    # "shapes_fit_s": "S+B Post-Fit",
    "postfit": "B-only Post-Fit",
}

In [None]:
selection_regions = {}
region_info = {}  # Store parsed region components for clean access

for signal in sig_keys:
    for channel in CHANNELS.values():
        for pass_fail in ["pass", "fail"]:
            region_name = f"{signal}{channel.key}{pass_fail}"
            selection_regions[region_name] = f"{signal} {channel.label} {pass_fail.title()}"
            region_info[region_name] = {
                "signal": signal,
                "signal_label": {"ggfbbtt": "ggF", "vbfbbtt": "VBF"}[signal],
                "channel": channel,
                "channel_key": channel.key,
                "pass_fail": pass_fail,
                "is_pass": pass_fail == "pass",
            }

In [None]:
hists = {}
bgerrs = {}

for shape in shapes:
    print(shape)
    hists[shape] = {
        region: Hist(
            hist.axis.StrCategory(samples, name="Sample"),
            *[shape_var.axis for shape_var in shape_vars],
            storage="double",
        )
        for region in selection_regions
    }
    bgerrs[shape] = {}

    for region in selection_regions:
        h = hists[shape][region]
        templates = file[f"{region}_{shape}"]
        for key, file_key in hist_label_map_inverse.items():
            if key != data_key:
                if file_key not in templates:
                    print(f"No {key} in {region}")
                    continue

                data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
                h.view(flow=False)[data_key_index, :] = templates[file_key].values()

        # # if key not in fit output, take from templates
        # for key in pbg_keys:
        #     if key not in hist_label_map_inverse:
        #         data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
        #         h.view(flow=False)[data_key_index, :] = bg_pre_templates[region][key, ...].values()

        # if key not in fit output, take from templates
        info = region_info[region]
        channel_key = info["channel_key"]

        for key in sig_keys:
            if key not in hist_label_map_inverse:
                sig_pre_templates = get_pre_templates(
                    Path(f"{templates_dir}/{info['signal']}/{channel_key}")
                )
                data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]

                # Template key: "bbtt{channel}{pass_fail}" e.g. "bbtthepass"
                # template_key = f"bbtt{channel_key}{info['pass_fail']}"
                template_key = f"{info['pass_fail']}"
                h.view(flow=False)[data_key_index, :] = sig_pre_templates[template_key][
                    key + channel_key, ...
                ].values()

        data_key_index = np.where(np.array(list(h.axes[0])) == data_key)[0][0]
        h.view(flow=False)[data_key_index, :] = np.nan_to_num(
            templates[hist_label_map_inverse[data_key]].values()
        )

        bgerrs[shape][region] = np.minimum(
            templates["TotalBkg"].errors(), templates["TotalBkg"].values()
        )

In [None]:
# if not unblinded:
#     for shapeh in hists.values():
#         for region, h in shapeh.items():
#             if region != "fail":
#                 utils.blindBins(h, [100, 150], data_key, axis=0)

In [None]:
# ylims = {"hhpass": 1, "passvbf": 11, "fail": 7e5}
sig_scale_dict = {"ggfbbtt": 100, "vbfbbtt": 500}

(plot_dir / "preliminary").mkdir(exist_ok=True, parents=True)
(plot_dir / "final").mkdir(exist_ok=True, parents=True)

for prelim, plabel, pplotdir in zip([True, False], ["Preliminary", ""], ["preliminary", "final"]):
    for shape, shape_label in shapes.items():
        # if shape != "postfit":
        #     continue
        for region, region_label in selection_regions.items():
            info = region_info[region]
            for i, shape_var in enumerate(shape_vars):
                plot_params = {
                    "hists": hists[shape][region],
                    "sig_keys": sig_keys,
                    "bg_keys": pbg_keys,
                    "bg_err": bgerrs[shape][region],
                    "data_err": True,
                    "sig_scale_dict": sig_scale_dict if info["is_pass"] else None,
                    "show": True,
                    "year": "2022-2023",
                    # "ylim": ylims[region],
                    # "title": f"{shape_label} {region_label}",
                    "region_label": info["signal_label"]
                    + " "
                    + info["channel"].label
                    + " "
                    + info["pass_fail"],
                    "name": f"{plot_dir}/{pplotdir}/{pplotdir}_{shape}_{region}_{shape_var.var}.pdf",
                    "ratio_ylims": [0, 2],
                    "cmslabel": plabel,
                    "leg_args": {"fontsize": 22, "ncol": 2},
                    "channel": info["channel"],
                    "blind_region": SHAPE_VAR["blind_window"] if info["is_pass"] else None,
                }

                plotting.ratioHistPlot(**plot_params)

        # break
    break

## QCD Transfer Factor

In [None]:
import matplotlib.ticker as mticker
import mplhep as hep

plt.style.use(hep.style.CMS)
hep.style.use("CMS")
formatter = mticker.ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-3, 3))

In [None]:
ylims = {"passggf": 1e-4, "passvbf": 1e-5}
tfs = {}

for region, region_label in selection_regions.items():
    info = region_info[region]

    # Only compute TF for pass regions
    if not info["is_pass"]:
        continue

    # Get corresponding fail region
    fail_region = f"{info['signal']}{info['channel_key']}fail"

    tf = hists["postfit"][region]["qcddy", ...] / hists["postfit"][fail_region]["qcddy", ...]

    print(tf)

    tfs[region] = tf

    hep.histplot(tf)
    plt.title(f"{region_label} Region")
    plt.ylabel("QCD Transfer Factor")
    plt.xlim([50, 250])
    plt.ylim([0, 1e-4])
    plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
    # plt.savefig(f"{plot_dir}/{region}_QCDTF.pdf", bbox_inches="tight")
    plt.show()

In [None]:
tf = tfs["passvbf"]
slope = (tf.view()[-1] - tf.view()[0]) / (245 - 55)
yint = tf.view()[0] - slope * 55
print(slope, yint)