In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('../')
from os.path import join
import uproot
from coffea import util
import numpy as np
from matplotlib import pyplot as plt
from cycler import cycler
from hist.intervals import ratio_uncertainty
from azh_analysis.utils.plotting import plot_fake_rates_data, plot_fake_rate_measurements
from azh_analysis.utils.parameters import get_lumis

plot = True
year = "2016postVFP"
lumi = get_lumis()[year if "2016" not in year else "2016"]
data = util.load(f"../output/fake_rates_data_UL_{year}.coffea")
MC = util.load(f"../output/fake_rates_MC_UL_{year}.coffea")
data = sum(data["pt"].values()) 
MC = sum(MC["pt"].values())

if "2016" in year:
    data_2016preVFP = sum(util.load(f"../output/fake_rates_data_UL_2016preVFP.coffea")["pt"].values())
    data_2016postVFP = sum(util.load(f"../output/fake_rates_data_UL_2016postVFP.coffea")["pt"].values())
    MC_2016preVFP = sum(util.load(f"../output/fake_rates_MC_UL_2016preVFP.coffea")["pt"].values())
    MC_2016postVFP = sum(util.load(f"../output/fake_rates_MC_UL_2016postVFP.coffea")["pt"].values())
    data = data_2016preVFP + data_2016postVFP
    MC = MC_2016preVFP + MC_2016postVFP
    
pt_bins = ['$10<p_T<20$ GeV', '$20<p_T<30$ GeV', '$30<p_T<40$ GeV', '$40<p_T<60$ GeV', '$60<p_T<1000000$ GeV']
eta_bins = ['$|\\eta|<1.479$', '$|\\eta|>1.479$']
decay_modes = {0: -1, 1: 0, 2: 1, 3: 2, 4: 10, 5: 11, 6: 15}
ranges = {
    ("e", '$|\\eta|<1.479$', None): [0, 0.014],
    ("e", '$|\\eta|>1.479$', None): [0, 0.035],
    ("m", '$|\\eta|<1.479$', None): [0, 0.1],
    ("m", '$|\\eta|>1.479$', None): [0, 0.14],
    ("et", None, 0): [0, 0.18],
    ("et", None, 1): [0, 0.2],
    ("et", None, 10): [0, 0.18],
    ("et", None, 11): [0, 0.1],
    ("mt", None, 0): [0, 0.35],
    ("mt", None, 1): [0, 0.25], 
    ("mt", None, 10): [0, 0.18],
    ("mt", None, 11): [0, 0.1], 
    ("tt", None, 0): [0, 0.35],
    ("tt", None, 1): [0, 0.25],
    ("tt", None, 10): [0, 0.18],
    ("tt", None, 11): [0, 0.1],
}
filenames = {
    "e": f"JetEleFakeRate_Fall17MVAv2WP90_noIso_Iso0p15_UL{year}_coffea.root",
    "m": f"JetMuFakeRate_Medium_Iso0p15_UL{year}_coffea.root",
    "et": f"JetTauFakeRate_Medium_Tight_VLoose_UL{year}_coffea.root",
    "mt": f"JetTauFakeRate_Medium_VLoose_Tight_UL{year}_coffea.root",
    "tt": f"JetTauFakeRate_Medium_VLoose_VLoose_UL{year}_coffea.root",
}
for k, v in filenames.items():
    filenames[k] = join(f"../corrections/fake_rates/UL_{year}", v)
    
ratio_dir = f"../plots/fakes-ratio/{year}"
meas_dir = f"../plots/fakes-measurement/{year}"
for mode in ["e", "m", "et", "mt", "tt"]:
    with uproot.recreate(filenames[mode]) as f:
        print("Mode:", mode)
        
        #### jet faking electron or muon ####
        if "t" not in mode:
            for eta_bin in eta_bins:
                
                label = f"Electron Fakes\n{eta_bin}" if "e" in mode else f"Muon Fakes\n{eta_bin}"
                data_d = (
                    data[::sum, "ee"+mode, ::sum, "Denominator", ::sum, eta_bin, :] + 
                    data[::sum, "mm"+mode, ::sum, "Denominator", ::sum, eta_bin, :] +
                    -1 * MC[::sum, "ee"+mode, "Prompt", "Denominator", ::sum, eta_bin, :] + 
                    -1 * MC[::sum, "mm"+mode, "Prompt", "Denominator", ::sum, eta_bin, :]
                )
                data_n = (
                    data[::sum, "ee"+mode, ::sum, "Numerator", ::sum, eta_bin, :] + 
                    data[::sum, "mm"+mode, ::sum, "Numerator", ::sum, eta_bin, :] +
                    -1 * MC[::sum, "ee"+mode, "Prompt", "Numerator", ::sum, eta_bin, :] + 
                    -1 * MC[::sum, "mm"+mode, "Prompt", "Numerator", ::sum, eta_bin, :]
                )  

                eta_label = "barrel" if "<" in eta_bin else "endcaps"
                plot_fake_rates_data(
                    data_d, data_n, label, 
                    xlim=[10,100], ylim=ranges[(mode, eta_bin, None)],
                    outfile=join(ratio_dir, f"{mode}_{eta_label}.pdf") if plot else None,
                    year=year if "2016" not in year else "2016", lumi=lumi,
                )
                edges, centers, ratios, yerr, xerr = plot_fake_rate_measurements(
                    data_d, data_n, label, 
                    xlim=[10,100], ylim=ranges[(mode, eta_bin, None)],
                    outfile=join(meas_dir, f"{mode}_{eta_label}.pdf") if plot else None,
                    plot_fit=False, year=year if "2016" not in year else "2016", lumi=lumi,
                )
                h = np.histogram(centers, bins=edges, weights=ratios)
                print(h)
                if "<" in eta_bin:
                    f["POL2FitFR_Central_barrel"] = h
                elif ">" in eta_bin:
                    f["POL2FitFR_Central_endcap"] = h
        
        #### jet faking hadronic taus ####
        else:
            for decay_mode in decay_modes.keys():
                    label = (
                        r"$\tau_h$ Fakes" + f"\n Decay Mode {decay_modes[decay_mode]}\n"
                    )
                    if "e" in mode:
                        label += f"(Tight, VLoose, Medium)"
                    elif "m" in mode:
                        label += f"(VLoose, Tight, Medium)"
                    else:
                        label += f"(VLoose, VLoose, Medium)"

                    data_d = (
                        data[::sum, "ee"+mode, ::sum, "Denominator", decay_mode, ::sum, :] + 
                        data[::sum, "mm"+mode, ::sum, "Denominator", decay_mode, ::sum, :] +
                        -1 * MC[::sum, "ee"+mode, "Prompt", "Denominator", decay_mode, ::sum, :] +
                        -1 * MC[::sum, "mm"+mode, "Prompt", "Denominator", decay_mode, ::sum, :]
                    )
                    data_n = (
                        data[::sum, "ee"+mode, ::sum, "Numerator", decay_mode, ::sum, :] + 
                        data[::sum, "mm"+mode, ::sum, "Numerator", decay_mode, ::sum, :] +
                        -1 * MC[::sum, "ee"+mode, "Prompt", "Numerator", decay_mode, ::sum, :] +
                        -1 * MC[::sum, "mm"+mode, "Prompt", "Numerator", decay_mode,::sum, :]
                    )
                    if np.sum(data_d.values())==0: continue
                        
                    plot_fake_rates_data(
                        data_d, data_n, label, xlim=[20,100], 
                        ylim=ranges[(mode, None, decay_modes[decay_mode])],
                        combine_bins=True,
                        outfile=join(ratio_dir, f"{mode}_{decay_mode}.pdf") if plot else None,
                        year=year if "2016" not in year else "2016", lumi=lumi,
                    )
                    edges, centers, ratios, yerr, xerr = plot_fake_rate_measurements(
                        data_d, data_n, label, xlim=[20,100],
                        ylim=ranges[(mode, None, decay_modes[decay_mode])],
                        combine_bins=True,
                        outfile=join(meas_dir, f"{mode}_{decay_mode}.pdf") if plot else None,
                        plot_fit=False, year=year if "2016" not in year else "2016", lumi=lumi,
                    )
                    h = np.histogram(centers, bins=edges, weights=ratios)
                    print(h)
                    f[f"POL2FitFR_Central_DM{decay_modes[decay_mode]}"] = h
        print(f)