In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import argparse
import warnings
import logging
from os.path import join
from collections import defaultdict
sys.path.append('../')
warnings.filterwarnings("ignore")

from scipy.stats import ks_2samp
import numpy as np
from coffea import processor, util
import hist
import mplhep as hep
import seaborn as sns
from matplotlib import pyplot as plt
from cycler import cycler
from hist import Hist
from hist.intervals import ratio_uncertainty
from azh_analysis.utils.plotting import plot_data_vs_mc, plot_systematic
from azh_analysis.utils.parameters import get_lumis
import warnings
warnings.filterwarnings('ignore')
#hep.style.use(["CMS", "fira", "firamath"])

sys.path.append("../")
from azh_analysis.utils.histograms import integrate, norm_to


plot = True
year = "2018"
lumi = get_lumis()[year]
OS = util.load(f"../output/data_UL_{year}_OS_ub.coffea")
SSr = util.load(f"../output/data_UL_{year}_SS_ub_relaxed.coffea")
SS = util.load(f"../output/data_UL_{year}_SS_ub_not-relaxed.coffea")
print(OS.keys())
#ss_m4l, os_m4l = np.array(SS["reducible_m4l_reg"].value), np.array(OS["reducible_m4l_reg"].value)
#ss_cat, os_cat = np.array(SS["reducible_cat"].value), np.array(OS["reducible_cat"].value)
#ss_btag, os_btag = np.array(SS["reducible_btag"].value), np.array(OS["reducible_btag"].value)

cat_labels = {
    'tt': r'$ll\tau_h\tau_h$',
    'et': r'$ll e\tau_h$',
    'mt': r'$ll\mu\tau_h$',
    'em': r'$ll e\mu$',
}

In [None]:
def plot_closure(
    h1,
    h1_label,
    h2,
    h2_label,
    h3, 
    h3_label,
    var,
    cat_label,
    var_label,
    btag_label,
    stats=None,
    logscale=False,
    outfile=None,
    year="1",
    lumi="1",
    blind=False,
    xerr=None,
):
    hep.style.use(["CMS", "fira", "firamath"])
    colors = {
        "h1": "#005F73",
        "h2": "#0A9396",
    }
    
    fig, (ax, rax) = plt.subplots(
        nrows=2,
        ncols=1,
        figsize=(9, 12),
        dpi=120,
        gridspec_kw={"height_ratios": (4, 1)},
        sharex=True,
    )
    fig.subplots_adjust(hspace=0.07)
    ax.set_prop_cycle(cycler(color=list(colors.values())))
    
    h1.plot1d(ax=ax, histtype="errorbar", xerr=100, yerr=np.sqrt(h1.variances()), 
              color="#0A9396", marker="s", markersize=5, mfc='#0A9396', mec='#0A9396', capsize=2, label=h1_label, alpha=0.5)
    h2.plot1d(ax=ax, histtype="errorbar", xerr=100, yerr=np.sqrt(h2.variances()), 
              color='#EE9B00', marker="o", markersize=5, mfc='#EE9B00', mec='#EE9B00', capsize=2, label=h2_label, alpha=0.5)
    h3.plot1d(ax=ax, histtype="errorbar", xerr=100, yerr=np.sqrt(h3.variances()), 
              color='#005F73', marker="o", markersize=5, mfc='#005F73', mec='#005F73', capsize=2, label=h3_label, alpha=0.5)
    h1_vals = h1.values()
    h2_vals = h2.values()
    
    bins = h1.axes[0].centers
    y = h1_vals / h2_vals
    yerr = ratio_uncertainty(h1_vals, h2_vals, "poisson")

    rax.errorbar(
        x=bins,
        y=y,
        yerr=yerr,
        color="k",
        linestyle="none",
        marker="o",
        elinewidth=1,
    )

    if logscale:
        ax.set_yscale("log")
        ax.set_xscale("log")
    ax.set_ylabel("Counts")
    rax.set_xlabel(var_label)
    ax.set_xlabel("")
    ax.legend()
    rax.set_ylim([0, 4])
    rax.axhline(1, color="black", linestyle='--')
    #ax.set_xlim([200, 2000])
    #rax.set_xlim([200, 2000])
    if not logscale:
        ax.set_xlim([50,1200])
        rax.set_xlim([50,1200])
    rax.set_ylabel(f"{h1_label} / {h2_label}", fontsize=10)

    ax.legend(loc="best", prop={"size": 16}, frameon=True, title_fontsize="small")
    ax.get_legend().set_title(f"{cat_label}, {btag_label}") #\nKS={(stats.statistic):.2f}, p={stats.pvalue:.3f}")
    hep.cms.label("Preliminary", data=True, lumi=lumi, year=year, ax=ax)
    if outfile is not None:
        plt.savefig(outfile, format="pdf", dpi=800)
    plt.show()

In [None]:
from scipy.optimize import minimize
from scipy.stats import chisquare


var = "m4l"
outdir = f"../plots/data-mc/{year}"
for cat, cat_label in cat_labels.items():
    for bcat in [0, 1]:
        for mass_type in ["cons"]:
            os = sum(OS[var].values())
            os = os["reducible", "ee" + cat, bcat, "none", mass_type, :] + os["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ssr = sum(SSr[var].values())
            ssr = ssr["data", "ee" + cat, bcat, "none", mass_type, :] + ssr["data", "mm" + cat, bcat, "none", mass_type, :]
            
            #ssr = ssr[::2j]
            #os = os[::2j]
            os_norm = integrate(os)
            ssr_norm = integrate(ssr)
            def f(x):
                diff = (os + -1 * ssr * x / ssr_norm).values().sum()
                return abs(diff)
            
            res = minimize(f, os_norm, method='Nelder-Mead', tol=1e-6)
            print(os_norm, res.x[0])
            ssr = ssr * res.x[0] / ssr_norm
            
            blabel = "btag" if bcat else "0btag"
            plot_closure(
                h1=os, h1_label="OS Application", h2=ssr, h2_label="SS Relaxed", var=var, cat_label=cat_label, var_label=r"$m_{ll\tau\tau}^\mathrm{cons}$",
                logscale=False, outfile=f"../plots/fakes-closure/{year}/{cat}_{blabel}.pdf", year=year, lumi=lumi, btag_label="btag" if bcat else "0-btag",
            )

In [None]:
plot = True
year = "2018"
lumi = get_lumis()[year]
#SS = util.load(f"../output/data_UL_{year}_SS_ub_not-relaxed.coffea")
#SSr = util.load(f"../output/data_UL_{year}_SS_ub_relaxed.coffea")
ssr_m4l, ss_m4l = np.array(SSr["data_m4l_reg"].value), np.array(SS["reducible_m4l_reg"].value)
ssr_cat, ss_cat = np.array(SSr["data_cat"].value), np.array(SS["reducible_cat"].value)
ssr_btag, ss_btag = np.array(SSr["data_btag"].value), np.array(SS["reducible_btag"].value)
cat_labels = {
    'tt': r'$ll\tau_h\tau_h$',
    'et': r'$ll e\tau_h$',
    'mt': r'$ll\mu\tau_h$',
    'em': r'$ll e\mu$',
}

In [None]:
var = "m4l_reg"
for cat, cat_label in cat_labels.items():
    for bcat in [0, 1]:
        for mass_type in ["cons"]:
            os = sum(OS[var].values())
            os = os["reducible", "ee" + cat, bcat, "none", mass_type, :] + os["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ss = sum(SS[var].values())
            ss = ss["reducible", "ee" + cat, bcat, "none", mass_type, :] + ss["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ssr = sum(SSr[var].values())
            ssr = ssr["data", "ee" + cat, bcat, "none", mass_type, :] + ssr["data", "mm" + cat, bcat, "none", mass_type, :]
            
            os = os[::4j]
            ss = ss[::4j]
            ssr = ssr[::4j]
            ssr = norm_to(os, ssr, simple=False)
            ss = norm_to(os, ss, simple=False)
            #ss_norm, ssr_norm, os_norm = integrate(ss), integrate(ssr), integrate(os)
            #ss = ss * os_norm / ss_norm
            #ssr = ssr * os_norm / ssr_norm
            
            outdir = "/uscms_data/d3/jdezoort/AZh_columnar/CMSSW_10_2_9/src/azh_coffea/src/corrections/closure/UL_{year}"
            axes = ssr.axes[0]
            sigmas = []
            for i, a in enumerate(axes): 
                ssr_val, ssr_std = ssr[i].value, np.sqrt(ssr[i].variance)
                ss_val, ss_std = ss[i].value, np.sqrt(ss[i].variance)
                sigma = ss_std / ssr_val if ssr_val > 0 else 0
                sigmas.append(sigma)
            print(sigmas)
            
            blabel = "btag" if bcat else "0btag"
            plot_closure(
                h1=ss, h1_label="SS Application", h2=ssr, h2_label="SS Relaxed", h3=os, h3_label="OS Application", var=var, cat_label=cat_label, var_label=r"$m_{ll\tau\tau}^\mathrm{cons}$", stats=None,
                logscale=False, outfile=f"../plots/fakes-closure/{year}/{var}_{cat}_{blabel}.pdf", year=year, lumi=lumi, btag_label="btag" if bcat else "0-btag",
            )

In [None]:
var = "m4l"
outdir = f"../plots/data-mc/{year}"
for cat, cat_label in cat_labels.items():
    for bcat in [0, 1]:
        for mass_type in ["cons"]:
            os = sum(OS[var].values())
            os = os["reducible", "ee" + cat, bcat, "none", mass_type, :] + os["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ss = sum(SS[var].values())
            ss = ss["reducible", "ee" + cat, bcat, "none", mass_type, :] + ss["reducible", "mm" + cat, bcat, "none", mass_type, :]
            ssr = sum(SSr[var].values())
            ssr = ssr["data", "ee" + cat, bcat, "none", mass_type, :] + ssr["data", "mm" + cat, bcat, "none", mass_type, :]
            
            os_yield = os.sum().value
            bins = np.array([(b[0] + b[1])/2 for b in ss.axes[0]])
            widths = np.array([b[1]-b[0] for b in ss.axes[0]])
            
            os_norm = integrate(os)
            ss_norm = integrate(ss)
            ssr_norm = integrate(ssr)
            
            ss = norm_to(os, ss)
            ssr = norm_to(os, ssr)
            
            os_n = os.values()
            ss_n = ss.values()
            ssr_n = ssr.values()
            #ss = np.histogram(bins+1, bins=bins, weights=ss)
            #sr = np.histogram(bins+1, bins=bins, weights=ssr)
            
            down = ss_n
            down[down < 0] = 10**-5
            nom = ssr_n
            up = 2*ssr_n + -1*ss_n
            up[up < 0] = 10**-5
            fig, ax = plt.subplots(dpi=200)
            ax.plot(bins, down, '.-', label="down")
            ax.plot(bins, nom, '.-', label="nom")
            ax.plot(bins, up, '.-', label="up")
            ax.set_xlabel(r"$m_{ll\tau\tau}^c$ [GeV]")
            plt.legend(loc="best")
            plt.show()
            
            #ss = ss * res.x[0] / ss_norm 
            #ssr = ssr * resr.x[0] / ssr_norm
            #stats = ks_2samp(ss_vals, ssr_vals, alternative='two-sided', method='exact')

            
            blabel = "btag" if bcat else "0btag"
            plot_closure(
                h1=os, h1_label="OS Application", h2=ssr, h2_label="SS Relaxed", var=var, cat_label=cat_label, var_label=r"$m_{ll\tau\tau}^\mathrm{cons}$", stats=None,
                logscale=True, outfile=f"../plots/fakes-closure/{year}/{var}_{cat}_{blabel}.pdf", year=year, lumi=lumi, btag_label="btag" if bcat else "0-btag",
            )