In [None]:
import os
import sys
from coffea import util
import uproot
import numpy as np
import mplhep as hep
import hist
from hist import Hist
from matplotlib import pyplot as plt
from cycler import cycler
from hist.axis import Variable
sys.path.append("../")

from azh_analysis.utils.histograms import integrate

def get_empty_hist():
    return Hist(
        Variable(
            [
                200,
                220,
                240,
                260,
                280,
                300,
                320,
                340,
                360,
                380,
                400,
                450,
                550,
                700,
                1000,
                2400,
            ],
            name="mass",
        )
    )

year = "2018"
mc = util.load(f'../output/MC_UL_{year}_all_OS.coffea')
signal = util.load(f'../output/signal_UL_{year}_all_OS.coffea')
data = util.load(f'../output/data_UL_{year}_OS_ub.coffea')
data_ss = util.load(f'../output/data_UL_{year}_SS_ub_relaxed.coffea')
print(mc["m4l"].keys())

group_labels = { 
    "2018": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            "WZZ",
            "WZZTuneCP5_ext1",
            "ZZZ",
            "ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWWTo2L2Nu",
        ],
        "ZHWW": [
            "HZJHToWW",
        ]
    },
    "2017": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWWTo2L2Nu",
        ]
    },
    "2016postVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets",
        ],    
        "TTW": [
            "TTWJetsToLNu",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2Lmllmin4p0",
        ],
        "WZ": [
            "WZTo2Q2L",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTauM125_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM-125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWWTo2L2Nu",
        ]
    },
    "2016preVFP": {
        "TT": [
            "TTToSemiLeptonic",
            "TTToHadronic",
            "TTTo2L2Nu",
        ],
        "TTZ": [
            "ttZJets_preFVP",
        ],    
        "TTW": [
            "TTWJetsToLNu_preFVP",
        ],
        "ggZZ": [
            "GluGluToContinToZZTo2e2tau",
            "GluGluToContinToZZTo2mu2tau",
            "GluGluToContinToZZTo4e",
            "GluGluToContinToZZTo4mu",
            "GluGluToContinToZZTo4tau",
        ],
        "ZZ": [
            "ZZTo4L",
            "ZZTo2Q2L",
        ],
        "WZ": [
            "WZTo2Q2Lmllmin4p0",
            "WZTo3LNu",
        ],
        "VVV": [
            "WWW4F",
            "WWW4F_ext1",
            "WWZ4F",
            #"WZZTuneCP5",
            "WZZ_ext1",
            "ZZZ",
            #"ZZZTuneCP5_ext1",
        ],
        "ggHtt": [
            "GluGluHToTauTauM125",
        ],
        "VBFHtt": [
            "VBFHToTauTauM125",
        ],
        "WHtt": [
            "WminusHToTauTauM125",
            "WplusHToTauTauM125",
        ],
        "ZHtt": [
            "ZHToTauTau_ext1",
        ],
        "TTHtt": [
            "ttHToTauTauM125",
        ],
        "ggHWW": [
            "GluGluHToWWTo2L2NuM125",
        ],
        "VBFHWW": [
            "VBFHToWWTo2L2NuM-125",
        ], 
        "ggZHWW": [
            "GluGluZHHToWW",
        ],
        "ggHZZ": [
            "GluGluHToZZTo4LM125",
        ],
        "WHWW": [
            "HWminusJHToWW",
            "HWplusJHToWW",
        ],
        "ZHWW": [
            "HZJHToWW",
        ]
    }
}

cats = ["eeet", "eemt", "eett", "eeem", "mmet", "mmmt", "mmtt", "mmem"]
systs = [
    "nom",
    "l1prefire_up",
    "l1prefire_down",
    "pileup_up",
    "pileup_down",
    "tauES_down",
    "tauES_up",
    "efake_down",
    "efake_up",
    "mfake_down",
    "mfake_up",
    "eleES_down",
    "eleES_up",
    "eleSmear_down",
    "eleSmear_up",
    "muES_down",
    "muES_up",
    "unclMET_down",
    "unclMET_up",
]

In [None]:
for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    print(btag_label)
    
    ############################
    # fill MC output ROOT file #
    ############################
    mc_file = uproot.recreate(f"root_for_combine/MC_{btag_label}_{year}.root")
    for group, datasets in group_labels[year].items():
        mc_group = sum(
            v for k, v in mc["m4l"].items()
            if k in datasets #k.strip("TuneCP5")
        )
        found = [k for k, _ in mc['m4l'].items() if k in datasets]
        if len(found)!=len(datasets):
            print(f"ERROR: found {found} expected {datasets}")
        
        for cat in cats:
            for syst in systs:         
                if ("btag" in syst): continue
                if (
                    (cat not in list(mc_group.axes[1])) or
                    (syst not in list(mc_group.axes[3]))
                   ):
                    group_hist = get_empty_hist()
                else: 
                    group_hist = mc_group[::sum, cat, b, syst, "cons", :]

                if "nom" in syst:
                    fname = f"{cat}/{group}"
                    mc_file[fname] =  group_hist.to_numpy()
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    fname = f"{cat}/{group}_{syst}"
                    mc_file[fname] = group_hist.to_numpy()
  
    ###################################
    # fill reducible into the MC file #
    ###################################
    data_group = sum(
        v for k, v in data["m4l"].items()
    )   
    data_ss_group = sum(
        v for k, v in data_ss["m4l"].items()
    )
    for cat in cats:
        #if (cat not in list(data_group.axes[1])):
        #    group_hist = get_empty_hist()
        #else: 
        group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
        group_ss_hist = data_ss_group["data", cat, b, ::sum, "cons", :] 
        os_norm = integrate(group_hist)
        ss_norm = integrate(group_ss_hist)
        group_ss_hist = group_ss_hist * os_norm / ss_norm
        fname = f"{cat}/reducible"
        mc_file[fname] = group_ss_hist.to_numpy()
        
    ##########################    
    # now fill the fill data #
    ##########################
    data_file = uproot.recreate(f"root_for_combine/data_{btag_label}_{year}.root")
    data_group = sum(
        v for k, v in data["m4l"].items()
    )  
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["data", cat, b, ::sum, "cons", :]

        fname = f"{cat}/data"
        data_file[fname] = group_hist.to_numpy()
    
    #######################
    # now fill for signal #
    #######################
    unique_masses = np.unique([f"{k.split('TauM')[-1]}" for k in signal["m4l"].keys()])
    ggA_masses = {
        m: [v for k, v in signal["m4l"].items() 
            if k.split("TauM")[-1]==m and "Glu" in k]
        for m in unique_masses
    }
    ggA_masses = {k: v[0] if len(v)>0 else None for k, v in ggA_masses.items()}
    bbA_masses = {
        m: [v for k, v in signal["m4l"].items() 
            if k.split("TauM")[-1]==m and "BB" in k]
        for m in unique_masses
    }
    bbA_masses = {k: v[0] if len(v)>0 else None for k, v in bbA_masses.items()}

    for m in unique_masses:
        fname = f"root_for_combine/signal_{m}_{btag_label}_{year}.root"
        file = uproot.recreate(fname)
        ggA = ggA_masses[m]
        bbA = bbA_masses[m]
        for k, v in {"ggA": ggA, "bbA": bbA}.items():
            if v is None: 
                print("WARNING: SKIPPING", k, m)
                continue
            for cat in cats:
                for syst in systs:                    
                    if ("btag" in syst): continue
                    if (
                        (cat not in list(v.axes[1])) or
                        (syst not in list(v.axes[3])) 
                       ):
                        group_hist = get_empty_hist()
                    else:
                        group_hist = v[::sum, cat, b, syst, "cons", :] 

                    if "nom" in syst:
                        fname = f"{cat}/{k}"
                        file[fname] =  group_hist.to_numpy()
                    else:
                        shift = syst.split("_")[-1]
                        syst = syst.replace(f"_{shift}", "")
                        syst = syst + shift.capitalize()
                        fname = f"{cat}/{k}_{syst}"
                        file[fname] = group_hist.to_numpy()

In [None]:
def integrate(hist):
    bins = np.array(hist.axes[0])
    widths = bins[:,1] - bins[:,0]
    vals = hist.values()
    return sum(widths*vals)

for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    data_group = sum(
        v for k, v in data["m4l"].items()
    )   
    data_ss_group = sum(
        v for k, v in data_ss["m4l"].items()
    )
    
    # fill reducible into the MC file
    for cat in cats:
        if (cat not in list(data_group.axes[1])):
            group_hist = get_empty_hist()
        else: 
            group_hist = data_group["reducible", cat, b, ::sum, "cons", :]
            group_ss_hist = data_ss_group[::sum, cat, b, ::sum, "cons", :] 
            os_norm = integrate(group_hist)
            ss_norm = integrate(group_ss_hist)
            group_ss_hist = group_ss_hist * os_norm / ss_norm
            group_hist.plot1d(histtype="step")
            group_ss_hist.plot1d(histtype="step")
            break
    break

In [None]:
uproot.open(
    "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
)

In [None]:
uproot.open(
    '~/nobackup/combine/CMSSW_11_3_4/src/HiggsAnalysis/CombinedLimit/data/tutorials/longexercise/datacard_part3.shapes.root'
).keys()

In [None]:
import uproot
base = "/eos/uscms/store/group/lpcsusyhiggs/ntuples/AZh/nAODv9/2017/BBAToZhToLLTauTauM1400/all_BBAToZhToLLTauTau_M1400_file001_part_1of3_Electrons.root"
uproot.open(base)["Events"].arrays()