In [None]:
# automatically reloads imported files on edits
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

from pathlib import Path

from HH4b import utils
from HH4b import postprocessing
import itertools

In [None]:
processed_path: Path = Path("Zbb_events_combined.pkl")
REPROCESS: bool = False  # if True, reprocess from the skimmed ntuples

In [None]:
# if True, apply the Z->2Q corrections from ZMuMu measurement

APPLY_Zto2Q_CORR: bool = False
if APPLY_Zto2Q_CORR:
    import correctionlib

    corr_dir = Path("ZMuMu_corrs/pT")
    corr_dict = {}

    for year in ["2022", "2023"]:
        corr_file = corr_dir / f"corr_{year}.json"
        if not corr_file.exists():
            raise FileNotFoundError(f"Correction file {corr_file} does not exist.")

        # Load the correction
        corr = correctionlib.CorrectionSet.from_file(str(corr_file))
        corr_dict[year] = corr

In [None]:
if REPROCESS or not processed_path.exists():
    path_dir = "/ceph/cms/store/user/zichun/bbbb/skimmer/ZbbHT25May28_v12v2_private_zbb/"

    # names of all samples,
    samples_run3 = {
        "data": [f"{key}_Run" for key in ["JetMET"]],
        "ttbar": ["TTto4Q", "TTtoLNu2Q"],
        "qcd": ["QCD_HT"],
        "hbb": ["GluGluHto2B_M-125"],
        "Zto2Q": ["Zto2Q-4Jets"],
        "Wto2Q": ["Wto2Q-3Jets"],
    }

    sys_vars = ["FSRPartonShower", "ISRPartonShower", "pileup"]

    base_columns = [
        ("bbFatJetPt", 2),
        ("bbFatJetEta", 2),
        ("bbFatJetMsd", 2),
        ("bbFatJetParTmassVis", 2),
        ("bbFatJetPNetMassLegacy", 2),
        ("bbFatJetParTTXbb", 2),
        ("weight", 1),
    ]

    triggers = {
        "2022": [
            "AK8PFJet500",
            "AK8PFJet420_MassSD30",
            "AK8PFJet425_SoftDropMass40",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
        ],
        "2022EE": [
            "AK8PFJet500",
            "AK8PFJet420_MassSD30",
            "AK8PFJet425_SoftDropMass40",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
        ],
        "2023": [
            "AK8PFJet500",
            "AK8PFJet420_MassSD30",
            "AK8PFJet425_SoftDropMass40",
            "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35",
            "AK8PFJet230_SoftDropMass40_PNetBB0p06",
        ],
        "2023BPix": [
            "AK8PFJet500",
            "AK8PFJet420_MassSD30",
            "AK8PFJet425_SoftDropMass40",
            "AK8PFJet230_SoftDropMass40_PNetBB0p06",
        ],
    }

    load_columns_pt_var = []
    for jesr, ud in itertools.product(["JES", "JER"], ["up", "down"]):
        load_columns_pt_var.append((f"bbFatJetPt_{jesr}_{ud}", 2))

    load_columns_mass_var = []
    for jmsr, ud in itertools.product(["JMS", "JMR"], ["up", "down"]):
        load_columns_mass_var.append((f"bbFatJetMsd_{jmsr}_{ud}", 2))
        load_columns_mass_var.append((f"bbFatJetParTmassVis_{jmsr}_{ud}", 2))
        load_columns_mass_var.append((f"bbFatJetPNetMassLegacy_{jmsr}_{ud}", 2))

    load_weight_shifts = []
    for var, ud in itertools.product(sys_vars, ["Up", "Down"]):
        load_weight_shifts.append((f"weight_{var}{ud}", 1))

    MC_common_extra_columns = load_columns_mass_var + load_columns_pt_var + load_weight_shifts

    ZQQ_extra_columns = [("GenZPt", 1), ("GenZBB", 1), ("GenZCC", 1), ("bbFatJetVQQMatch", 2)]
    WQQ_extra_columns = [("GenWPt", 1), ("GenWCS", 1), ("GenWUD", 1), ("bbFatJetVQQMatch", 2)]

    extra_columns_dict = {
        "data": [],
        "qcd": load_weight_shifts,
        "ttbar": MC_common_extra_columns,
        "hbb": MC_common_extra_columns,
        "Zto2Q": MC_common_extra_columns + ZQQ_extra_columns,
        "Wto2Q": MC_common_extra_columns + WQQ_extra_columns,
    }

    events_dict = {}
    for year in ["2022", "2022EE", "2023", "2023BPix"]:
        # for year in ["2022"]:
        events_dict[year] = {}

        # Have to load the samples separately because branches vary
        for sample, sample_list in samples_run3.items():
            print(f"Loading {sample} for {year}...")
            triggers_col = [(trigger, 1) for trigger in triggers[year]]

            # append the event dictionary for each year
            columns = triggers_col + base_columns + extra_columns_dict.get(sample, [])
            dataframes = {
                **utils.load_samples(
                    data_dir=path_dir,
                    samples={sample: sample_list},
                    year=year,
                    columns=utils.format_columns(columns),
                    variations=True,
                    weight_shifts=["FSRPartonShower", "ISRPartonShower", "pileup"],
                )
            }
            # concatenate all dataframes in this sample
            events_dict[year][sample] = []
            for key, df in dataframes.items():
                df["sample"] = key
                # rename columns to a single level
                for col, n in columns:
                    if n > 1:
                        for i in range(n):
                            df[f"{col}{i}"] = df[(col, i)]
                    else:
                        df[col] = df[(col, 0)]

                # Fill non-existing HLT columns with 0
                if year in ("2023", "2023BPix"):
                    # add AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35 (filled with 0)
                    if "AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35" not in df.columns:
                        df["AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35"] = np.zeros(
                            len(df), dtype=int
                        )
                    # also fill NaN values in AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35 with 0
                    df["AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35"] = (
                        df["AK8PFJet250_SoftDropMass40_PFAK8ParticleNetBB0p35"]
                        .fillna(0)
                        .astype(int)
                    )
                elif year in ("2022", "2022EE"):
                    if "AK8PFJet230_SoftDropMass40_PNetBB0p06" in df.columns:
                        # Add AK8PFJet230_SoftDropMass40_PNetBB0p06 (filled with 0)
                        df["AK8PFJet230_SoftDropMass40_PNetBB0p06"] = np.zeros(len(df), dtype=int)
                events_dict[year][sample].append(df)
            # concatenate all dataframes for this sample
            events_dict[year][sample] = pd.concat(events_dict[year][sample], ignore_index=True)

In [None]:
if REPROCESS or not processed_path.exists():
    events_combined = {
        "2022All": {},
        "2023All": {},
    }
    for key in samples_run3:
        events_combined["2022All"][key] = pd.concat(
            [events_dict[year][key] for year in ["2022", "2022EE"] if key in events_dict[year]]
        )
        events_combined["2023All"][key] = pd.concat(
            [events_dict[year][key] for year in ["2023", "2023BPix"] if key in events_dict[year]]
        )

    # store events_combined as a pickle file
    with processed_path.open("wb") as f:
        pd.to_pickle(events_combined, f)
    print(f"Events combined and saved to {processed_path}")
else:
    # directly load the processed file
    print(f"Loading events from {processed_path}...")
    with processed_path.open("rb") as f:
        events_combined = pd.read_pickle(f)
    print(f"Loaded events from {processed_path}")

if APPLY_Zto2Q_CORR:
    print("Applying Zto2Q corrections...")
    for year in ["2022All", "2023All"]:
        # apply corrections to the events
        corr = corr_dict[year.replace("All", "")]["GenZPtWeight"]
        GenZ_pt = events_combined[year]["Zto2Q"]["GenZPt"].values[:, 0]
        sf_nom = corr.evaluate(GenZ_pt, "nominal")
        sf_up = corr.evaluate(GenZ_pt, "stat_up")
        sf_down = corr.evaluate(GenZ_pt, "stat_down")
        events_combined[year]["Zto2Q"]["SF_GenZPt"] = sf_nom
        events_combined[year]["Zto2Q"]["SF_GenZPt_up"] = sf_up
        events_combined[year]["Zto2Q"]["SF_GenZPt_down"] = sf_down
        events_combined[year]["Zto2Q"]["finalWeight"] = (
            events_combined[year]["Zto2Q"]["finalWeight"] * sf_nom
        )
        events_combined[year]["Zto2Q"]["weight_GenZPtUp"] = (
            events_combined[year]["Zto2Q"]["finalWeight"] * sf_up
        )
        events_combined[year]["Zto2Q"]["weight_GenZPtDown"] = (
            events_combined[year]["Zto2Q"]["finalWeight"] * sf_down
        )
    print("Zto2Q corrections applied")

In [None]:
# further split Zto2Q and Wto2Q events into different categories
for year in ["2022All", "2023All"]:
    Zto2Q = events_combined[year]["Zto2Q"]
    matched = Zto2Q["bbFatJetVQQMatch0"] == 1
    is_ZBB = Zto2Q[("GenZBB", 0)]
    is_ZCC = Zto2Q[("GenZCC", 0)]
    is_ZQQ = ~(is_ZBB | is_ZCC)  # u, d, s quarks
    ZtoBB = is_ZBB & matched
    ZtoCC = is_ZCC & matched
    ZtoQQ = is_ZQQ & matched
    Z_unmatched = ~matched
    events_combined[year]["Zto2Q_BB"] = Zto2Q[ZtoBB]
    events_combined[year]["Zto2Q_CC"] = Zto2Q[ZtoCC]
    events_combined[year]["Zto2Q_QQ"] = Zto2Q[ZtoQQ]
    events_combined[year]["Zto2Q_unmatched"] = Zto2Q[Z_unmatched]

In [None]:
# Pass and fail regions
txbb_bins = [0.95, 0.975, 0.99, 1.0]
# pT bins
pt_bins = [350, 550, 10000]

txbb_bins = list(zip(txbb_bins[:-1], txbb_bins[1:]))
pt_bins = list(zip(pt_bins[:-1], pt_bins[1:]))

# Mass bins
m_low, m_high = 50, 150
bins = 5
n_mass_bins = int((m_high - m_low) / bins)

In [None]:
# placeholders
bkg_keys = ["Zto2Q_BB", "Zto2Q_CC", "Zto2Q_QQ", "Zto2Q_unmatched", "Wto2Q", "hbb", "qcd", "ttbar"]
sig_keys = ["Zto2Q_BB"]

jshift_keys = [""]
# TODO: add JES and JER shifts
# for var, ud in itertools.product(["JES", "JER"], ["up", "down"]):
#     jshift_keys.append(f"{var}_{ud}")

for year, events in events_combined.items():

    templates = {}

    for txbb_bin, (pt_low, pt_high) in itertools.product(txbb_bins, pt_bins):
        templ_dir = Path(f"templates_zbb_pt{pt_low}to{pt_high}/TXbb{txbb_bin[0]}to{txbb_bin[1]}")
        templ_dir.mkdir(parents=True, exist_ok=True)

        cutflows_dir = Path(f"{templ_dir}/cutflows/{year}")
        cutflows_dir.mkdir(parents=True, exist_ok=True)

        plot_dir = Path(f"{templ_dir}/{year}")
        plot_dir.mkdir(parents=True, exist_ok=True)

        for jshift in jshift_keys:
            if jshift == "":
                mass_branch = "bbFatJetParTmassVis0"
                pt_branch = "bbFatJetPt0"
            elif "jms" in jshift or "jmr" in jshift:
                mass_branch = f"bbFatJetParTmassVis0_{jshift}"
                pt_branch = f"bbFatJetPt0"
            elif "jes" in jshift or "jer" in jshift:
                mass_branch = "bbFatJetParTmassVis0"
                pt_branch = f"bbFatJetPt0_{jshift}"

            selection_regions = {
                "pass": postprocessing.Region(
                    cuts={
                        # edit these cuts to match your selection
                        pt_branch: [pt_low, pt_high],
                        mass_branch: [m_low, m_high],
                        "bbFatJetParTTXbb0": txbb_bin,
                    },
                    label="pass",
                ),
                "fail": postprocessing.Region(
                    cuts={
                        # edit these cuts to match your selection
                        pt_branch: [pt_low, pt_high],
                        mass_branch: [m_low, m_high],
                        "bbFatJetParTTXbb0": [0.1, min(0.9, txbb_bin[1])],
                    },
                    label="fail",
                ),
            }

            fit_shape_var = postprocessing.ShapeVar(
                mass_branch,
                r"$m_\mathrm{reg}$ (GeV)",
                [n_mass_bins, m_low, m_high],
                reg=True,
            )
            ttemps = postprocessing.get_templates(
                events,
                year=year,
                sig_keys=sig_keys,
                plot_sig_keys=sig_keys,
                selection_regions=selection_regions,
                shape_vars=[fit_shape_var],
                systematics={},
                template_dir=templ_dir,
                bg_keys=bkg_keys,
                plot_dir=plot_dir,
                weight_key="finalWeight",
                weight_shifts={},  # skip systematics for now
                plot_shifts=True,  # skip for time
                show=False,
                energy=13.6,
                jshift=jshift,
                blind=False,
            )
            templates = {**templates, **ttemps}