# JMSR Ratio Plot

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import vector

import HH4b.plotting as plotting
import HH4b.utils as utils
from HH4b.utils import ShapeVar

In [None]:
def make_vector(events: pd.DataFrame, obj: str):
    """Create a ``vector`` object from the columns of the dataframe"""
    mstring = "PNetMass" if obj == "ak8FatJet" else "Mass"

    return vector.array(
        {
            "pt": events[f"{obj}Pt"],
            "phi": events[f"{obj}Phi"],
            "eta": events[f"{obj}Eta"],
            "M": events[f"{obj}{mstring}"],
        }
    )

## Load Dataset

In [None]:
year = "2022EE"  #
dir_name = "24Apr18_v12_signal"
path_to_dir = f"/eos/uscms/store/user/haoyang/bbbb/ttSkimmer/{dir_name}"

In [None]:
# Load your dataset
samples = {
    "muon": [
        "Muon_Run2022E",
        "Muon_Run2022F",
        "Muon_Run2022G",
    ],
    "tt": ["TTto2L2Nu", "TTto4Q", "TTtoLNu2Q"],
}

dirs = {path_to_dir: samples}

filters = None

# columns to load
# the parquet files are too big so we can only load a few columns at a time without consumming much memory
load_columns = [
    ("weight", 1),
    ("ak8FatJetMsd", 2),
    ("ak8FatJetPNetMass", 2),
    ("ak8FatJetEta", 2),
    ("ak8FatJetPhi", 2),
    ("ak8FatJetPt", 2),
    ("finalWeight", 0),
]
# reformat into ("column name", "idx") format for reading multiindex columns
columns = []
for key, num_columns in load_columns:
    for i in range(num_columns):
        columns.append(f"('{key}', '{i}')")


events_dict = {}
for input_dir, samples in dirs.items():
    events_dict = {
        **events_dict,
        # this function will load files (only the columns selected), apply filters and compute a weight per event
        **utils.load_samples(
            input_dir, samples, year, filters=filters, columns=columns, reorder_legacy_txbb=False
        ),
    }

samples_loaded = list(events_dict.keys())
keys_loaded = list(events_dict[samples_loaded[0]].keys())
print("Keys in events_dict")
for i in keys_loaded:
    print(i)

## Event cuts

In [None]:
# Higgs candidate selection example
events_raw = pd.concat([events_dict["muon"], events_dict["tt"]], keys=["muon", "ttbar"])

In [None]:
# AK4OutsideJet pt cut
# jets_outside_raw = make_vector(events_raw, "ak4JetOutside")
# j3_raw = jets_outside_raw[:, 0]
# j4_raw = jets_outside_raw[:, 1]
# j3j4_pt_cut = (j3_raw.pt > 20) & (j4_raw.pt > 20)

In [None]:
# combined_filter = j3j4_pt_cut
# events = events_raw[combined_filter]
events = events_raw

## Save and Reset Index

In [None]:
multiIndex = events.index
events = events.reset_index()

## Derive W jet mass

In [None]:
fatjets = make_vector(events, "ak8FatJet")
fj_sorted = np.take_along_axis(fatjets, np.argsort(fatjets.pt, axis=1)[:, ::-1], axis=1)
W_jets = fj_sorted[:, 0]
W_jet_PNetMass = W_jets.m

In [None]:
events["WJetPNetMass"] = W_jet_PNetMass

## Set Index Back

In [None]:
events = events.set_index(multiIndex)

In [None]:
# parse the events df to a way that util can accept
events_dict = {}
events_dict["data"] = events.loc["muon"]
events_dict["ttbar"] = events.loc["ttbar"]

## Plot Mass

In [None]:
control_plot_vars = [
    ShapeVar(
        var="WJetPNetMass",
        label=r"W Jet PNet Mass (GeV)",
        bins=list(np.arange(20, 250, 5)),
        reg=False,
    ),
]

In [None]:
ylims = {
    "2022": 5e4,
    "2022EE": 4e3,
    "2023-pre-BPix": 4e5,
}

In [None]:
events_dict["data"]["finalWeight"] = 1.0

In [None]:
events_dict["data"]["finalWeight"]

In [None]:
!pip install -e ../../../.

In [None]:
import HH4b.plotting as plotting

In [None]:
for year in ["2022EE"]:
    hists = {}
    for shape_var in control_plot_vars:
        print(shape_var)
        if shape_var.var not in hists:
            hists[shape_var.var] = utils.singleVarHist(
                events_dict,
                shape_var,
                weight_key="finalWeight",
            )

        bkgs = ["ttbar"]
        sigs = []

        plotting.ratioHistPlot(
            hists[shape_var.var],
            year,
            sigs,
            bkgs,
            name="test",
            show=True,
            log=True,
            bg_err=None,
            plot_data=True,
            plot_significance=False,
            significance_dir=shape_var.significance_dir,
            ratio_ylims=[0.2, 1.8],
            ylim=4e5,
            ylim_low=10,
        )

In [None]:
np.array([1, 2, 3, 4], dtype="float")