In [None]:
%load_ext autoreload
%autoreload 2

import argparse
import logging
import time
from os.path import join
import sys
sys.path.append('../')
import warnings
warnings.filterwarnings("ignore")

from coffea import processor, util
from coffea.lumi_tools import LumiMask
from coffea.nanoevents import NanoAODSchema

from azh_analysis.processors.fake_rate_processor import FakeRateProcessor
from azh_analysis.utils.corrections import (
    dyjets_stitch_weights,
    get_electron_ID_weights,
    get_electron_trigger_SFs,
    get_muon_ID_weights,
    get_muon_trigger_SFs,
    get_muon_ES_weights,
    get_tau_ID_weights,
    get_pileup_weights,
)
from azh_analysis.utils.sample import get_fileset, get_nevts_dict, get_sample_info

In [None]:
# setup logging
log_format = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logging.info("Initializing")

# relevant parameters
year, source = '2018', 'MC_UL'

# load up golden jsons
golden_json_dir = "../samples/data_certification"
golden_jsons = {
    "2018": join(golden_json_dir, "data_cert_2018.json"),
    "2017": join(golden_json_dir, "data_cert_2017.json"),
    "2016postVFP": join(golden_json_dir, "data_cert_2016.json"),
    "2016preVFP": join(golden_json_dir, "data_cert_2016.json"),
}
lumi_masks = {year: LumiMask(golden_json) for year, golden_json in golden_jsons.items()}

# load up electron / muon / tau IDs
eID_base = f"../corrections/electron_ID/UL_{year}"
eID_file = join(
    eID_base, f"Electron_RunUL{year}_IdIso_AZh_IsoLt0p15_IdFall17MVA90noIsov2.root"
)
eIDs = get_electron_ID_weights(eID_file)
logging.info(f"Using eID_SFs:\n{eID_file}")

mID_base = f"../corrections/muon_ID/UL_{year}"
mID_file = join(mID_base, f"Muon_RunUL{year}_IdIso_AZh_IsoLt0p15_IdLoose.root")
mIDs = get_muon_ID_weights(mID_file)
logging.info(f"Using mID_SFs:\n{mID_file}")

tID_base = f"../corrections/tau_ID/UL_{year}"
tID_file = join(tID_base, "tau.corr.json")
tIDs = get_tau_ID_weights(tID_file)
logging.info(f"Using tID_SFs:\n{tID_file}")

# load up electron / muon trigger SFs
e_trigs = {
    "2016preVFP": "Ele25_EtaLt2p1",
    "2016postVFP": "Ele25_EtaLt2p1",
    "2017": "Ele35",
    "2018": "Ele35",
}
e_trig_base = f"../corrections/electron_trigger/UL_{year}"
e_trig_file = join(e_trig_base, f"Electron_RunUL{year}_{e_trigs[year]}.root")
e_trig_SFs = get_electron_trigger_SFs(e_trig_file)

m_trigs = {
    "2016preVFP": "IsoMu24orIsoTkMu24",
    "2016postVFP": "IsoMu24orIsoTkMu24",
    "2017": "IsoMu27",
    "2018": "IsoMu27",
}
m_trig_base = f"../corrections/muon_trigger/UL_{year}"
m_trig_file = join(m_trig_base, f"Muon_RunUL{year}_{m_trigs[year]}.root")
m_trig_SFs = get_muon_trigger_SFs(m_trig_file)

mES_SFs = get_muon_ES_weights("../corrections/muon_ES/", year)

# load up non-signal MC csv / yaml files
fset_string = f"{source}_{year}"
sample_info = get_sample_info(join("../samples", fset_string + ".csv"))
fileset = get_fileset(join("../samples/filesets", fset_string + ".yaml"))

pileup_weights = None
if "data" not in source:
    pileup_weights = get_pileup_weights("../corrections/pileup", year)

# load up signal MC csv / yaml files
fileset = {k: v[0:1] for k, v in fileset.items()}

# only run over root files
for sample, files in fileset.items():
    good_files = []
    for f in files:
        if f.split(".")[-1] == "root":
            good_files.append(f)
    fileset[sample] = good_files
logging.info(f"running on\n {fileset.keys()}")

# extract the sum_of_weights from the ntuples
nevts_dict, dyjets_weights = None, None
if "MC" in source:
    nevts_dict = get_nevts_dict(fileset, year)
    print("fileset keys", fileset.keys())
    if f"DYJetsToLLM-50_{year}" in fileset.keys():
        dyjets_weights = dyjets_stitch_weights(sample_info, nevts_dict, year)

logging.info(f"Successfully built sum_of_weights dict:\n {nevts_dict}")
logging.info(f"Successfully built dyjets stitch weights:\n {dyjets_weights}")

# start timer, initiate cluster, ship over files
tic = time.time()

# instantiate processor module
proc_instance = FakeRateProcessor(
    year=year,
    sample_info=sample_info,
    fileset=fileset,
    pileup_weights=pileup_weights,
    lumi_masks=lumi_masks,
    nevts_dict=nevts_dict,
    eleID_SFs=eIDs,
    muID_SFs=mIDs,
    muES_SFs=mES_SFs,
    tauID_SFs=tIDs,
    e_trig_SFs=e_trig_SFs,
    m_trig_SFs=m_trig_SFs,
    dyjets_weights=dyjets_weights,
)

futures_run = processor.Runner(
    executor=processor.FuturesExecutor(compression=None, workers=1),
    schema=NanoAODSchema,
)

out = futures_run(
    fileset,
    "Events",
    processor_instance=proc_instance,
)

In [None]:
mll_hists = [out["mT"][d] for d in out["mT"].keys()]
sum(mll_hists)[::sum, "eee", "Fake", "Denominator", ::sum, ::sum, ::sum, :]

In [None]:
import sys
sys.path.append('../')
import uproot
from azh_analysis.utils.corrections import get_fake_rates, make_evaluator
from coffea.lookup_tools import extractor

f = f"JetEleFakeRate_Fall17MVAv2WP90_noIso_Iso0p15_UL2018_coffea.root"
print(uproot.open(f)["POL2FitFR_Central_barrel"])

f = "../corrections/fake_rates/UL_2018/JetEleFakeRate_Fall17MVAv2WP90_noIso_Iso0p15_UL2018.root"
get_fake_rates("", "2018", origin="_coffea")

In [None]:
from matplotlib import pyplot as plt

mll = mll_hist["eett", "Prompt", "Numerator", ::sum, ::sum, ::sum, :]
mll.plot1d()
plt.show()
pd_eee = mll_hist["eee", "Prompt", "Denominator", ::sum, ::sum, ::sum, :]
pd_mme = mll_hist["mme", "Prompt", "Denominator", ::sum, ::sum, ::sum, :]
pd = pd_eee + pd_mme
fd_eee = mll_hist["eee", "Fake", "Denominator", ::sum, ::sum, ::sum, :]
fd_mme = mll_hist["mme", "Fake", "Denominator", ::sum, ::sum, ::sum, :]
fd = fd_eee + fd_mme
pn_eee = mll_hist["eee", "Prompt", "Numerator", ::sum, ::sum, ::sum, :]
pn_mme = mll_hist["mme", "Prompt", "Numerator", ::sum, ::sum, ::sum, :]
pn = pn_eee + pn_mme
fn_eee = mll_hist["eee", "Fake", "Numerator", ::sum, ::sum, ::sum, :]
fn_mme = mll_hist["mme", "Fake", "Numerator", ::sum, ::sum, ::sum, :]
fn = fn_eee + fn_mme



print(fn/fd)
fig, axs = plt.subplots(dpi=200, figsize=(5,5))
main_ax_artists, sublot_ax_arists = fn.plot_ratio(
    fd,
    rp_ylabel=r"Ratio",
    rp_num_label="hist1",
    rp_denom_label="hist2",
    rp_uncert_draw_type="line",  # line or bar
)
plt.tight_layout()
plt.show()

In [None]:
import os
from os.path import join
from coffea import util
import numpy as np
import mplhep as hep
from matplotlib import pyplot as plt
from cycler import cycler
from hist.intervals import ratio_uncertainty
from azh_analysis.utils.plotting import plot_fake_rates_data, plot_fake_rate_measurements

year="2018"
date = "02-16"
data = util.load("../fake_rates_data_UL_{year}_{date}.coffea")
MC = util.load("../fake_rates_MC_UL_{year}_{date}.coffea")
mT_data = sum(data["pt"].values())
mT_MC = sum(MC["pt"].values())

# category, prompt, numerator, decay_mode, pt_bin, eta_bin, vals
mT_MC_d = (
    mT_MC["eett", "Fake", "Denominator", ::sum, ::sum, ::sum, :] 
)

mT_MC_n = (
    mT_MC["eett", "Fake", "Numerator", ::sum, ::sum, ::sum, :] 
)

pt_bins = ['$10<p_T<20$ GeV', '$20<p_T<30$ GeV', '$30<p_T<40$ GeV', '$40<p_T<60$ GeV', '$60<p_T<1000000$ GeV']
eta_bins = ['$|\\eta|<1.479$', '$|\\eta|>1.479$']
decay_modes = {0: -1, 1: 0, 2: 1, 3: 2, 4: 10, 5: 11, 6: 15}
ranges = {
    ("e", '$|\\eta|<1.479$', None): [0, 0.014],
    ("e", '$|\\eta|>1.479$', None): [0, 0.035],
    ("m", '$|\\eta|<1.479$', None): [0, 0.06],
    ("m", '$|\\eta|>1.479$', None): [0, 0.1],
    ("et", None, 0): [0, 0.18],
    ("et", None, 1): [0, 0.2],
    ("et", None, 10): [0, 0.18],
    ("et", None, 11): [0, 0.1],
    ("mt", None, 0): [0, 0.35],
    ("mt", None, 1): [0, 0.25], 
    ("mt", None, 10): [0, 0.18],
    ("mt", None, 11): [0, 0.1], 
    ("tt", None, 0): [0, 0.35],
    ("tt", None, 1): [0, 0.25],
    ("tt", None, 10): [0, 0.18],
    ("tt", None, 11): [0, 0.1],
}
filenames = {
    "e": f"JetEleFakeRate_Fall17MVAv2WP90_noIso_Iso0p15_UL{year}_coffea.root",
    "m": f"JetMuFakeRate_Medium_Iso0p15_UL{year}_coffea.root",
    "et": f"JetTauFakeRate_Medium_Tight_VLoose_UL{year}_coffea.root",
    "mt": f"JetTauFakeRate_Medium_VLoose_Tight_UL{year}_coffea.root",
    "tt": f"JetTauFakeRate_Medium_VLoose_VLoose_UL{year}_coffea.root",
}

for mode in ["e", "m", "et", "mt", "tt"]:
    with uproot.recreate(filenames[mode]) as f:
        if "t" not in mode:
            for eta_bin in eta_bins:
                label = (
                    f"Jet faking {mode[-1]}\n" +
                    #f"{pt_label}\n" + 
                    f"{eta_bin}"
                )

                mT_data_d = (
                    mT_data["ee"+mode, ::sum, "Denominator", ::sum, ::sum, eta_bin, :] + 
                    mT_data["mm"+mode, ::sum, "Denominator", ::sum, ::sum, eta_bin, :] -
                    mT_MC["ee"+mode, "Prompt", "Denominator", ::sum, ::sum, eta_bin, :] - 
                    mT_MC["mm"+mode, "Prompt", "Denominator", ::sum, ::sum, eta_bin, :]
                )
                mT_data_n = (
                    mT_data["ee"+mode, ::sum, "Numerator", ::sum, ::sum, eta_bin, :] + 
                    mT_data["mm"+mode, ::sum, "Numerator", ::sum, ::sum, eta_bin, :] -
                    mT_MC["ee"+mode, "Prompt", "Numerator", ::sum, ::sum, eta_bin, :] - 
                    mT_MC["mm"+mode, "Prompt", "Numerator", ::sum, ::sum, eta_bin, :]
                )  
                plot_fake_rates_data(
                    mT_data_d, mT_data_n, label, 
                    xlim=[10,80], ylim=ranges[(mode, eta_bin, None)]
                )
                bins, ratios, yerr, xerr = plot_fake_rate_measurements(
                    mT_data_d, mT_data_n, label, 
                    xlim=[10,80], ylim=ranges[(mode, eta_bin, None)]
                )
                h = np.histogram(bins-1, bins=bins, weights=ratios)
                print(h)
                
                if "<" in eta_bin:
                    f["POL2FitFR_Central_barrel"] = h
                elif ">" in eta_bin:
                    f["POL2FitFR_Central_endcap"] = h
                
                
        else:
            for decay_mode in decay_modes.keys():
                    label = (
                        f"Jet faking {mode[-1]}\n" +
                        f"Decay Mode {decay_modes[decay_mode]}\n"
                    )
                    if "e" in mode:
                        label += f"(Tight, VLoose, Medium)"
                    elif "m" in mode:
                        label += f"(VLoose, Tight, Medium)"
                    else:
                        label += f"(VLoose, VLoose, Medium)"
                    mT_data_d = (
                        mT_data["ee"+mode, ::sum, "Denominator", decay_mode, ::sum, ::sum, :] + 
                        mT_data["mm"+mode, ::sum, "Denominator", decay_mode, ::sum, ::sum, :] -
                        mT_MC["ee"+mode, "Prompt", "Denominator", decay_mode, ::sum, ::sum, :] - 
                        mT_MC["mm"+mode, "Prompt", "Denominator", decay_mode, ::sum, ::sum, :]
                    )
                    mT_data_n = (
                        mT_data["ee"+mode, ::sum, "Numerator", decay_mode, ::sum, ::sum, :] + 
                        mT_data["mm"+mode, ::sum, "Numerator", decay_mode, ::sum, ::sum, :] -
                        mT_MC["ee"+mode, "Prompt", "Numerator", decay_mode, ::sum, ::sum, :] - 
                        mT_MC["mm"+mode, "Prompt", "Numerator", decay_mode, ::sum, ::sum, :]
                    )
                    if np.sum(mT_data_d.values())==0: continue
                    plot_fake_rates_data(
                        mT_data_d, mT_data_n, label, xlim=[20,100], 
                        ylim=ranges[(mode, None, decay_modes[decay_mode])],
                        combine_bins = True,
                    )
                    bins, ratios, yerr, xerr = plot_fake_rate_measurements(
                        mT_data_d, mT_data_n, label, xlim=[20,100],
                        ylim=ranges[(mode, None, decay_modes[decay_mode])],
                        combine_bins=True,
                    )
                    h = np.histogram(bins-1, bins=bins, weights=ratios)
                    print(h)
                    f[f"POL2FitFR_Central_DM{decay_modes[decay_mode]}"] = h
        print(f)