In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import argparse
import warnings
import logging
from os.path import join
from collections import defaultdict
sys.path.append('../')
warnings.filterwarnings("ignore")

from scipy.stats import ks_2samp
import numpy as np
from coffea import processor, util
import hist
import mplhep as hep
import seaborn as sns
from matplotlib import pyplot as plt
from cycler import cycler
from hist import Hist
from hist.intervals import ratio_uncertainty
from azh_analysis.utils.plotting import plot_data_vs_mc
from azh_analysis.utils.parameters import get_lumis
import warnings
warnings.filterwarnings('ignore')
#hep.style.use(["CMS", "fira", "firamath"])

sys.path.append("../")
from azh_analysis.utils.histograms import integrate

plot = True
year = "2018"
lumi = get_lumis()[year]
OS = util.load(f"../output/data_UL_{year}_OS_ub.coffea")
SS = util.load(f"../output/data_UL_{year}_SS_ub_relaxed.coffea")
print(OS.keys())
ss_m4l, os_m4l = np.array(SS["reducible_m4l_reg"].value), np.array(OS["reducible_m4l_reg"].value)
ss_cat, os_cat = np.array(SS["reducible_cat"].value), np.array(OS["reducible_cat"].value)
ss_btag, os_btag = np.array(SS["reducible_btag"].value), np.array(OS["reducible_btag"].value)

cat_labels = {
    'tt': r'$ll\tau_h\tau_h$',
    'et': r'$ll e\tau_h$',
    'mt': r'$ll\mu\tau_h$',
    'em': r'$ll e\mu$',
}

In [None]:
def plot_closure(
    os,
    ss,
    var,
    cat_label,
    var_label,
    btag_label,
    stats,
    logscale=False,
    outfile=None,
    year="1",
    lumi="1",
    blind=False,
):
    hep.style.use(["CMS", "fira", "firamath"])
    colors = {
        "OS Application": "#005F73",
        "SS Validation": "#0A9396",
    }
    
    fig, (ax, rax) = plt.subplots(
        nrows=2,
        ncols=1,
        figsize=(9, 12),
        dpi=120,
        gridspec_kw={"height_ratios": (4, 1)},
        sharex=True,
    )
    fig.subplots_adjust(hspace=0.07)
    ax.set_prop_cycle(cycler(color=list(colors.values())))
    
    os_norm = integrate(os)
    os_errs = np.sqrt(np.maximum(os.values(), 1e-8))
    ss_norm = integrate(ss)
    ss_errs = np.sqrt(ss.values()+1e-8)
    ss, ss_errs = ss * os_norm / ss_norm, ss_errs * os_norm / ss_norm
    
    os.plot1d(ax=ax, histtype="errorbar", xerr=25, color="#0A9396", marker="s", markersize=5, mfc='#0A9396', mec='#0A9396', capsize=2, label="OS Application", alpha=0.5)
    ss.plot1d(ax=ax, histtype="errorbar", xerr=25, yerr=ss_errs, color='#EE9B00', marker="o", markersize=5, mfc='#EE9B00', mec='#EE9B00', capsize=2, label="SS Validation", alpha=0.5)
    os_vals = os.density()
    ss_vals = ss.density()
    
    bins = os.axes[0].centers
    y = os_vals / ss_vals
    yerr = ratio_uncertainty(os_vals, ss_vals, "poisson")

    rax.errorbar(
        x=bins,
        y=y,
        yerr=yerr,
        color="k",
        linestyle="none",
        marker="o",
        elinewidth=1,
    )

    if logscale:
        ax.set_yscale("log")
        ax.set_xscale("log")
    ax.set_ylabel("Counts")
    rax.set_xlabel(var_label)
    ax.set_xlabel("")
    ax.legend()
    rax.set_ylim([0, 2])
    ax.set_xlim([50,800])
    rax.set_xlim([50,800])
    ax.legend(loc="best", prop={"size": 16}, frameon=True, title_fontsize="small")
    ax.get_legend().set_title(f"{cat_label}, {btag_label}\nKS={(stats.statistic):.2f}, p={stats.pvalue:.3f}")
    hep.cms.label("Preliminary", data=True, lumi=lumi, year=year, ax=ax)
    if outfile is not None:
        plt.savefig(outfile, format="pdf", dpi=800)
    plt.show()

var = "m4l_reg"
outdir = f"../plots/data-mc/{year}"
for cat, cat_label in cat_labels.items():
    for bcat in [0, 1]:
        for mass_type in ["cons"]:
            os = sum(OS[var].values())
            os = os["reducible", "ee" + cat, bcat, "none", mass_type, :] + os["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ss = sum(SS[var].values())
            ss = ss["data", "ee" + cat, bcat, "none", mass_type, :] + ss["data", "mm" + cat, bcat, "none", mass_type, :]
            
            ss = ss[::2j]
            os = os[::2j]
            
            ss_mask = [(("ee"+cat in ss_cat[i]) or ("mm"+cat in ss_cat[i])) and (ss_btag[i]==bcat) for i in range(len(ss_cat))]
            os_mask = [(("ee"+cat in os_cat[i]) or ("mm"+cat in os_cat[i])) and (os_btag[i]==bcat) for i in range(len(os_cat))]
            ss_vals, os_vals = ss_m4l[ss_mask & (ss_m4l < 600) & (ss_m4l>100)], os_m4l[os_mask & (os_m4l < 600) & (os_m4l>100)]
            stats = ks_2samp(ss_vals, os_vals, alternative='two-sided', method='exact')

            
            blabel = "btag" if bcat else "0btag"
            plot_closure(
                os, ss, var, cat_label=cat_label, var_label=r"$m_{ll\tau\tau}^\mathrm{cons}$", stats=stats,
                logscale=False, outfile=f"../plots/fakes-closure/{year}/{cat}_{blabel}.pdf", year=year, lumi=lumi, btag_label="btag" if bcat else "0-btag",
            )