In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from os.path import join
import warnings
warnings.filterwarnings('ignore')

import hist
import mplhep as hep
import seaborn as sns
from coffea import util
from matplotlib import pyplot as plt
from cycler import cycler
from hist import Hist
from hist.intervals import ratio_uncertainty

sys.path.append("../")
from azh_analysis.utils.plotting import (
    get_category_labels, get_color_list, get_category_labels
)
from azh_analysis.utils.parameters import (
    get_lumis, get_categories, 
)

In [None]:
def separate_production_modes(signal_var):
    gga = ([v for k, v in signal_var if "GluGluToA" in k])
    bba = ([v for k, v in signal_var if "BBA" in k])
    return gga, bba

save_plots = False

year = "2018"
lumis = get_lumis(as_picobarns=False)
source = "signal"

signal = util.load(f"../output/{source}_UL_2018_None_OS.coffea")
cat_labels = get_category_labels()
colors = get_color_list()

outdir = f"../plots/signal/{year}"

In [None]:
var = "m4l"
#def plot_signals(
#    signals, cat_labels, mass_label, logscale=False, outfile=None, year=2018, lumi=59.7,
#):
mass_label = r"$m_{ll\tau\tau}$ [GeV]"
year = "2018"
lumi = 59.7
logscale = False
outfile = None
mass = "400"

fig, axs = plt.subplots(
    nrows=2,
    ncols=3,
    figsize=(15, 10),
    dpi=200,
    sharex=True,
    sharey=True
)

labels = {"raw": "Raw", "corr": "Corrected", "cons": "Constrained"}
for cat, cat_label in cat_labels.items():
    for i, mass_type in enumerate(["raw", "corr", "cons"]):
        # group, cat, sign, mass_type, btags, syst,
        gga = {
            k: (
                v[::sum, "ee" + cat, 0, mass_type, 0, "nom", :] +
                v[::sum, "mm" + cat, 0, mass_type, 0, "nom", :]
            )
            for k, v in signal[var].items() 
            if ("GluGluToA" in k) and (f"M{str(mass)}" in k)
        }
        gga = {
            (rf"$m_A={k.split('TauTauM')[-1]}$ GeV"): v 
             for k, v in gga.items()
        }
        bba = {
            k: (
                v[::sum, "ee" + cat, 0, mass_type, 1, "nom", :] +
                v[::sum, "mm" + cat, 0, mass_type, 1, "nom", :]
            )
            for k, v in signal[var].items() 
            if ("BBA" in k) and (f"M{str(mass)}" in k)
        }
        bba = {
            (rf"$m_A={k.split('TauTauM')[-1]}$ GeV"): v 
             for k, v in bba.items()
        }
            
        for j, data in enumerate([gga, bba]):
            stack = hist.Stack.from_dict(data)
            stack.plot(
            ax=axs[j, i], stack=True, histtype="step", 
        )
            if logscale:
                axs[j, i].set_yscale("log")
                axs[j, i].set_xscale("log")

            axs[j, i].set_xlabel("")
            axs[j, i].set_ylabel("")
            #axs[j, i].legend(loc="best", prop={"size": 16}, frameon=True)
            #axs[j, i].get_legend().set_title(f"{cat_label}")
            hep.cms.label("Preliminary", data=False, lumi=lumi, year=year, ax=axs[j, i])
            axs[j, i].set_xlabel(f"{labels[mass_type]} " + mass_label)

        
    plt.tight_layout(),
    if outfile is not None:
        plt.savefig(outfile, format="pdf", dpi=800)
    plt.show()

In [None]:
var = "m4l"
for cat, cat_label in cat_labels.items():
    for bcat in [0]:            
        sign = 0 # OS
        nom_s = sum(nom[var].values())
        shift_s = sum(shift[var].values())
        # group, cat, sign, mass_type, btags, syst, values
        nom_s = nom_s[::sum, 'ee'+cat, sign, :, ::sum, 'nom', :] + nom_s[::sum, 'mm'+cat, sign, :, ::sum, 'nom', :]
        shift_up = shift_s[::sum, 'ee'+cat, sign, :, ::sum, syst+"_up", :] + shift_s[::sum, 'mm'+cat, sign, :, ::sum, syst+"_up", :]
        shift_down = shift_s[::sum, 'ee'+cat, sign, :, ::sum, syst+"_down", :] + shift_s[::sum, 'mm'+cat, sign, :, ::sum, syst+"_down", :]
        plot_m4l_systematic(
            nom_s, shift_up, shift_down, 
            syst, cat_label, mass_label=r"$m_{ll\tau\tau}$ [GeV]", 
            logscale=True,
            year=int(year), lumi = lumis[year],
            outfile = join(outdir, f"{source}_{cat}_{syst}_{var}.pdf") if plot else None
        )

In [None]:
var = "mtt"
for cat, cat_label in cat_labels.items():
    for bcat in [0]:
        nom_s = sum(nom[var].values())
        nom_s = nom_s[::sum, 'ee'+cat, 0, :, ::sum, 'nom', :] + nom_s[::sum, 'mm'+cat, 0, :, ::sum, 'nom', :]
        shift_s = sum(shift[var].values())
        shift_up = shift_s[::sum, 'ee'+cat, 0, :, ::sum, syst+"_up", :] + shift_s[::sum, 'mm'+cat, 0, :, ::sum, syst+"_up", :]
        shift_down = shift_s[::sum, 'ee'+cat, 0, :, ::sum, syst+"_down", :] + shift_s[::sum, 'mm'+cat, 0, :, ::sum, syst+"_down", :]
        plot_m4l_systematic(
            nom_s, shift_up, shift_down, 
            syst, cat_label, mass_label=r"$m_{\tau\tau}$ [GeV]",
            year=int(year), lumi = lumis[year],
            outfile = join(outdir, f"{source}_{cat}_{syst}_{var}.pdf") if plot else None
        )

In [None]:
var = "mll"
for cat, cat_label in cat_labels.items():
    for bcat in [0]:
        # group, cat, sign, btag, syst, vals
        nom_s = sum(nom[var].values())
        nom_s = nom_s[::sum, 'ee'+cat, 0, ::sum, 'nom', :] + nom_s[::sum, 'mm'+cat, 0, ::sum, 'nom', :]
        shift_s = sum(shift[var].values())
        shift_up = shift_s[::sum, 'ee'+cat, 0, ::sum, syst+"_up", :] + shift_s[::sum, 'mm'+cat, 0, ::sum, syst+"_up", :]
        shift_down = shift_s[::sum, 'ee'+cat, 0, ::sum, syst+"_down", :] + shift_s[::sum, 'mm'+cat, 0, ::sum, syst+"_down", :]
        plot_systematic(
            nom_s, shift_up, shift_down,
            syst, cat_label, r"$m_{ll}$ [GeV]", 
            year=int(year), lumi = lumis[year],
            outfile = join(outdir, f"{source}_{cat}_{syst}_{var}.pdf") if plot else None
        )

In [None]:
var = "pt"
for cat, cat_label in cat_labels.items():
    for bcat in [0]:
        for leg in ['3', '4']:
            nom_s = sum(nom[var].values())
            nom_s = nom_s[::sum, 'ee'+cat, 0, leg, ::sum, 'nom', :] + nom_s[::sum, 'mm'+cat, 0, leg, ::sum, 'nom', :]
            shift_s = sum(shift[var].values())
            shift_up = shift_s[::sum, 'ee'+cat, 0, leg, ::sum, syst+"_up", :] + shift_s[::sum, 'mm'+cat, 0, leg, ::sum, syst+"_up", :]
            shift_down = shift_s[::sum, 'ee'+cat, 0, leg, ::sum, syst+"_down", :] + shift_s[::sum, 'mm'+cat, 0, leg, ::sum, syst+"_down", :]
            plot_systematic(
                nom_s, shift_up, shift_down, 
                syst, cat_label, f"Leg {leg} " + r"$p_T$ [GeV]",
                year=int(year), lumi = lumis[year],
                outfile = join(outdir, f"{source}_{cat}_{syst}_{var}{leg}.pdf") if plot else None
            )

In [None]:
var = "met"
for cat, cat_label in cat_labels.items():
    for bcat in [0]:
        nom_s = sum(nom[var].values())
        nom_s = nom_s[::sum, 'ee'+cat, 0, ::sum, 'nom', :] + nom_s[::sum, 'mm'+cat, 0, ::sum, 'nom', :]
        shift_s = sum(shift[var].values())
        shift_up = shift_s[::sum, 'ee'+cat, 0, ::sum, syst+"_up", :] + shift_s[::sum, 'mm'+cat, 0, ::sum, syst+"_up", :]
        shift_down = shift_s[::sum, 'ee'+cat, 0, ::sum, syst+"_down", :] + shift_s[::sum, 'mm'+cat, 0, ::sum, syst+"_down", :]
        plot_systematic(
            nom_s, shift_up, shift_down, 
            syst, cat_label, r"$E_T^\mathrm{miss}$ [GeV]",
            year=int(year), lumi = lumis[year],
            outfile = join(outdir, f"{source}_{cat}_{syst}_{var}.pdf")
        )