In [1]:
import numpy as np
from HH4b import utils
from HH4b import postprocessing
import xgboost as xgb
import importlib
import hist
import os
import mplhep as hep
import matplotlib.pyplot as plt
from HH4b.postprocessing.PostProcess import add_bdt_scores
import HH4b

plt.style.use(hep.style.ROOT)
from HH4b.hh_vars import mreg_strings, txbb_strings
import json

package_path = os.path.dirname(HH4b.__file__)

In [2]:
# automatically reloads imported files on edits
%load_ext autoreload
%autoreload 2

In [3]:
txbb_version = "glopart-v2"
bdt_model_name = "25Feb5_v13_glopartv2_rawmass"
bdt_config = "v13_glopartv2"
bin1_txbb = 0.945
bin1_bdt = 0.94
bin2_txbb = 0.85
bin2_bdt = 0.755
vbf_txbb = 0.8
vbf_bdt = 0.9825
presel_txbb = 0.3
data_dir = "24Sep25_v12v2_private_signal"
input_dir = f"/ceph/cms/store/user/cmantill/bbbb/skimmer/{data_dir}"

bdt_axis = hist.axis.Variable(list(np.arange(0.9, 1, 0.001)), name="BDT score")
txbb1_axis = hist.axis.Variable(list(np.arange(0.9, 1, 0.001)), name=r"Jet 1 $T_{Xbb}$")
txbb2_axis = hist.axis.Variable(list(np.arange(0, 1, 0.001)), name=r"Jet 2 $T_{Xbb}$")

In [8]:
def get_dataframe(events_dict, year, bdt_model_name, bdt_config):
    bdt_model = xgb.XGBClassifier()
    bdt_model.load_model(
        fname=f"{package_path}/boosted/bdt_trainings_run3/{bdt_model_name}/trained_bdt.model"
    )
    make_bdt_dataframe = importlib.import_module(
        f".{bdt_config}", package="HH4b.boosted.bdt_trainings_run3"
    )

    bdt_events_dict = {}
    for key in events_dict:
        events = events_dict[key]
        bdt_events = make_bdt_dataframe.bdt_dataframe(events)
        preds = bdt_model.predict_proba(bdt_events)
        # inference
        add_bdt_scores(bdt_events, preds)

        # extra variables
        bdt_events["H1PNetMass"] = events[mreg_strings[txbb_version]][0]
        bdt_events["H2PNetMass"] = events[mreg_strings[txbb_version]][1]
        bdt_events["H1Msd"] = events["bbFatJetMsd"][0]
        bdt_events["H1TXbb"] = events[txbb_strings[txbb_version]][0]
        bdt_events["H2TXbb"] = events[txbb_strings[txbb_version]][1]
        bdt_events["weight"] = events["finalWeight"].to_numpy()

        bdt_events["hlt"] = np.any(
            np.array(
                [events[trigger][0] for trigger in postprocessing.HLTs[year] if trigger in events]
            ),
            axis=0,
        )
        mask_hlt = bdt_events["hlt"] == 1

        # masks
        mask_presel = (
            (bdt_events["H1Msd"] > 40)
            & (bdt_events["H1Pt"] > 300)
            & (bdt_events["H2Pt"] > 250)
            & (bdt_events["H1TXbb"] > presel_txbb)
        )
        mask_mass = (
            (bdt_events["H1PNetMass"] > 60)
            & (bdt_events["H1PNetMass"] < 220)
            & (bdt_events["H2PNetMass"] > 60)
            & (bdt_events["H2PNetMass"] < 220)
        )
        bdt_events = bdt_events[(mask_mass) & (mask_hlt) & (mask_presel)]

        columns = [
            "bdt_score",
            "bdt_score_vbf",
            "H1TXbb",
            "H2TXbb",
            "H1Msd",
            "H1PNetMass",
            "H2PNetMass",
            "weight",
            "H1Pt",
            "H2Pt",
        ]
        bdt_events_dict[key] = bdt_events[columns]
    return bdt_events_dict

In [None]:
DATA_SAMPLES = ["JetMET", "Muon", "EGamma"]

samples_run3 = {
    "2022EE": {
        "hh4b": ["GluGlutoHHto4B_kl-1p00_kt-1p00_c2-0p00_TuneCP5_13p6TeV?"],
        "data": [f"{key}_Run" for key in DATA_SAMPLES],
    },
    "2022": {
        "hh4b": ["GluGlutoHHto4B_kl-1p00_kt-1p00_c2-0p00_TuneCP5_13p6TeV?"],
        "data": [f"{key}_Run" for key in DATA_SAMPLES],
    },
    "2023": {
        "hh4b": ["GluGlutoHHto4B_kl-1p00_kt-1p00_c2-0p00_TuneCP5_13p6TeV?"],
        "data": [f"{key}_Run" for key in DATA_SAMPLES],
    },
    "2023BPix": {
        "hh4b": ["GluGlutoHHto4B_kl-1p00_kt-1p00_c2-0p00_TuneCP5_13p6TeV?"],
        "data": [f"{key}_Run" for key in DATA_SAMPLES],
    },
}

bdt_events_dict_year = {}
for year in samples_run3:
    events = HH4b.postprocessing.load_run3_samples(
        input_dir=input_dir,
        year=year,
        txbb_version=txbb_version,
        samples_run3=samples_run3,
        reorder_txbb=True,
        scale_and_smear=True,
        load_systematics=True,
        mass_str=mreg_strings[txbb_version],
    )
    bdt_events_dict_year[year] = get_dataframe(events, year, bdt_model_name, bdt_config)

In [10]:
events_combined, scaled_by = postprocessing.combine_run3_samples(
    bdt_events_dict_year,
    ["data"],
    bg_keys=["ttbar"],
    scale_processes={},
    years_run3=bdt_events_dict_year.keys(),
)

In [None]:
events = events_combined["data"]

mask_bin1 = (events["H2TXbb"] > bin1_txbb) & (events["bdt_score"] > bin1_bdt)
mask_vbf = (~mask_bin1) & (events["H2TXbb"] > vbf_txbb) & (events["bdt_score_vbf"] > vbf_bdt)

mask_mass = ((events["H2PNetMass"] >= 60) & (events["H2PNetMass"] <= 110)) | (
    (events["H2PNetMass"] >= 155) & (events["H2PNetMass"] <= 220)
)
mask_fail = (events["H2TXbb"] < bin2_txbb) & (events["bdt_score"] > 0.03)
print(np.sum(mask_vbf & mask_mass & ~mask_fail))
events["H2PNetMass"][mask_vbf & mask_mass & ~mask_fail].hist(
    bins=np.arange(60, 230, 10), histtype="step"
)

In [None]:
labels = {
    "hh4b": "HH(4b)",
}

from HH4b.postprocessing import corrections
from HH4b.hh_vars import txbbsfs_decorr_pt_bins, txbbsfs_decorr_txbb_wps

txbb_sf = {}
txbb_sf_weight = {}
mask_bin1 = {}
mask_vbf = {}
mask_bin2 = {}
mask_bin3 = {}
for year in samples_run3:
    events = bdt_events_dict_year[year]["hh4b"]
    nevents = len(events)
    mask_bin1[year] = (events["H2TXbb"] > bin1_txbb) & (events["bdt_score"] > bin1_bdt)
    mask_vbf[year] = (
        (~mask_bin1[year]) & (events["H2TXbb"] > vbf_txbb) & (events["bdt_score_vbf"] > vbf_bdt)
    )
    mask_bin2[year] = (
        (~mask_bin1[year])
        & (~mask_vbf[year])
        & (
            ((events["H2TXbb"] > bin1_txbb) & (events["bdt_score"] > bin2_bdt))
            | ((events["H2TXbb"] > bin2_txbb) & (events["bdt_score"] > bin1_bdt))
        )
    )
    mask_bin3[year] = (
        (~mask_bin1[year])
        & (~mask_vbf[year])
        & (~mask_bin2[year])
        & ((events["H2TXbb"] > bin2_txbb) & (events["bdt_score"] > bin2_bdt))
    )
    txbb_sf[year] = corrections._load_txbb_sfs(
        year,
        "sf_glopart-v2_freezeSFs_trial20241011",
        txbbsfs_decorr_txbb_wps[txbb_version],
        txbbsfs_decorr_pt_bins[txbb_version],
        txbb_version,
    )
    txbb_range = [0.8, 1]
    pt_range = [200, 1000]
    txbb_sf_weight[year] = np.ones(nevents)
    for ijet in [1, 2]:
        txbb_sf_weight[year] *= corrections.restrict_SF(
            txbb_sf[year]["nominal"],
            events[f"H{ijet}TXbb"].to_numpy(),
            events[f"H{ijet}Pt"].to_numpy(),
            txbb_range,
            pt_range,
        )

bins = np.arange(0.5, 2.0, 0.05)
data1 = np.concatenate([txbb_sf_weight[year][mask_bin1[year]] for year in txbb_sf_weight])
datavbf = np.concatenate([txbb_sf_weight[year][mask_vbf[year]] for year in txbb_sf_weight])
data2 = np.concatenate([txbb_sf_weight[year][mask_bin2[year]] for year in txbb_sf_weight])
data3 = np.concatenate([txbb_sf_weight[year][mask_bin3[year]] for year in txbb_sf_weight])

hist1 = np.histogram(data1, bins=bins)
histvbf = np.histogram(datavbf, bins=bins)
hist2 = np.histogram(data2, bins=bins)
hist3 = np.histogram(data3, bins=bins)

plt.figure()
hep.histplot(
    [hist1, histvbf, hist2, hist3],
    stack=True,
    histtype="fill",
    label=[
        f"ggF category 1, mean={np.mean(data1):.2f}",
        f"VBF category, mean={np.mean(datavbf):.2f}",
        f"ggF category 2, mean={np.mean(data2):.2f}",
        f"ggF category 3, mean={np.mean(data3):.2f}",
    ],
)
plt.xlabel("TXbb SF event weight")
plt.ylabel("Events")
plt.xlim(0.5, 2.5)
plt.legend(title="ggF HH(4b)")
plt.tight_layout()
plt.savefig("TXbb_SF_dist.pdf")
plt.show()

In [None]:
labels = {
    "hh4b": "HH (4b)",
}


for key, events in events_combined.items():
    h_xbb1_bdt = hist.Hist(txbb1_axis, bdt_axis, storage=hist.storage.Weight())
    h_xbb1_bdt_bin1 = hist.Hist(txbb1_axis, bdt_axis, storage=hist.storage.Weight())
    h_xbb1_bdt_bin2 = hist.Hist(txbb1_axis, bdt_axis, storage=hist.storage.Weight())
    h_xbb1_bdt_bin3 = hist.Hist(txbb1_axis, bdt_axis, storage=hist.storage.Weight())

    h_xbb1 = hist.Hist(txbb1_axis, storage=hist.storage.Weight())
    h_xbb1_bin1 = hist.Hist(txbb1_axis, storage=hist.storage.Weight())
    h_xbb1_bin2 = hist.Hist(txbb1_axis, storage=hist.storage.Weight())
    h_xbb1_bin3 = hist.Hist(txbb1_axis, storage=hist.storage.Weight())

    mask_t2xbb = events["H2TXbb"] > bin1_txbb
    mask_bin1 = (events["H2TXbb"] > bin1_txbb) & (events["bdt_score"] > bin1_bdt)
    mask_bin2 = (~mask_bin1) & (
        ((events["H2TXbb"] > bin1_txbb) & (events["bdt_score"] > bin2_bdt))
        | ((events["H2TXbb"] > bin2_txbb) & (events["bdt_score"] > bin1_bdt))
    )
    mask_bin3 = (
        (~mask_bin1)
        & (~mask_bin2)
        & ((events["H2TXbb"] > bin2_txbb) & (events["bdt_score"] > bin2_bdt))
    )

    h_xbb1_bdt.fill(
        events["H1TXbb"],
        events["bdt_score"],
    )
    h_xbb1_bdt_bin1.fill(
        events["H1TXbb"][mask_bin1],
        events["bdt_score"][mask_bin1],
    )
    h_xbb1_bdt_bin2.fill(
        events["H1TXbb"][mask_bin2],
        events["bdt_score"][mask_bin2],
    )
    h_xbb1_bdt_bin3.fill(
        events["H1TXbb"][mask_bin3],
        events["bdt_score"][mask_bin3],
    )

    h_xbb1.fill(events["H1TXbb"])
    h_xbb1_bin1.fill(events["H1TXbb"][mask_bin1])
    h_xbb1_bin2.fill(events["H1TXbb"][mask_bin2])
    h_xbb1_bin3.fill(events["H1TXbb"][mask_bin3])

    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    hep.hist2dplot(h_xbb1_bdt, ax=ax)
    ax.set_title(key)

    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    hep.hist2dplot(h_xbb1_bdt_bin1, ax=ax)
    ax.set_title(f"{key}, ggF category 1")

    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    hep.hist2dplot(h_xbb1_bdt_bin2, ax=ax)
    ax.set_title(f"{key}, ggF category 2")

    fig, ax = plt.subplots(1, 1, figsize=(6, 5))
    hep.hist2dplot(h_xbb1_bdt_bin3, ax=ax)
    ax.set_title(f"{key}, ggF category 3")

    fig = plt.figure(figsize=(10, 8))
    main_ax_artists, sublot_ax_arists = h_xbb1_bin1.plot_ratio(
        h_xbb1,
        rp_ylabel=r"Efficiency",
        rp_num_label="ggF Category 1",
        rp_denom_label="Preselection",
        rp_uncert_draw_type="line",  # line or bar
        rp_uncertainty_type="efficiency",
    )

    fig = plt.figure(figsize=(10, 8))
    main_ax_artists, sublot_ax_arists = h_xbb1_bin2.plot_ratio(
        h_xbb1,
        rp_ylabel=r"Efficiency",
        rp_num_label="ggF Category 2",
        rp_denom_label="Preselection",
        rp_uncert_draw_type="line",  # line or bar
        rp_uncertainty_type="efficiency",
    )

    fig = plt.figure(figsize=(10, 8))
    main_ax_artists, sublot_ax_arists = h_xbb1_bin3.plot_ratio(
        h_xbb1,
        rp_ylabel=r"Efficiency",
        rp_num_label="ggF Category 3",
        rp_denom_label="Preselection",
        rp_uncert_draw_type="line",  # line or bar
        rp_uncertainty_type="efficiency",
    )

In [None]:
plt.figure()
h, _, _ = plt.hist(
    np.concatenate((events["H2TXbb"], events["H1TXbb"])),
    bins=np.arange(0.8, 1.002, 0.002),
    histtype="step",
    label="ParticleNet-Legacy",
)
plt.plot([0.998, 0.998], [0, np.max(h)], label="WP1")
plt.plot([0.995, 0.995], [0, np.max(h)], label="WP2")
plt.plot([0.99, 0.99], [0, np.max(h)], label="WP3")
plt.plot([0.975, 0.975], [0, np.max(h)], label="WP4")
plt.plot([0.95, 0.95], [0, np.max(h)], label="WP5")
plt.plot([0.92, 0.92], [0, np.max(h)], label="WP6")
plt.legend(title="HH(4b), preselection")
plt.show()

In [None]:
plt.figure()
h, _, _ = plt.hist(
    np.concatenate((events["H2TXbb"], events["H1TXbb"])),
    bins=np.arange(0.3, 1.001, 0.001),
    histtype="step",
    label="GloParT-v2",
)
plt.plot([0.99, 0.99], [0, np.max(h)], label="WP1")
plt.plot([0.97, 0.97], [0, np.max(h)], label="WP2")
plt.plot([0.94, 0.94], [0, np.max(h)], label="WP3")
plt.plot([0.9, 0.9], [0, np.max(h)], label="WP4")
plt.plot([0.8, 0.8], [0, np.max(h)], label="WP5")
plt.legend(title="HH(4b), preselection")
plt.show()

In [None]:
np.quantile(events["H1TXbb"][mask_bin1], q=[0.16, 0.33, 0.5, 0.84])

In [None]:
year = "2023BPix"
with open(f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jun14.json") as f:
    txbb_sf_old = json.load(f)
with open(f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jun26_freezeSFs.json") as f:
    txbb_sf = json.load(f)
with open(
    f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jun27_freezeSFs_finerWPs.json"
) as f:
    txbb_sf_fine = json.load(f)
with open(
    f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jun29_freezeSFs_zoomedInWPs.json"
) as f:
    txbb_sf_zoom = json.load(f)
with open(
    f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jul3_freezeSFs_newPt.json"
) as f:
    txbb_sf_new_pt = json.load(f)

In [None]:
txbb_sf_new = {}

ptbins = np.array([200, 250, 300, 400, 500, 100000])

for i in range(len(ptbins) - 1):
    txbb_sf_new[f"WP3_pt{ptbins[i]}to{ptbins[i+1]}"] = txbb_sf_zoom[
        f"WP2_pt{ptbins[i]}to{ptbins[i+1]}"
    ]
    txbb_sf_new[f"WP4_pt{ptbins[i]}to{ptbins[i+1]}"] = txbb_sf_fine[
        f"WP2_pt{ptbins[i]}to{ptbins[i+1]}"
    ]
    txbb_sf_new[f"WP5_pt{ptbins[i]}to{ptbins[i+1]}"] = txbb_sf_fine[
        f"WP3_pt{ptbins[i]}to{ptbins[i+1]}"
    ]
    txbb_sf_new[f"WP6_pt{ptbins[i]}to{ptbins[i+1]}"] = txbb_sf_fine[
        f"WP4_pt{ptbins[i]}to{ptbins[i+1]}"
    ]

ptbins_new = np.array([200, 400, 100000])

for i in range(len(ptbins_new) - 1):
    txbb_sf_new[f"WP1_pt{ptbins_new[i]}to{ptbins_new[i+1]}"] = txbb_sf_new_pt[
        f"WP1_pt{ptbins_new[i]}to{ptbins_new[i+1]}"
    ]
    txbb_sf_new[f"WP2_pt{ptbins_new[i]}to{ptbins_new[i+1]}"] = txbb_sf_new_pt[
        f"WP2_pt{ptbins_new[i]}to{ptbins_new[i+1]}"
    ]

# with open(
#     f"{package_path}/corrections/data/txbb_sfs/{year}/sf_txbbv11_Jul3_freezeSFs_combinedWPs.json",
#     "w",
# ) as f:
#     json.dump(txbb_sf_new, f, indent=4)

In [None]:
y_new, yerr_low_new, yerr_high_new = {}, {}, {}
wps = {
    "WP1": [0.998, 1],
    "WP2": [0.995, 0.998],
    "WP3": [0.99, 0.995],
    "WP4": [0.975, 0.99],
    "WP5": [0.95, 0.975],
    "WP6": [0.92, 0.95],
}
for wp in range(0, 2):
    y_new[wp] = []
    yerr_low_new[wp] = []
    yerr_high_new[wp] = []
    for i in range(len(ptbins_new) - 1):
        y_new[wp].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[i]}to{ptbins_new[i+1]}"]["final"]["central"]
        )
        yerr_low_new[wp].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[i]}to{ptbins_new[i+1]}"]["final"]["low"]
        )
        yerr_high_new[wp].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[i]}to{ptbins_new[i+1]}"]["final"]["high"]
        )
for wp in range(2, 6):
    y_new[wp] = []
    yerr_low_new[wp] = []
    yerr_high_new[wp] = []
    for i in range(len(ptbins) - 1):
        y_new[wp].append(txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["central"])
        yerr_low_new[wp].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["low"]
        )
        yerr_high_new[wp].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["high"]
        )

plt.figure()
# horizaontal line at 1
plt.axhline(1, color="gray", linestyle="--", alpha=0.5)
# vertical line at each ptbin
for ptbin in ptbins[:-1]:
    plt.axvline(ptbin, color="gray", linestyle="-", alpha=0.5)
for wp in range(0, 2):
    plt.errorbar(
        y=y_new[wp],
        x=ptbins_new[:-1] + (wp + 1) * 5,
        yerr=[yerr_low_new[wp], yerr_high_new[wp]],
        fmt="o",
        label=f"WP{wp+1} {wps[f'WP{wp+1}']}",
    )
for wp in range(2, 6):
    plt.errorbar(
        y=y_new[wp],
        x=ptbins[:-1] + (wp + 1) * 5,
        yerr=[yerr_low_new[wp], yerr_high_new[wp]],
        fmt="o",
        label=f"WP{wp+1} {wps[f'WP{wp+1}']}",
    )
plt.xlabel("$p_T (j)$ [GeV]")
plt.ylabel("SF (flvB)")
plt.ylim([0, 2])
plt.legend(title=year)
plt.savefig(f"new_txbb_sf_{year}.pdf")

In [None]:
y_new, yerr_low_new, yerr_high_new = [], [], []
wps = {
    "WP1": [0.998, 1],
    "WP2": [0.995, 0.998],
    "WP3": [0.99, 0.995],
    "WP4": [0.975, 0.99],
    "WP5": [0.95, 0.975],
    "WP6": [0.92, 0.95],
}
for i in range(len(ptbins) - 1):
    y_new.append([])
    yerr_low_new.append([])
    yerr_high_new.append([])
    for wp in reversed(range(2, 6)):
        y_new[i].append(txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["central"])
        yerr_low_new[i].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["low"]
        )
        yerr_high_new[i].append(
            txbb_sf_new[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["high"]
        )
    for wp in reversed(range(0, 2)):
        for j in range(len(ptbins_new) - 1):
            if ptbins[i] >= ptbins_new[j] and ptbins[i + 1] <= ptbins_new[j + 1]:
                y_new[i].append(
                    txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[j]}to{ptbins_new[j+1]}"]["final"][
                        "central"
                    ]
                )
                yerr_low_new[i].append(
                    txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[j]}to{ptbins_new[j+1]}"]["final"]["low"]
                )
                yerr_high_new[i].append(
                    txbb_sf_new[f"WP{wp+1}_pt{ptbins_new[j]}to{ptbins_new[j+1]}"]["final"]["high"]
                )
                break

# make 5 subfigures
fig, axs = plt.subplots(5, 1, figsize=(8, 30))
for i in range(len(ptbins) - 1):
    plt.sca(axs[i])
    plt.axhline(1, color="gray", linestyle="--", alpha=0.5)
    x = [np.mean(wps[wp]) for wp in ["WP6", "WP5", "WP4", "WP3", "WP2", "WP1"]]
    xerr = [(wps[wp][1] - wps[wp][0]) / 2 for wp in ["WP6", "WP5", "WP4", "WP3", "WP2", "WP1"]]
    plt.errorbar(
        y=y_new[i][:-1],
        x=x[:-1],
        xerr=xerr[:-1],
        yerr=[yerr_low_new[i][:-1], yerr_high_new[i][:-1]],
        fmt="o",
        label=f"$[{ptbins[i]}, {ptbins[i+1]}]$ GeV".replace("100000", "\inf"),
    )
    j = 0 if i < 3 else 1
    plt.errorbar(
        y=y_new[i][-2:],
        x=x[-2:],
        xerr=xerr[-2:],
        yerr=[yerr_low_new[i][-2:], yerr_high_new[i][-2:]],
        fmt="o",
        label=f"$[{ptbins_new[j]}, {ptbins_new[j+1]}]$ GeV".replace("100000", "\inf"),
    )
    extended_wps = np.array([0.92, 0.95, 0.975, 0.99, 0.995, 0.998, 1])
    extended_y = np.array(y_new[i] + [y_new[i][-1]])
    extended_yerr_low = np.array(yerr_low_new[i] + [yerr_low_new[i][-1] * 3])
    extended_yerr_high = np.array(yerr_high_new[i] + [yerr_high_new[i][-1] * 3])
    plt.fill_between(
        extended_wps[:-2],
        extended_y[:-2] - extended_yerr_low[:-2],
        extended_y[:-2] + extended_yerr_high[:-2],
        alpha=0.2,
        step="post",
    )
    plt.fill_between(
        extended_wps[-3:-1],
        extended_y[-3] - extended_yerr_low[-3],
        extended_y[-3] + extended_yerr_high[-3],
        alpha=0.2,
    )
    # reset color to orange
    # print(plt.rcParams['axes.prop_cycle'].by_key()['color'])
    # plt.gca().set_prop_cycle(None)
    plt.plot()
    plt.fill_between(
        extended_wps[-2:],
        extended_y[-2:] - extended_yerr_low[-2:],
        extended_y[-2:] + extended_yerr_high[-2:],
        alpha=0.2,
        color="#f89c20",
    )
    plt.xlabel("$T_{Xbb}$")
    plt.ylabel("SF (flvB)")
    plt.ylim([0, 2])
    plt.xlim([0.92, 1])
    plt.legend(title=year)
plt.savefig(f"new_txbb_sf_3x_{year}.pdf")

In [None]:
y_old, yerr_low_old, yerr_high_old = [], [], []
wps = {
    "WP1": [0.975, 1],
    "WP2": [0.95, 0.975],
    "WP3": [0.92, 0.95],
}
for wp in range(0, 3):
    y_old.append([])
    yerr_low_old.append([])
    yerr_high_old.append([])
    for i in range(len(ptbins) - 1):
        y_old[wp].append(txbb_sf[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["central"])
        yerr_low_old[wp].append(txbb_sf[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["low"])
        yerr_high_old[wp].append(txbb_sf[f"WP{wp+1}_pt{ptbins[i]}to{ptbins[i+1]}"]["final"]["high"])

plt.figure()
# horizaontal line at 1
plt.axhline(1, color="gray", linestyle="--", alpha=0.5)
# vertical line at each ptbin
for ptbin in ptbins[:-1]:
    plt.axvline(ptbin, color="gray", linestyle="-", alpha=0.5)
for wp in range(0, 3):
    plt.errorbar(
        y=y_old[wp],
        x=ptbins[:-1] + (wp + 3) * 5,
        yerr=[yerr_low_old[wp], yerr_high_old[wp]],
        fmt="o",
        label=f"WP{wp+1} {wps[f'WP{wp+1}']}",
    )
plt.xlabel("$p_T (j)$ [GeV]")
plt.ylabel("SF (flvB)")
plt.ylim([0, 2])
plt.legend(title=year)
plt.savefig(f"old_txbb_sf_{year}.pdf")