In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from collections import defaultdict
from coffea import util
import uproot
import numpy as np
#import mplhep as hep
import hist
from hist import Hist
from matplotlib import pyplot as plt
from cycler import cycler
from hist.axis import Variable
sys.path.append("../")

from azh_analysis.utils.histograms import integrate, norm_to
#from azh_analysis.utils.plotting import plot_closure

In [None]:
def get_empty_hist():
    return Hist(
        Variable(
            [
                200,
                220,
                240,
                260,
                280,
                300,
                320,
                340,
                360,
                380,
                400,
                450,
                550,
                700,
                1000,
                2400,
            ],
            name="mass",
        )
    )

year = "2016"
indir = "../output_unblind"

mc_pre = util.load(f'{indir}/MC_UL_{year}preVFP_all_OS.coffea')
mc_post = util.load(f'{indir}/MC_UL_{year}postVFP_all_OS.coffea')
signal_pre = util.load(f'{indir}/signal_UL_{year}preVFP_all_OS.coffea')
signal_post = util.load(f'{indir}/signal_UL_{year}postVFP_all_OS.coffea')
data_pre = util.load(f'{indir}/data_UL_{year}preVFP_OS_ub.coffea')
data_post = util.load(f'{indir}/data_UL_{year}postVFP_OS_ub.coffea')
data_ss_pre = util.load(f'{indir}/data_UL_{year}preVFP_SS_ub_not-relaxed.coffea')
data_ss_post = util.load(f'{indir}/data_UL_{year}postVFP_SS_ub_not-relaxed.coffea')
data_ssr_pre = util.load(f"{indir}/data_UL_{year}preVFP_SS_ub_relaxed.coffea")
data_ssr_post = util.load(f"{indir}/data_UL_{year}postVFP_SS_ub_relaxed.coffea")

group_labels = { 
    "2016postVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2L",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            "WZZ_ext1",
            "ZZZ",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
            "HZJHToWW_ext1"
        ]
    },
    "2016preVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2L",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTau_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
        ]
    }
}

cats = ["eeet", "eemt", "eett", "eeem", "mmet", "mmmt", "mmtt", "mmem"]
systs = [
    "nom",
    "l1prefire_up",
    "l1prefire_down",
    "pileup_up",
    "pileup_down",
    "tauES_down",
    "tauES_up",
    "efake_down",
    "efake_up",
    "mfake_down",
    "mfake_up",
    "eleES_down",
    "eleES_up",
    "eleSmear_down",
    "eleSmear_up",
    "muES_down",
    "muES_up",
    "unclMET_down",
    "unclMET_up",
    "tauID_0_down",
    "tauID_0_up",
    "tauID_1_down",
    "tauID_1_up",
    "tauID_10_down",
    "tauID_10_up",
    "tauID_11_down",
    "tauID_11_up",
    "JES_up",
    "JES_down",
    "JER_up",
    "JER_down",
]

In [None]:
var = "m4l_binopt"
outdir = "for_alexis/towards_unblinding/datacard_templates"
for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    print(btag_label)
    
    ############################
    # fill MC output ROOT file #
    ############################
    mc_file = uproot.recreate(f"{outdir}/MC_{btag_label}_{year}.root")
    for group, _ in group_labels["2016postVFP"].items():
        
        
        # grab MC preVFP group
        datasets_pre = group_labels["2016preVFP"][group]
         
        factors_pre = {group_dataset: 1.0 for group_dataset in datasets_pre}
        #if "ggZZ" in group and "output_unblind" in indir:
        #    factors_pre = {
        #        group_dataset: 0.0027/0.00159 if "4" in group_dataset else 0.0054/0.00319
        #        for group_dataset in datasets_pre
        #    }
        #print(group, datasets_pre, factors_pre)
        
        mc_group_pre = sum(factors_pre[k]*v for k, v in mc_pre[var].items() if k in datasets_pre)
        found = [k for k, _ in mc_pre[var].items() if k in datasets_pre]
        if len(found)!=len(datasets_pre):
            print(f"ERROR: (preVFP) found {found} expected {datasets_pre}")
            
            
        # grab MC postVFP group
        datasets_post = group_labels["2016postVFP"][group]
        
        factors_post = {group_dataset: 1.0 for group_dataset in datasets_post}
        #if "ggZZ" in group and "output_unblind" in indir:
        #    factors_post = {
        #        group_dataset: 0.0027/0.00159 if "4" in group_dataset else 0.0054/0.00319
        #        for group_dataset in datasets_post
        #    }
        #print(group, datasets_post, factors_post)
        
        mc_group_post = sum(factors_post[k]*v for k, v in mc_post[var].items() if k in datasets_post)
        found = [k for k, _ in mc_post[var].items() if k in datasets_post]
        if len(found)!=len(datasets_post):
            print(f"ERROR: (postVFP) found {found} expected {datasets_post}")
            
        # combine post and pre VFP
        mc_group = mc_group_pre + mc_group_post
        
        for cat in cats:
            for syst in systs:         
                if ("btag" in syst): continue
                if (
                    (cat not in list(mc_group.axes[1])) or
                    (syst not in list(mc_group.axes[3]))
                   ):
                    group_hist = get_empty_hist()
                else: 
                    group_hist = mc_group[::sum, cat, b, syst, "cons", :]       
                
                if "tauID" in syst:
                    syst = "tauID" + syst.split("_")[1] + "_" + syst.split("_")[2]
                
                if "nom" in syst:
                    fname = f"{cat}/{group}"
                    mc_file[fname] =  group_hist
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    fname = f"{cat}/{group}_{syst}"
                    mc_file[fname] = group_hist
                    
            # extract the statistical bin errors
            #variances = defaultdict(list)
            #for k, h in mc_pre["m4l"].items():
            #    if k not in datasets_pre: continue
            #    if cat not in list(h.axes[1]): continue
            #    h = h[::sum, cat, b, "nom", "cons", :]
            #    var = np.array(h.variances())
            #    for idx, _ in enumerate(np.array(h.axes[0])):
            #        variances[str(idx)].append(var[idx])
            #        #print("pre", k, cat, idx, var[idx])      
            #for k, h in mc_post["m4l"].items():
            #    if k not in datasets_post: continue
            #    if cat not in list(h.axes[1]): continue
            #    h = h[::sum, cat, b, "nom", "cons", :]
            #    var = np.array(h.variances())
            #    for idx, _ in enumerate(np.array(h.axes[0])):
            #        variances[str(idx)].append(var[idx])
            #        #print("post", k, cat, idx, var[idx])
            
            # attempt to fill group
            #if len(variances)==0: 
            #    print("skipping", group)
            #    continue
            #group_hist = mc_group[::sum, cat, b, "nom", "cons", :]
            #values = group_hist.values()
            #stds = {k: np.sqrt(sum(v)) for k, v in variances.items()}
            #for idx in range(15):
            #    sidx = str(idx)
            #    if sidx not in list(variances.keys()): 
            #        print("skipping", group, cat, idx)
            #        continue
            #    if stds[sidx] <= 0 or values[idx] <= 0.4: 
            #        print("skipping", group, cat, idx)
            #        continue
            #    up_name = f"{cat}/{group}_{group}-{cat}-bin{idx}Up"
            #    up_hist = group_hist.copy()
            #    up_hist[idx] = (values[idx] + stds[sidx], 0.0)
            #    #print(idx, "up", up_hist)
            #    mc_file[up_name] = up_hist
            #    down_name = f"{cat}/{group}_{group}-{cat}-bin{idx}Down"
            #    down_hist = group_hist.copy()
            #    #print(idx, "down", down_hist)
            #    down_hist[idx] = (max(10**-9, values[idx] - stds[sidx]), 0.0)
            #    mc_file[down_name] = down_hist
            #    print(cat, idx, values[idx] - stds[sidx], values[idx], values[idx] + stds[sidx])
                
  
    ###################################
    # fill reducible into the MC file #
    ###################################
    data_group_pre = sum(v for k, v in data_pre[var].items()) 
    data_reg_group_pre = sum(v for k, v in data_pre["m4l_reg"].items())
    data_group_post = sum(v for k, v in data_post[var].items())
    data_reg_group_post = sum(v for k, v in data_post["m4l_reg"].items())
    data_group = data_group_pre + data_group_post
    data_reg_group = data_reg_group_pre + data_reg_group_post
    
    data_ss_group_pre = sum(v for k, v in data_ss_pre[var].items())
    data_ss_reg_group_pre = sum(v for k, v in data_ss_pre["m4l_reg"].items())
    data_ss_group_post = sum(v for k, v in data_ss_post[var].items())
    data_ss_reg_group_post = sum(v for k, v in data_ss_post["m4l_reg"].items())
    data_ss_group = data_ss_group_pre + data_ss_group_post
    data_ss_reg_group = data_ss_reg_group_pre + data_ss_reg_group_post
    
    data_ssr_group_pre = sum(v for k, v in data_ssr_pre[var].items())
    data_ssr_reg_group_pre = sum(v for k, v in data_ssr_pre["m4l_reg"].items())
    data_ssr_group_post = sum(v for k, v in data_ssr_post[var].items())
    data_ssr_reg_group_post = sum(v for k, v in data_ssr_post["m4l_reg"].items())
    data_ssr_group = data_ssr_group_pre + data_ssr_group_post
    data_ssr_reg_group = data_ssr_reg_group_pre + data_ssr_reg_group_post
        
    for cat in cats:
        group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
        group_ss_hist = data_ss_group["reducible", cat, b, ::sum, "cons", :] 
        group_ssr_hist = data_ssr_group["data", cat, b, ::sum, "cons", :]
        group_reg_hist = data_reg_group["reducible", cat, b, ::sum, "cons", :]
        group_ss_reg_hist = data_ss_reg_group["reducible", cat, b, ::sum, "cons", :] 
        group_ssr_reg_hist = data_ssr_reg_group["data", cat, b, ::sum, "cons", :]
        
        # norm the m4l
        group_ssr_hist = norm_to(group_hist, group_ssr_hist, simple=False)
        group_ss_hist = norm_to(group_hist, group_ss_hist, simple=False)
        
        # m4l_reg
        group_ssr_reg_hist = norm_to(group_reg_hist, group_ssr_reg_hist, simple=False)
        group_ss_reg_hist = norm_to(group_reg_hist, group_ss_reg_hist, simple=False)
        
        for i, ax in enumerate(group_ssr_reg_hist.axes[0]):
            group_ssr_hist_up = group_ssr_hist.copy()
            group_ssr_hist_down = group_ssr_hist.copy()
            if i>4 or i==0: continue
            down, up = ax
            ssr_val, ssr_std = group_ssr_reg_hist[i].value, np.sqrt(group_ssr_reg_hist[i].variance)
            ss_val, ss_std = group_ss_reg_hist[i].value, np.sqrt(group_ss_reg_hist[i].variance)
            sigma = ss_std / ssr_val if ssr_val > 0 else 0
            print(i, ax, sigma)
            for j, bx in enumerate(group_ssr_hist.axes[0]):
                bdown, bup = bx
                #if bup <= up and bup >= down and bdown <= up and bdown >= down:
                if (
                    (i==1 and j in [0,1,2,3,4,5,6,7,8,9]) or
                    (i==2 and j in [10, 11]) or 
                    (i==3 and j in [12, 13]) or
                    (i==4 and j in [14]) 
                ):
                    group_ssr_hist_up[j] = (group_ssr_hist_up[j].value * (1 + sigma), group_ssr_hist_up[j].variance)
                    group_ssr_hist_down[j] = (group_ssr_hist_down[j].value * 1/(1+sigma), group_ssr_hist_down[j].variance) 
                    
            fnameUp = f"{cat}/reducible_closure{i}Up"
            mc_file[fnameUp] = group_ssr_hist_up
            fnameDown = f"{cat}/reducible_closure{i}Down"
            mc_file[fnameDown] = group_ssr_hist_down
                    
        fname = f"{cat}/reducible"
        mc_file[fname] = group_ssr_hist
        
       
    ##########################    
    # now fill the fill data #
    ##########################
    data_file = uproot.recreate(f"{outdir}/data_{btag_label}_{year}.root")
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["data", cat, b, ::sum, "cons", :]

        fname = f"{cat}/data"
        data_file[fname] = group_hist
    
    #######################
    # now fill for signal #
    #######################
    unique_masses = np.unique([f"{k.split('TauM')[-1]}" for k in signal_pre[var].keys()])
    
    ggA_masses_pre = {
        m: [v for k, v in signal_pre[var].items() 
            if k.split("TauM")[-1]==m and "Glu" in k]
        for m in unique_masses
    }
    ggA_masses_pre = {k: v[0] if len(v)>0 else None for k, v in ggA_masses_pre.items()}
    
    ggA_masses_post = {
        m: [v for k, v in signal_post[var].items() 
            if k.split("TauM")[-1]==m and "Glu" in k]
        for m in unique_masses
    }
    ggA_masses_post = {k: v[0] if len(v)>0 else None for k, v in ggA_masses_post.items()}
     
    bbA_masses_pre = {
        m: [v for k, v in signal_pre[var].items() 
            if k.split("TauM")[-1]==m and "BB" in k]
        for m in unique_masses
    }
    bbA_masses_pre = {k: v[0] if len(v)>0 else None for k, v in bbA_masses_pre.items()}

    bbA_masses_post = {
        m: [v for k, v in signal_post[var].items() 
            if k.split("TauM")[-1]==m and "BB" in k]
        for m in unique_masses
    }
    bbA_masses_post = {k: v[0] if len(v)>0 else None for k, v in bbA_masses_post.items()}
    
    for m in unique_masses:
        fname = f"{outdir}/signal_{m}_{btag_label}_{year}.root"
        file = uproot.recreate(fname)
        ggA = ggA_masses_pre[m] + ggA_masses_post[m]
        bbA = bbA_masses_pre[m] + bbA_masses_post[m]
        for k, v in {"ggA": ggA, "bbA": bbA}.items():
            if v is None: 
                print("WARNING: SKIPPING", k, m)
                continue
            for cat in cats:
                for syst in systs:        
                    if ("btag" in syst): continue
                    if (
                        (cat not in list(v.axes[1])) or
                        (syst not in list(v.axes[3])) 
                       ):
                        group_hist = get_empty_hist()
                    else:
                        group_hist = v[::sum, cat, b, syst, "cons", :] 

                    if "tauID" in syst:
                        syst = "tauID" + syst.split("_")[1] + "_" + syst.split("_")[2]

                    if "nom" in syst:
                        fname = f"{cat}/{k}"
                        file[fname] =  group_hist
                    else:
                        shift = syst.split("_")[-1]
                        syst = syst.replace(f"_{shift}", "")
                        syst = syst + shift.capitalize()
                        fname = f"{cat}/{k}_{syst}"
                        file[fname] = group_hist

In [None]:
def integrate(hist):
    bins = np.array(hist.axes[0])
    widths = bins[:,1] - bins[:,0]
    vals = hist.values()
    return sum(widths*vals)

for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    data_group = sum(
        v for k, v in data["m4l"].items()
    )   
    data_ss_group = sum(
        v for k, v in data_ss["m4l"].items()
    )
    
    # fill reducible into the MC file
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
            group_ss_hist = data_ss_group[::sum, cat, b, ::sum, "cons", :] 
            os_norm = integrate(group_hist)
            ss_norm = integrate(group_ss_hist)
            group_ss_hist = group_ss_hist * os_norm / ss_norm
            group_hist.plot1d(histtype="step")
            group_ss_hist.plot1d(histtype="step")
            break
    break

In [None]:
uproot.open(
    "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
)

In [None]:
uproot.open(
    '~/nobackup/combine/CMSSW_11_3_4/src/HiggsAnalysis/CombinedLimit/data/tutorials/longexercise/datacard_part3.shapes.root'
).keys()

In [None]:
import uproot
base = "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
uproot.open(base)["Events"].arrays()

In [None]:
import uproot

mass_points = [
    '225','250','275','300','325','350','375','400','450',
    '500','600','700','750','800','900','1000',
    '1200','1400','1600','1800','2000'
]
for mass in mass_points:
    for btag in ['0btag', 'btag']:
        dir_1 = uproot.open(f"root_for_combine/signal_{mass}_{btag}_2016postVFP.root")
        dir_2 = uproot.open(f"root_for_combine/signal_{mass}_{btag}_2016preVFP.root")
        for j in dir_1.keys():
            print(tree)
            for k in dir_1[j].keys():
                print(k)
                h1 = dir_1[tree][k]
                h2 = dir_2[tree][k]
                print(h1.to_hist() + h2.to_hist())
                