In [None]:
# automatically reloads imported files on edits
from __future__ import annotations

%load_ext autoreload
%autoreload 2

In [None]:
import math
import os
from copy import deepcopy

import numpy as np
import pandas as pd

from HH4b import utils
from HH4b.matching_study import predict_spanet_hhh

In [None]:
import hist
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import mplhep as hep
from matplotlib.lines import Line2D

hep.style.use(["CMS", "firamath"])

formatter = mticker.ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-3, 3))
plt.rcParams.update({"font.size": 12})
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["grid.color"] = "#CCCCCC"
plt.rcParams["grid.linewidth"] = 0.5
plt.rcParams["figure.edgecolor"] = "none"

In [None]:
MAIN_DIR = "../../../"
path_to_dir = f"{MAIN_DIR}/../data/matching/23Nov18_WSel_v9_private/"
year = "2018"

# make plot and template directory
date = "23Nov17"
plot_dir = f"{MAIN_DIR}/plots/PostProcessing/{date}/{year}"
template_dir = f"templates/{date}/"
_ = os.system(f"mkdir -p {plot_dir}")
_ = os.system(f"mkdir -p {template_dir}/cutflows/{year}")

from HH4b.hh_vars import samples

samples = deepcopy(samples[year])
samples = {key: samples[key] for key in ["hh4b", "qcd", "ttbar", "vhtobb"]}
# only use hadronic
samples["ttbar"] = ["TTToHadronic_13TeV"]
sample_dirs = {path_to_dir: samples}

filters = None
events_dict = {}
for input_dir, samples in sample_dirs.items():
    events_dict = {
        **events_dict,
        # this function will load files (only the columns selected), apply filters and compute a weight per event
        **utils.load_samples(
            input_dir,
            samples,
            year,
            filters=filters,
        ),
    }

In [None]:
import onnxruntime

sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = 23
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL
session = onnxruntime.InferenceSession(
    f"{MAIN_DIR}/../data/spanet-inference/spanet_pnet_all_vars_v0.onnx", sess_options
)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
output_nodes = session.get_outputs()
output_names = [node.name for node in output_nodes]

In [None]:
session_assignment = onnxruntime.InferenceSession(
    f"{MAIN_DIR}/../data/spanet-inference/spanet_categorisation_v6.onnx", sess_options
)
output_nodes_assignment = session_assignment.get_outputs()
output_names_assignment = [node.name for node in output_nodes_assignment]

In [None]:
pairs_dict = {}
probs_dict = {}

In [None]:
assign_probs_dict = {}
max_probs_dict = {}

In [None]:
key = "vhtobb-pt300-m50"
events_key = "vhtobb"
events = events_dict[events_key]
input_dict = predict_spanet_hhh.build_inputs(events, MIN_FJPT=300, MIN_FJMASS=50)
output_values = session.run(output_names, input_dict)
probs_dict[key] = predict_spanet_hhh.get_values_from_output(output_values)
pairs_dict[key] = predict_spanet_hhh.get_pairs_hhh(output_values, events)
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
key = "vhtobb"
events_key = "vhtobb"
events = events_dict[events_key]
input_dict = predict_spanet_hhh.build_inputs(events)
output_values = session.run(output_names, input_dict)
probs_dict[key] = predict_spanet_hhh.get_values_from_output(output_values)
pairs_dict[key] = predict_spanet_hhh.get_pairs_hhh(output_values, events)
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
key = "hh4b"
events_key = "hh4b"
events = events_dict[events_key]
input_dict = predict_spanet_hhh.build_inputs(events, MIN_FJPT=200)
output_values = session.run(output_names, input_dict)
probs_dict[key] = predict_spanet_hhh.get_values_from_output(output_values)
pairs_dict[key] = predict_spanet_hhh.get_pairs_hhh(output_values, events)
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
key = "hh4b-pt300"
events_key = "hh4b"
events = events_dict[events_key]
input_dict = predict_spanet_hhh.build_inputs(events, MIN_FJPT=300)
output_values = session.run(output_names, input_dict)
probs_dict[key] = predict_spanet_hhh.get_values_from_output(output_values)
pairs_dict[key] = predict_spanet_hhh.get_pairs_hhh(output_values, events)
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
key = "hh4b-pt300-m50"
events_key = "hh4b"
events = events_dict[events_key]
input_dict = predict_spanet_hhh.build_inputs(events, MIN_FJPT=300, MIN_FJMASS=50)
output_values = session.run(output_names, input_dict)
probs_dict[key] = predict_spanet_hhh.get_values_from_output(output_values)
pairs_dict[key] = predict_spanet_hhh.get_pairs_hhh(output_values, events)
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
# takes 13 min!
key = "qcd"
input_dict = predict_spanet_hhh.build_inputs(events_dict[key])
output_values = session.run(output_names, input_dict)

In [None]:
key = "qcd"
probs = predict_spanet_hhh.get_values_from_output(output_values)
pairs_pd = predict_spanet_hhh.get_pairs_hhh(output_values, events_dict[key])

pairs_dict[key] = pairs_pd
probs_dict[key] = probs

In [None]:
key = "qcd"
input_dict = predict_spanet_hhh.build_inputs(events_dict[key])
output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
assign_probs_dict[key], max_probs_dict[key] = predict_spanet_hhh.get_values_from_assignment_output(
    output_values_assignment
)

In [None]:
# save to csv
odir = "Nov27"
os.system(f"mkdir -p {MAIN_DIR}/../data/spanet/{odir}")
for key in ["qcd", "hh4b-pt300"]:
    pairs_dict[key].to_csv(f"{MAIN_DIR}/../data/spanet/{odir}/pairs_spanet_hhh_{key}.csv")

In [None]:
import pickle

odir = "Nov27"

for key in ["hh4b"]:
    pairs_dict[key].to_csv(f"{MAIN_DIR}/../data/spanet/{odir}/pairs_spanet_hhh_{key}.csv")
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/probs_{key}.pkl", "wb") as f:
        pickle.dump(probs_dict[key], f)
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/assign_probs_{key}.pkl", "wb") as f:
        pickle.dump(assign_probs_dict[key], f)
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/max_assign_probs_{key}.pkl", "wb") as f:
        pickle.dump(max_probs_dict[key], f)

In [None]:
def slice_it(li, cols=2):
    start = 0
    for i in range(cols):
        stop = start + len(li[i::cols])
        yield li[start:stop]
        start = stop


key = "ttbar"
# do this in batches
nevents = len(events_dict[key])
# pairs_pd_all = []
# probs_all = []
assign_probs_all = []
max_probs_all = []
j = 0
for i in slice_it(range(nevents), 15):
    print(i)
    if j >= 1:
        break
    events = events_dict[key][i[0] : i[-1]]
    """
    input_dict = predict_spanet_hhh.build_inputs(events)
    output_values = session.run(output_names, input_dict)
    probs = predict_spanet_hhh.get_values_from_output(output_values)
    pairs_pd = predict_spanet_hhh.get_pairs_hhh(output_values, events)
    output_values_assignment = session_assignment.run(output_names_assignment, input_dict)
    assign_probs, max_probs = predict_spanet_hhh.get_values_from_assignment_output(output_values_assignment)

    #probs_all.append(probs)
    #pairs_pd_all.append(pairs_pd)
    assign_probs_all.append(assign_probs)
    max_probs_all.append(max_probs)
    """
    j = j + 1

In [None]:
probs_dict["ttbar"] = probs_all[0]
pairs_dict["ttbar"] = pairs_pd_all[0]

import pickle

odir = "Nov27"
os.system(f"mkdir -p {MAIN_DIR}/../data/spanet/{odir}")
for key in ["ttbar"]:
    pairs_dict[key].to_csv(f"{MAIN_DIR}/../data/spanet/{odir}/pairs_spanet_hhh_{key}.csv")
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/probs_{key}.pkl", "wb") as f:
        pickle.dump(probs_dict[key], f)

In [None]:
assign_probs_dict["ttbar"] = assign_probs_all[0]
max_probs_dict["ttbar"] = max_probs_all[0]

for key in ["ttbar"]:
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/assign_probs_{key}.pkl", "wb") as f:
        pickle.dump(assign_probs_dict[key], f)
    with open(f"{MAIN_DIR}/../data/spanet/{odir}/max_assign_probs_{key}.pkl", "wb") as f:
        pickle.dump(max_probs_dict[key], f)

# Compare mass reconstruction by pT cut

In [None]:
spanet_higgsmass_axis = hist.axis.Regular(40, 0, 250, name="mass", label="SPANET Higgs mass")
higgs_axis = hist.axis.StrCategory([], name="higgs", growth=True)
sample_axis = hist.axis.StrCategory([], name="sample", growth=True)
probs_axis = hist.axis.IntCategory(
    range(10), name="prob", label="SPANET max probability", growth=True
)
cat_axis = hist.axis.IntCategory(range(4), name="cat", label="GEN matching", growth=True)
spanet_discr_axis = hist.axis.Regular(40, 0, 1, name="discr", label="SPANET Prob Discriminator")
class_axis = hist.axis.StrCategory([], name="class", growth=True)

In [None]:
max_probs_dict["hh4b"]

In [None]:
max_probs_dict["hh4b-pt300"]

In [None]:
max_probs_dict["hh4b-pt300"]

In [None]:
h = hist.Hist(spanet_higgsmass_axis, higgs_axis, probs_axis, sample_axis)
for key in [
    "hh4b",
    # "hh4b-pt300",
    "hh4b-pt300-m50",
]:
    h.fill(pairs_dict[key]["higgs_1_mass"], "h1", max_probs_dict[key], key)
    h.fill(pairs_dict[key]["higgs_2_mass"], "h2", max_probs_dict[key], key)

linestyle_by_class = {
    "h1": "solid",
    "h2": "dashed",
}
color_by_class = {
    5: {
        "h1": "green",
        "h2": "green",
    },
    6: {
        "h1": "b",
        "h2": "b",
    },
}
probs_name = {5: "2bh0h", 6: "1bh1h"}
titles = {
    "hh4b": r"FatJet p$_T$>200",
    "hh4b-pt300": r"FatJet p$_T$>300",
    "hh4b-pt300-m50": r"FatJet p$_T$>300 m>50",
}
fig, ax = plt.subplots(2, 2, figsize=(12, 5), sharex=True)
for i, sample in enumerate(["hh4b", "hh4b-pt300-m50"]):
    for j, prob in enumerate(color_by_class.keys()):
        legend_elements = []
        for key in linestyle_by_class:
            hep.histplot(
                h[{"higgs": key, "prob": prob, "sample": sample}],
                density=True,
                lw=2,
                ax=ax[i, j],
                ls=linestyle_by_class[key],
                color=color_by_class[prob][key],
            )
            legend_elements.append(
                Line2D(
                    [0],
                    [0],
                    ls=linestyle_by_class[key],
                    lw=2,
                    label=key,
                    color=color_by_class[prob][key],
                )
            )
        ax[i, j].legend(handles=legend_elements, title=f"{probs_name[prob]} {titles[sample]}")
        ax[i, j].set_ylabel("Density")
        ax[i, j].set_xlabel("")
ax[1, 0].set_xlabel("SPANet Higgs mass")
ax[1, 1].set_xlabel("SPANet Higgs mass")

In [None]:
h = hist.Hist(probs_axis)
h.fill(max_probs_dict["hh4b-pt300"])

probs_str = [
    "0bh0h",
    "3bh0h",
    "2bh1h",
    "1bh2h",
    "0bh3h",
    "2bh0h",
    "1bh1h",
    "0bh2h",
    "1bh0h",
    "0bh1h",
]

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
hep.histplot(
    h,
    lw=2,
)
ax.set_ylabel("Events")
ax.set_yscale("log")
xticks = [i + 0.5 for i in range(10)]
ax.set_xticks(xticks, probs_str, size="small", rotation="vertical")
ax.set_title("HH4b sample")

In [None]:
spanet_discr_axis = hist.axis.Regular(40, 0, 1, name="discr", label="SPANET Assignment")
class_axis = hist.axis.StrCategory([], name="class", growth=True)
h = hist.Hist(spanet_discr_axis, class_axis)
for key in assign_probs_dict["hh4b"].keys():
    h.fill(assign_probs_dict["hh4b"][key], key)

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
legend_elements = []
color_by_prob = {
    "3bh0h": "grey",
    "2bh1h": "grey",
    "1bh2h": "grey",
    "0bh3h": "grey",
    "2bh0h": "blue",
    "1bh1h": "red",
    "0bh2h": "green",
    "1bh0h": "black",
    "0bh1h": "black",
    "0bh0h": "black",
}
for key in color_by_prob:
    hep.histplot(
        h[{"class": key}],
        density=True,
        lw=2,
        color=color_by_prob[key],
    )
    legend_elements.append(Line2D([0], [0], color=color_by_prob[key], lw=2, label=key))
ax.legend(handles=legend_elements)
ax.set_ylabel("Density")
ax.set_yscale("log")
ax.set_title("HH4b sample")

# Confusion matrix

- need to first define each of the categories above..

In [None]:
events = events_dict["hh4b"]

indexak8 = events["ak8FatJetHiggsMatchIndex"].to_numpy()
indexak4 = events["ak4JetHiggsMatchIndex"].to_numpy()
# indexak4 = events["ak4JetOutsideHiggsMatchIndex"].to_numpy()
nbh1ak8 = events["ak8FatJetNumBMatchedH1"].to_numpy()
nbh2ak8 = events["ak8FatJetNumBMatchedH2"].to_numpy()

h1ak8nb2 = (indexak8 == 0) & (nbh1ak8 == 2)
h2ak8nb2 = (indexak8 == 1) & (nbh2ak8 == 2)
h1m1ak8b2 = h1ak8nb2.sum(axis=1) == 1
h2m1ak8b2 = h2ak8nb2.sum(axis=1) == 1

h1ak8nb1 = (indexak8 == 0) & (nbh1ak8 == 1)
h2ak8nb1 = (indexak8 == 1) & (nbh2ak8 == 1)
h1m1ak8b1 = h1ak8nb1.sum(axis=1) == 1
h2m1ak8b1 = h2ak8nb1.sum(axis=1) == 1

h1ak4 = indexak4 == 0
h2ak4 = indexak4 == 1
num_ak4m2h1 = h1ak4.sum(axis=1)
num_ak4m2h2 = h2ak4.sum(axis=1)
h1m2ak4 = num_ak4m2h1 == 2
h2m2ak4 = num_ak4m2h2 == 2

h1ak8 = indexak8 == 0
h2ak8 = indexak8 == 1

gen_2bh0h = h1m1ak8b2 & h2m1ak8b2
gen_1bh1h = ((h1m1ak8b2 & (h2m1ak8b1 | h2m2ak4)) | (h2m1ak8b2 & (h1m1ak8b1 | h1m2ak4))) & ~(
    gen_2bh0h
)
gen_0bh2h = (h1m2ak4 & h2m2ak4) & ~(gen_2bh0h) & ~(gen_1bh1h)
gen_others = ~(gen_2bh0h | gen_1bh1h | gen_0bh2h)

In [None]:
np.stack([gen_2bh0h, gen_1bh1h, gen_0bh2h, gen_others], axis=1)

In [None]:
cat_matched = np.argmax(np.stack([gen_2bh0h, gen_1bh1h, gen_0bh2h, gen_others], axis=1), axis=1)

In [None]:
h = hist.Hist(cat_axis)
h.fill(cat_matched)

cat_str = ["2bh0h", "1bh1h", "0bh2h", "other"]

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
hep.histplot(
    h,
    lw=2,
)
ax.set_ylabel("Events")
xticks = [i + 0.5 for i in range(4)]
ax.set_xticks(xticks, cat_str, size="small", rotation="vertical")
ax.set_title("HH4b sample")

In [None]:
h = hist.Hist(cat_axis, probs_axis)
h.fill(cat_matched, max_probs_dict["hh4b-pt300-m50"])

cat_str = ["2bh0h", "1bh1h", "0bh2h", "other"]
probs_str = [
    "0bh0h",
    "3bh0h",
    "2bh1h",
    "1bh2h",
    "0bh3h",
    "2bh0h",
    "1bh1h",
    "0bh2h",
    "1bh0h",
    "0bh1h",
]

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
hep.hist2dplot(
    h / h.sum(),
)
values, bins_x, bins_y = (h / h.sum()).to_numpy()
for i in range(len(bins_x) - 1):
    for j in range(len(bins_y) - 1):
        if not math.isnan(values[i, j]):
            ax.text(
                (bins_x[i] + bins_x[i + 1]) / 2,
                (bins_y[j] + bins_y[j + 1]) / 2,
                values[i, j].round(2),
                color="black",
                ha="center",
                va="center",
                fontsize=12,
            )
xticks = [i + 0.5 for i in range(4)]
yticks = [i + 0.5 for i in range(10)]
ax.set_xticks(xticks, cat_str, size="small", rotation="vertical")
ax.set_yticks(yticks, probs_str, size="small", rotation="horizontal")
ax.set_title("HH4b sample")

normalizing in the vertical direction

In [None]:
h = hist.Hist(cat_axis, probs_axis)
h.fill(cat_matched, max_probs_dict["hh4b-pt300-m50"])

cat_str = ["2bh0h", "1bh1h", "0bh2h", "other"]
probs_str = [
    "0bh0h",
    "3bh0h",
    "2bh1h",
    "1bh2h",
    "0bh3h",
    "2bh0h",
    "1bh1h",
    "0bh2h",
    "1bh0h",
    "0bh1h",
]

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
values = []
for cat in range(4):
    x = h[{"cat": cat}] / h[{"cat": cat}].sum()
    val, bins_x = x.to_numpy()
    values.append(val)

all_values = np.stack(values)
print(all_values)
hep.hist2dplot(all_values, cmin=1e-2, cmax=1)
values, bins_x, bins_y = (h / h.sum()).to_numpy()
print(values)
for i in range(len(bins_x) - 1):
    for j in range(len(bins_y) - 1):
        if not math.isnan(all_values[i, j]) and (all_values[i, j] > 1e-2):
            ax.text(
                (bins_x[i] + bins_x[i + 1]) / 2,
                (bins_y[j] + bins_y[j + 1]) / 2,
                all_values[i, j].round(2),
                color="black",
                ha="center",
                va="center",
                fontsize=12,
            )
xticks = [i + 0.5 for i in range(4)]
yticks = [i + 0.5 for i in range(10)]
ax.set_xticks(xticks, cat_str, size="small", rotation="vertical")
ax.set_yticks(yticks, probs_str, size="small", rotation="horizontal")
ax.set_title("HH4b sample")

# Mass reconstruction comparison

- min dHH method
- min chi2
- pair btags

first_bb_j1 = jets_outside[:, 0]
first_bb_j2 = jets_outside[:, 1]
first_bb_dijet = first_bb_j1 + first_bb_j2

fatjet_0 = fatjets[:, 0]

In [None]:
events = events_dict["hh4b"]

# jets outside the fatjet - sorted by b-score
# also, abs(eta) < 2.5
jets_outside = utils.make_vector(events, "ak4JetOutside")
jets = utils.make_vector(events, "ak4Jet")
# fatjets sorted by xbb
fatjets = utils.make_vector(events, "ak8FatJet", mstring="PNetMass")

# H1 candidate
h1 = fatjets[:, 0]
h1_xbb = events.ak8FatJetPNetXbb[0]

# H2 candidate
# ak4 jet (outside) with highest b-tagging score (btagDeepFlavB)
j3 = jets_outside[:, 0]
j4 = jets_outside[:, 1]
j3_btag = events.ak4JetOutsidebtagDeepFlavB[0]
j4_btag = events.ak4JetOutsidebtagDeepFlavB[1]
h2_btag = j3 + j4

# HH candidate
hh_btag = h2_btag + h1

# Second fatjet
fj2 = fatjets[:, 1]
fj2_xbb = events.ak8FatJetPNetXbb[1]

# H2 candidate
# ak4 jets with min dHH (large dR)
# first_bb_pair = events.ak4JetPair0.to_numpy()
# second_bb_pair = events.ak4JetPair1.to_numpy()
# first_bb_j1 = jets[np.arange(len(jets.pt)), first_bb_pair[:, 0]]
# first_bb_j2 = jets[np.arange(len(jets.pt)), first_bb_pair[:, 1]]
# first_bb_dijet = first_bb_j1 + first_bb_j2
# h2_mindHH = j3 + j4

In [None]:
h2_btag.mass

In [None]:
h = hist.Hist(spanet_higgsmass_axis, higgs_axis, probs_axis)
h.fill(pairs_dict["hh4b-pt300"]["higgs_1_mass"], "h1s", max_probs_dict["hh4b-pt300"])
h.fill(pairs_dict["hh4b-pt300"]["higgs_2_mass"], "h2s", max_probs_dict["hh4b-pt300"])
h.fill(h1.mass, "h1r", max_probs_dict["hh4b-pt300"])
h.fill(h2_btag.mass, "h2r", max_probs_dict["hh4b-pt300"])
h.fill(h1.mass, "h1b", max_probs_dict["hh4b-pt300"])
h.fill(fj2.mass, "h2b", max_probs_dict["hh4b-pt300"])

linestyle_by_class = {
    "h1s": "solid",  # spanet
    "h2s": "solid",
    "h1r": "dashed",  # reco
    "h2r": "dashed",
    "h1b": "dotted",  # reco
    "h2b": "dotted",
}
probs_name = {5: "2bh0h", 6: "1bh1h"}
labels = {
    "s": "SPANet",
    "r": "2 b-tag",
    "b": "Boosted FatJet",
}
color_by_class = {
    "s": "blue",
    "r": "orange",
    "b": "green",
}
fig, ax = plt.subplots(2, 2, figsize=(12, 5), sharex=True)
for i in range(2):
    for j, prob in enumerate([5, 6]):
        legend_elements = []
        for k, rec in enumerate(["s", "r", "b"]):
            key = f"h{i+1}{rec}"
            hep.histplot(
                h[{"higgs": key, "prob": prob}],
                density=True,
                lw=2,
                ax=ax[i, j],
                ls=linestyle_by_class[key],
                color=color_by_class[rec],
            )
            legend_elements.append(
                Line2D(
                    [0],
                    [0],
                    ls=linestyle_by_class[key],
                    lw=2,
                    label=labels[rec],
                    color=color_by_class[rec],
                )
            )
        ax[i, j].legend(handles=legend_elements, title=f"{probs_name[prob]}")
        ax[i, j].set_ylabel("Density")
        ax[i, j].set_xlabel("")
ax[1, 0].set_xlabel("Higgs mass")
ax[1, 1].set_xlabel("Higgs mass")

In [None]:
h = hist.Hist(spanet_higgsmass_axis, higgs_axis)
h.fill(pairs_dict["hh4b-pt300"]["higgs_1_mass"], "h1-hh4b")
h.fill(pairs_dict["hh4b-pt300"]["higgs_2_mass"], "h2-hh4b")
h.fill(pairs_dict["qcd"]["higgs_1_mass"], "h1-qcd")
h.fill(pairs_dict["qcd"]["higgs_2_mass"], "h2-qcd")

linestyle_by_class = {
    "qcd": "solid",
    "hh4b": "dashed",
}
color_by_class = {
    "qcd": "orange",
    "hh4b": "blue",
}

fig, ax = plt.subplots(1, 2, figsize=(12, 5), sharex=True)
for i in range(2):
    legend_elements = []
    for sample in ["qcd", "hh4b"]:
        hep.histplot(
            h[{"higgs": f"h{i+1}-{sample}"}],
            ax=ax[i],
            ls=linestyle_by_class[sample],
            color=color_by_class[sample],
            density=True,
        )
        legend_elements.append(
            Line2D(
                [0],
                [0],
                ls=linestyle_by_class[sample],
                lw=2,
                label=sample,
                color=color_by_class[sample],
            )
        )
    ax[i].legend(handles=legend_elements, title=f"H{i+1}")
    ax[i].set_ylabel("Density")

# Load inference from multiple samples

In [None]:
probs_dict = {}
assign_probs_dict = {}
max_probs_dict = {}
pairs_dict = {}

In [None]:
from pathlib import Path

for key in ["vhtobb", "hh4b", "qcd", "ttbar"]:
    print(key)
    with Path(f"{MAIN_DIR}/../data/spanet/{odir}/assign_probs_{key}.pkl").open("rb") as file:
        assign_probs_dict[key] = pickle.load(file)

    pairs_dict[key] = pd.read_csv(f"{MAIN_DIR}/../data/spanet/{odir}/pairs_spanet_hhh_{key}.csv")

    with Path(f"{MAIN_DIR}/../data/spanet/{odir}/probs_{key}.pkl").open("rb") as file:
        probs_dict[key] = pickle.load(file)

    with Path(f"{MAIN_DIR}/../data/spanet/{odir}/max_assign_probs_{key}.pkl").open("rb") as file:
        max_probs_dict[key] = pickle.load(file)

In [None]:
pairs_dict["ttbar"]["higgs_1_mass"]

In [None]:
h = hist.Hist(spanet_higgsmass_axis, higgs_axis)
h.fill(pairs_dict["hh4b"]["higgs_1_mass"], "h1-hh4b")
h.fill(pairs_dict["hh4b"]["higgs_2_mass"], "h2-hh4b")
h.fill(pairs_dict["qcd"]["higgs_1_mass"], "h1-qcd")
h.fill(pairs_dict["qcd"]["higgs_2_mass"], "h2-qcd")
h.fill(pairs_dict["ttbar"]["higgs_1_mass"], "h1-ttbar")
h.fill(pairs_dict["ttbar"]["higgs_2_mass"], "h2-ttbar")
h.fill(pairs_dict["vhtobb"]["higgs_1_mass"], "h1-vhtobb")
h.fill(pairs_dict["vhtobb"]["higgs_2_mass"], "h2-vhtobb")

linestyle_by_class = {
    "qcd": "solid",
    "hh4b": "dashed",
    "vhtobb": "dashdot",
    "ttbar": "dotted",
}
color_by_class = {"qcd": "orange", "hh4b": "red", "vhtobb": "teal", "ttbar": "blue"}

fig, ax = plt.subplots(1, 2, figsize=(12, 5), sharex=True)
for i in range(2):
    legend_elements = []
    for sample in ["ttbar", "qcd", "hh4b", "vhtobb"]:
        hep.histplot(
            h[{"higgs": f"h{i+1}-{sample}"}],
            ax=ax[i],
            ls=linestyle_by_class[sample],
            color=color_by_class[sample],
            density=True,
        )
        legend_elements.append(
            Line2D(
                [0],
                [0],
                ls=linestyle_by_class[sample],
                lw=2,
                label=sample,
                color=color_by_class[sample],
            )
        )
    ax[i].legend(handles=legend_elements, title=f"H{i+1}")
    ax[i].set_ylabel("Density")

In [None]:
h = hist.Hist(spanet_higgsmass_axis, higgs_axis, probs_axis)
for key in ["hh4b", "vhtobb", "qcd", "ttbar"]:
    h.fill(pairs_dict[key]["higgs_1_mass"], f"h1s-{key}", max_probs_dict[key])
    h.fill(pairs_dict[key]["higgs_2_mass"], f"h2s-{key}", max_probs_dict[key])

    if key == "ttbar":
        events = events_dict[key][:49035]
    else:
        events = events_dict[key]
    jets_outside = utils.make_vector(events, "ak4JetOutside")
    fatjets = utils.make_vector(events, "ak8FatJet", mstring="PNetMass")

    # H1 candidate
    h1 = fatjets[:, 0]
    # H2 candidate
    # ak4 jet (outside) with highest b-tagging score (btagDeepFlavB)
    j3 = jets_outside[:, 0]
    j4 = jets_outside[:, 1]
    h2_btag = j3 + j4

    # fill histogram
    h.fill(h1.mass, f"h1r-{key}", max_probs_dict[key])
    h.fill(h2_btag.mass, f"h2r-{key}", max_probs_dict[key])

linestyle_by_class = {
    "qcd": "solid",
    "hh4b": "dashed",
    "vhtobb": "dashdot",
    "ttbar": "dotted",
}
probs_name = {5: "2bh0h", 6: "1bh1h"}
labels = {
    "s": "SPANet",
    "r": "2 b-tag",
}
color_by_class = {
    "s": "blue",
    "r": "orange",
}

alpha_by_sample = {
    "qcd": 0.1,
    "ttbar": 0.3,
    "hh4b": 0.9,
    "vhtobb": 0.5,
}
fig, ax = plt.subplots(2, 2, figsize=(12, 5), sharex=True)
for i in range(2):
    for j, prob in enumerate([5, 6]):
        legend_elements = []
        for k, sample in enumerate(["ttbar", "qcd", "hh4b", "vhtobb"]):
            rec = "r"
            key = f"h{i+1}{rec}-{sample}"
            hep.histplot(
                h[{"higgs": key, "prob": prob}],
                density=True,
                lw=2,
                ax=ax[i, j],
                ls=linestyle_by_class[sample],
                color=color_by_class[rec],
                alpha=alpha_by_sample[sample],
            )
            legend_elements.append(
                Line2D(
                    [0],
                    [0],
                    ls=linestyle_by_class[sample],
                    lw=2,
                    label=sample,
                    color=color_by_class[rec],
                    alpha=alpha_by_sample[sample],
                )
            )
        ax[i, j].legend(handles=legend_elements, title=f"{probs_name[prob]}")
        ax[i, j].set_ylabel("Density")
        ax[i, j].set_xlabel("")
ax[1, 0].set_xlabel("Higgs mass")
ax[1, 1].set_xlabel("Higgs mass")

In [None]:
probs_dict["hh4b"]["hh4b"] + probs_dict["hh4b"]["hhh"]

In [None]:
h = hist.Hist(spanet_discr_axis, class_axis)
key = "vhtobb"
for prob in ["hhh", "qcd", "tt", "vv", "vjets", "hhh4b2tau", "hh2b2tau", "hh4b"]:
    h.fill(probs_dict[key][prob], prob)

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
legend_elements = []
linestyles = {
    "hh4b": "solid",
    "hhh": "dashed",
    "hhh4b2tau": "dashdot",
    "hh2b2tau": "dashed",
    "qcd": "dashdot",
    "tt": "dotted",
    "vv": "dotted",
    "vjets": "dashed",
}
color_by_prob = {
    "hh4b": "red",
    "hhh": "green",
    "hhh4b2tau": "grey",
    "hh2b2tau": "black",
    "qcd": "orange",
    "tt": "blue",
    "vv": "teal",
    "vjets": "violet",
}
for key in ["hhh", "qcd", "tt", "vv", "vjets", "hhh4b2tau", "hh2b2tau", "hh4b"]:
    hep.histplot(
        h[{"class": key}],
        density=True,
        lw=2,
        ls=linestyles[key],
        color=color_by_prob[key],
    )
    legend_elements.append(
        Line2D([0], [0], color=color_by_prob[key], lw=2, label=key, ls=linestyles[key])
    )
ax.legend(handles=legend_elements)
ax.set_ylabel("Density")
ax.set_yscale("log")
ax.set_title("HH4b sample")

In [None]:
h = hist.Hist(spanet_discr_axis, sample_axis)
for key in ["hh4b", "vhtobb", "qcd", "ttbar"]:
    prob_HH = probs_dict[key]["hh4b"] + probs_dict[key]["hhh"]
    h.fill(prob_HH, key)

linestyle_by_class = {
    "qcd": "solid",
    "hh4b": "dashed",
    "vhtobb": "dashdot",
    "ttbar": "dotted",
}
color_by_class = {"qcd": "orange", "hh4b": "red", "vhtobb": "teal", "ttbar": "blue"}

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
legend_elements = []
for key in ["hh4b", "vhtobb", "qcd", "ttbar"]:
    hep.histplot(
        h[{"sample": key}],
        density=True,
        lw=2,
        ls=linestyle_by_class[key],
        color=color_by_class[key],
    )
    legend_elements.append(
        Line2D([0], [0], color=color_by_class[key], lw=2, label=key, ls=linestyle_by_class[key])
    )
ax.legend(handles=legend_elements)
ax.set_ylabel("Density")
# ax.set_yscale("log")
ax.set_xlabel("SPANet Prob HH + Prob HHH")

In [None]:
color_by_prob = {
    "3bh0h": "grey",
    "2bh1h": "grey",
    "1bh2h": "grey",
    "0bh3h": "grey",
    "2bh0h": "blue",
    "1bh1h": "red",
    "0bh2h": "green",
    "1bh0h": "black",
    "0bh1h": "black",
    "0bh0h": "black",
}
h = hist.Hist(spanet_discr_axis, class_axis)
key = "ttbar"
for prob in color_by_prob.keys():
    h.fill(assign_probs_dict[key][prob], prob)

fig, ax = plt.subplots(1, 1, figsize=(7, 5))
legend_elements = []
for key in color_by_prob.keys():
    hep.histplot(
        h[{"class": key}],
        # density=True,
        lw=2,
        color=color_by_prob[key],
    )
    legend_elements.append(Line2D([0], [0], color=color_by_prob[key], lw=2, label=key))
ax.legend(handles=legend_elements)
ax.set_ylabel("Events")
ax.set_yscale("log")
ax.set_title("TThad sample")