In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from coffea import util
from collections import defaultdict
import uproot
import numpy as np
import mplhep as hep
import hist
from hist import Hist
from matplotlib import pyplot as plt
from cycler import cycler
from hist.axis import Variable
sys.path.append("../")

from azh_analysis.utils.histograms import integrate, norm_to
from azh_analysis.utils.plotting import plot_closure

def get_empty_hist():
    return Hist(
        Variable(
            [
                200,
                220,
                240,
                260,
                280,
                300,
                320,
                340,
                360,
                380,
                400,
                450,
                550,
                700,
                1000,
                2400,
            ],
            name="mass",
        )
    )

year = "2018"

indir = "output_unblind"
mc = util.load(f'../{indir}/MC_UL_{year}_all_OS.coffea')
signal = util.load(f'../{indir}/signal_UL_{year}_all_OS.coffea')
data = util.load(f'../{indir}/data_UL_{year}_OS_ub.coffea')
data_ss = util.load(f'../{indir}/data_UL_{year}_SS_ub_not-relaxed.coffea')
data_ssr = util.load(f"../{indir}/data_UL_{year}_SS_ub_relaxed.coffea")

print(mc["m4l"].keys())

group_labels = { 
    "2018": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            "WZZ",
            "WZZ_ext1",
            "ZZZ",
            "ZZZ_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWWTo2L2Nu",
        ],
        "ZHWW": [
            "HZJHToWW",
        ]
    },
    "2017": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
            "HZJHToWW_ext1"
        ]
    },
    "2016postVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2L",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
            "HZJHToWW_ext1"
        ]
    },
    "2016preVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2L",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTau_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
        ]
    }
}

cats = ["eeet", "eemt", "eett", "eeem", "mmet", "mmmt", "mmtt", "mmem"]
systs = [
    "nom",
    "l1prefire_up",
    "l1prefire_down",
    "pileup_up",
    "pileup_down",
    "tauES_down",
    "tauES_up",
    "efake_down",
    "efake_up",
    "mfake_down",
    "mfake_up",
    "eleES_down",
    "eleES_up",
    "eleSmear_down",
    "eleSmear_up",
    "muES_down",
    "muES_up",
    "unclMET_down",
    "unclMET_up",
    "tauID_0_down",
    "tauID_0_up",
    "tauID_1_down",
    "tauID_1_up",
    "tauID_10_down",
    "tauID_10_up",
    "tauID_11_down",
    "tauID_11_up",
    "JES_up",
    "JES_down",
    "JER_up",
    "JER_down",
]

In [None]:
var = "m4l_binopt"
def get_hist_var(hist_dict, datasets, cat, btag, syst):
    variances = np.zeros(15)
    for name, hist in hist_dict[var].items():
        if name not in datasets: continue
        if ((cat not in list(hist.axes[1])) or
            (syst not in list(hist.axes[3]))): continue
        hist = hist[::sum, cat, btag, syst, "cons", :]
        variances += np.array(hist.variances())
    return variances

for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    mc_file = uproot.recreate(f"root_for_combine/dummy.root")
    
    for group, datasets in group_labels[year].items():
        # grab all the necessary dataasets in the group of interest
        mc_group = sum(
            v for k, v in mc[var].items()
            if k in datasets #k.strip("TuneCP5")
        )
        found = [k for k, _ in mc[var].items() if k in datasets]
        if len(found)!=len(datasets):
            print(f"ERROR: found {found} expected {datasets}")
        
        for cat in cats:
            
            for syst in systs:
                if ("btag" in syst): continue
                    
                if ( (cat not in list(mc_group.axes[1])) or
                     (syst not in list(mc_group.axes[3])) ):
                    group_hist = get_empty_hist() 
                    group_hist_var = np.zeros(15)
                else: 
                    group_hist = mc_group[::sum, cat, b, syst, "cons", :] 
                    # group_hist_var = get_hist_var(mc, datasets, cat, b, syst)  
                if "tauID" in syst:
                    syst = "tauID" + syst.split("_")[1] + "_" + syst.split("_")[2]
                if "nom" in syst:
                    fname = f"{cat}/{group}"
                    temp = group_hist.copy()
                    #temp._variances = group_hist_var
                    mc_file[fname] =  temp
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    fname = f"{cat}/{group}_{syst}"
                    mc_file[fname] = group_hist

In [None]:
var = "m4l_binopt"
outdir = "unblinded/datacard_templates"
for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    
    ############################
    # fill MC output ROOT file #
    ############################
    mc_file = uproot.recreate(f"{outdir}/MC_{btag_label}_{year}.root")
    for group, datasets in group_labels[year].items():
        
        factors = {dataset: 1.0 for dataset in datasets}
        #if "ggZZ" in group and "output_unblind" in indir and "2018" not in year:
        #    factors = {
        #        dataset: 0.0027/0.00159 if "4" in dataset else 0.0054/0.00319
        #        for dataset in datasets
        #    }
        #print(group, datasets, factors)
        
        mc_group = sum(
            factors[k] * v for k, v in mc[var].items()
            if k in datasets #k.strip("TuneCP5")
        )
            
        found = [k for k, _ in mc[var].items() if k in datasets]
        if len(found)!=len(datasets):
            print(f"ERROR: found {found} expected {datasets}")
        
        for cat in cats:
            for syst in systs:
                
                if ("btag" in syst): continue
                if (
                    (cat not in list(mc_group.axes[1])) or
                    (syst not in list(mc_group.axes[3]))
                   ):
                    group_hist = get_empty_hist()
                else: 
                    group_hist = mc_group[::sum, cat, b, syst, "cons", :]       
                
                if "tauID" in syst:
                    syst = "tauID" + syst.split("_")[1] + "_" + syst.split("_")[2]
                
                if "nom" in syst:
                    fname = f"{cat}/{group}"
                    mc_file[fname] =  group_hist
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    fname = f"{cat}/{group}_{syst}"
                    mc_file[fname] = group_hist

  
    ###################################
    # fill reducible into the MC file #
    ###################################
    data_group = sum(
        v for k, v in data[var].items()
    ) 
    data_reg_group = sum(
        v for k, v in data["m4l_reg"].items()
    ) 
    data_ss_group = sum(
        v for k, v in data_ss[var].items()
    )
    data_ss_reg_group = sum(
        v for k, v in data_ss["m4l_reg"].items()
    )
    data_ssr_group = sum(
        v for k, v in data_ssr[var].items()
    )
    data_ssr_reg_group = sum(
        v for k, v in data_ssr["m4l_reg"].items()
    )
    for cat in cats:
        group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
        group_reg_hist = data_reg_group["reducible", cat, b, ::sum, "cons", ::4j]
        group_ss_hist = data_ss_group["reducible", cat, b, ::sum, "cons", :] 
        group_ss_reg_hist = data_ss_reg_group["reducible", cat, b, ::sum, "cons", ::4j] 
        group_ssr_hist = data_ssr_group["data", cat, b, ::sum, "cons", :]
        group_ssr_reg_hist = data_ssr_reg_group["data", cat, b, ::sum, "cons", ::4j]
        
        # m4l
        group_ssr_hist = norm_to(group_hist, group_ssr_hist, simple=False)
        group_ss_hist = norm_to(group_hist, group_ss_hist, simple=False)
        
        # m4l_reg
        group_ssr_reg_hist = norm_to(group_reg_hist, group_ssr_reg_hist, simple=False)
        group_ss_reg_hist = norm_to(group_reg_hist, group_ss_reg_hist, simple=False)
        
        for i, ax in enumerate(group_ssr_reg_hist.axes[0]):
            print(f"closure{i}")
            group_ssr_hist_up = group_ssr_hist.copy()
            group_ssr_hist_down = group_ssr_hist.copy()
            if i>4 or i==0: continue
            down, up = ax
            print(f"{ax}")
            ssr_val, ssr_std = group_ssr_reg_hist[i].value, np.sqrt(group_ssr_reg_hist[i].variance)
            ss_val, ss_std = group_ss_reg_hist[i].value, np.sqrt(group_ss_reg_hist[i].variance)
            sigma = ss_std / ssr_val if ssr_val > 0 else 0
            for j, bx in enumerate(group_ssr_hist.axes[0]):
                bdown, bup = bx
                print(f"  {bx}")
                #if bup <= up and bup >= down and bdown <= up and bdown >= down:
                if (
                    (i==1 and j in [0,1,2,3,4,5,6,7,8,9]) or
                    (i==2 and j in [10, 11]) or 
                    (i==3 and j in [12, 13]) or
                    (i==4 and j in [14]) 
                ):
                    group_ssr_hist_up[j] = (group_ssr_hist_up[j].value * (1 + sigma), group_ssr_hist_up[j].variance)
                    group_ssr_hist_down[j] = (group_ssr_hist_down[j].value * 1/(1+sigma), group_ssr_hist_down[j].variance) 
                    print(f"       in bin, {group_ssr_hist_up[j]}, {group_ssr_hist_down[j]}")
                    
            fnameUp = f"{cat}/reducible_closure{i}Up"
            mc_file[fnameUp] = group_ssr_hist_up
            fnameDown = f"{cat}/reducible_closure{i}Down"
            mc_file[fnameDown] = group_ssr_hist_down
                    
        fname = f"{cat}/reducible"
        mc_file[fname] = group_ssr_hist
        
    ##########################    
    # now fill the fill data #
    ##########################
    data_file = uproot.recreate(f"{outdir}/data_{btag_label}_{year}.root")
    data_group = sum(
        v for k, v in data[var].items()
    )  
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["data", cat, b, ::sum, "cons", :]

        fname = f"{cat}/data"
        data_file[fname] = group_hist
        
    
    #######################
    # now fill for signal #
    #######################
    unique_masses = np.unique([f"{k.split('TauM')[-1]}" for k in signal[var].keys()])
    ggA_masses = {
        m: [v for k, v in signal[var].items() 
            if k.split("TauM")[-1]==m and "Glu" in k]
        for m in unique_masses
    }
    ggA_masses = {k: v[0] if len(v)>0 else None for k, v in ggA_masses.items()}
    bbA_masses = {
        m: [v for k, v in signal[var].items() 
            if k.split("TauM")[-1]==m and "BB" in k]
        for m in unique_masses
    }
    bbA_masses = {k: v[0] if len(v)>0 else None for k, v in bbA_masses.items()}

    for m in unique_masses:
        fname = f"{outdir}/signal_{m}_{btag_label}_{year}.root"
        file = uproot.recreate(fname)
        ggA = ggA_masses[m]
        bbA = bbA_masses[m]
        for k, v in {"ggA": ggA, "bbA": bbA}.items():
            if v is None: 
                print("WARNING: SKIPPING", k, m)
                continue
            for cat in cats:
                for syst in systs:        
                    if ("btag" in syst): continue
                    if (
                        (cat not in list(v.axes[1])) or
                        (syst not in list(v.axes[3])) 
                       ):
                        group_hist = get_empty_hist()
                    else:
                        group_hist = v[::sum, cat, b, syst, "cons", :] 

                    if "tauID" in syst:
                        syst = "tauID" + syst.split("_")[1] + "_" + syst.split("_")[2]

                    if "nom" in syst:
                        fname = f"{cat}/{k}"
                        file[fname] =  group_hist
                    else:
                        shift = syst.split("_")[-1]
                        syst = syst.replace(f"_{shift}", "")
                        syst = syst + shift.capitalize()
                        fname = f"{cat}/{k}_{syst}"
                        file[fname] = group_hist

In [None]:
def integrate(hist):
    bins = np.array(hist.axes[0])
    widths = bins[:,1] - bins[:,0]
    vals = hist.values()
    return sum(widths*vals)

for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    data_group = sum(
        v for k, v in data["m4l"].items()
    )   
    data_ss_group = sum(
        v for k, v in data_ss["m4l"].items()
    )
    
    # fill reducible into the MC file
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
            group_ss_hist = data_ss_group[::sum, cat, b, ::sum, "cons", :] 
            os_norm = integrate(group_hist)
            ss_norm = integrate(group_ss_hist)
            group_ss_hist = group_ss_hist * os_norm / ss_norm
            group_hist.plot1d(histtype="step")
            group_ss_hist.plot1d(histtype="step")
            break
    break

In [None]:
uproot.open(
    "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
)

In [None]:
uproot.open(
    '~/nobackup/combine/CMSSW_11_3_4/src/HiggsAnalysis/CombinedLimit/data/tutorials/longexercise/datacard_part3.shapes.root'
).keys()

In [None]:
import uproot
base = "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
uproot.open(base)["Events"].arrays()

In [None]:
import uproot

mass_points = [
    '225','250','275','300','325','350','375','400','450',
    '500','600','700','750','800','900','1000',
    '1200','1400','1600','1800','2000'
]
for mass in mass_points:
    for btag in ['0btag', 'btag']:
        dir_1 = uproot.open(f"root_for_combine/signal_{mass}_{btag}_2016postVFP.root")
        dir_2 = uproot.open(f"root_for_combine/signal_{mass}_{btag}_2016preVFP.root")
        for j in dir_1.keys():
            print(tree)
            for k in dir_1[j].keys():
                print(k)
                h1 = dir_1[tree][k]
                h2 = dir_2[tree][k]
                print(h1.to_hist() + h2.to_hist())
                