In [None]:
import os
from coffea import util
import uproot
import numpy as np
import mplhep as hep
import hist
from hist import Hist
from matplotlib import pyplot as plt
from cycler import cycler
from hist.axis import Variable

def get_empty_hist():
    return Hist(
        Variable(
            [
                200,
                220,
                240,
                260,
                280,
                300,
                320,
                340,
                360,
                380,
                400,
                450,
                550,
                700,
                1000,
                2400,
            ],
            name="mass",
        )
    )


mc = util.load('../output/MC_UL_2018_all_OS.coffea')
signal = util.load('../output/signal_UL_2018_all_OS.coffea')
data = util.load('../output/data_UL_2018_OS_ub.coffea')

group_labels = {
    "TT": [
        "TTToSemiLeptonic",
        "TTToHadronic",
        "TTTo2L2Nu",
    ],
    "TTZ": [
        "ttZJets",
    ],    
    "TTW": [
        "TTWJetsToLNu",
    ],
    "ggZZ": [
        "GluGluToContinToZZTo2e2tau",
        "GluGluToContinToZZTo2mu2tau",
        "GluGluToContinToZZTo4e",
        "GluGluToContinToZZTo4mu",
        "GluGluToContinToZZTo4tau",
    ],
    "ZZ": [
        "ZZTo4L",
        "ZZTo2Q2Lmllmin4p0",
    ],
    "WZ": [
        "WZTo2Q2Lmllmin4p0",
        "WZTo3LNu",
    ],
    "VVV": [
        "WWW4F",
        "WWW4F_ext1",
        "WWZ4F",
        "WZZ",
        "WZZTuneCP5_ext1",
        "ZZZ",
        "ZZZTuneCP5_ext1",
    ],
    "ggHtt": [
        "GluGluHToTauTauM125",
    ],
    "VBFHtt": [
        "VBFHToTauTauM125",
    ],
    "WHtt": [
        "WminusHToTauTauM125",
        "WplusHToTauTauM125",
    ],
    "ZHtt": [
        "ZHToTauTauM125_ext1",
    ],
    "TTHtt": [
        "ttHToTauTauM125",
    ],
    "ggHWW": [
        "GluGluHToWWTo2L2NuM-125",
    ],
    "VBFHWW": [
        "VBFHToWWTo2L2NuM-125",
    ], 
    "ggZHWW": [
        "GluGluZHHToWW",
    ],
    "ggHZZ": [
        "GluGluHToZZTo4LM125",
    ],
    "WHWW": [
        "HWminusJHToWW",
        "HWplusJHToWWTo2L2Nu",
    ],
    "ZHWW": [
        "HZJHToWW",
    ]
}

cats = ["eeet", "eemt", "eett", "eeem", "mmet", "mmmt", "mmtt", "mmem"]
systs = [
    "nom",
    "l1prefire_up",
    "l1prefire_down",
    "pileup_up",
    "pileup_down",
    "tauES_down",
    "tauES_up",
    "efake_down",
    "efake_up",
    "mfake_down",
    "mfake_up",
    "eleES_down",
    "eleES_up",
    "eleSmear_down",
    "eleSmear_up",
    "muES_down",
    "muES_up",
    "unclMET_down",
    "unclMET_up",
]

In [None]:
for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    file = uproot.recreate(f"root_for_combine/MC_2018_{btag_label}.root")
    for group, datasets in group_labels.items():
        print(f"{group}")
        mc_group = sum(
            v for k, v in mc["m4l"].items()
            if k in datasets
        )
        btag_labels = np.array(mc_group.axes[2])
        bidx = np.argwhere(btag_labels==b)        
        for cat in cats:
            for syst in systs:         
                if ("btag" in syst): continue
                if ((cat not in list(mc_group.axes[1])) or
                    (syst not in list(mc_group.axes[3])) or
                    (sum(bidx)==0)
                   ):
                    group_hist = get_empty_hist()
                else: 
                    group_hist = mc_group[::sum, cat, bidx, syst, "cons", :]

                if "nom" in syst:
                    file[f"{cat}/{group}"] =  group_hist.to_numpy()
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    file[f"{cat}/{group}_{syst}"] = group_hist.to_numpy()

for b in [0, 1]:
    for group in ["reducible", "data"]:
        btag_label = "btag" if (b==1) else "0btag"
        file = uproot.recreate(f"root_for_combine/{group}_2018_{btag_label}.root")
        data_group = sum(
            v for k, v in data["m4l"].items()
        )
        btag_labels = np.array(mc_group.axes[2])
        bidx = np.argwhere(btag_labels==b)        
        for cat in cats:
            if ((cat not in list(data_group.axes[1])) or
                (sum(bidx)==0)
               ):
                group_hist = get_empty_hist()
            else: 
                group_hist = data_group[group, cat, bidx, ::sum, "cons", :]

            file[f"{cat}/{group}"] = group_hist.to_numpy()

for k, v in signal["m4l"].items():
    for b in [0, 1]:
        if (b==0) and "GluGlu" not in k: continue
        if (b>0) and "BBA" not in k: continue
        label = "ggA" if b==0 else "bbA"
        
        btag_labels = np.array(v.axes[2])
        bidx = np.argwhere(btag_labels==b)
        
        k = label + f"_{k.split('TauM')[-1]}"
        print(k)
        path = f"root_for_combine/{k}"
        if not os.path.exists(path):
            os.makedirs(path)
        file = uproot.recreate(f"{path}/{label}_2018.root")
        for cat in cats:
            for syst in systs:                    
                if ("btag" in syst): continue
                if (
                    (cat not in list(v.axes[1])) or
                    (syst not in list(v.axes[3])) or
                    (sum(bidx)==0)
                   ):
                    group_hist = get_empty_hist()
                else:
                    group_hist = v[::sum, cat, bidx, syst, "cons", :]
                    
                if "nom" in syst:
                    file[f"{cat}/{k}"] =  group_hist.to_numpy()
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    file[f"{cat}/{label}_{syst}"] = group_hist.to_numpy()

In [None]:
for b in [0, 1]:
    btag_label = "btag" if (b==1) else "0btag"
    file = uproot.recreate(f"root_for_combine/2018_{btag_label}.root")
    
    for group, datasets in group_labels.items():
        print(f"{group}")
        mc_group = sum(
            v for k, v in mc["m4l"].items()
            if k in datasets
        )
        btag_labels = np.array(mc_group.axes[2])
        bidx = np.argwhere(btag_labels==b)        
        for cat in cats:
            for syst in systs:         
                if ("btag" in syst): continue
                if ((cat not in list(mc_group.axes[1])) or
                    (syst not in list(mc_group.axes[3])) or
                    (sum(bidx)==0)
                   ):
                    group_hist = get_empty_hist()
                else: 
                    group_hist = mc_group[::sum, cat, bidx, syst, "cons", :]

                if "nom" in syst:
                    file[f"{cat}/{group}"] =  group_hist.to_numpy()
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    file[f"{cat}/{group}_{syst}"] = group_hist.to_numpy()

    for group in ["reducible", "data"]:
        btag_label = "btag" if (b==1) else "0btag"
        data_group = sum(
            v for k, v in data["m4l"].items()
        )
        btag_labels = np.array(mc_group.axes[2])
        bidx = np.argwhere(btag_labels==b)        
        for cat in cats:
            if ((cat not in list(data_group.axes[1])) or
                (sum(bidx)==0)
               ):
                group_hist = get_empty_hist()
            else: 
                group_hist = data_group[group, cat, bidx, ::sum, "cons", :]

            if "data" in group:
                file[f"{cat}/{group}_obs"] = group_hist.to_numpy()
            else:
                file[f"{cat}/reducible"] = group_hist.to_numpy()

    for k, v in signal["m4l"].items():
        if (b==0) and "GluGlu" not in k: continue
        if (b>0) and "BBA" not in k: continue
        label = "ggA" if b==0 else "bbA"
        
        btag_labels = np.array(v.axes[2])
        bidx = np.argwhere(btag_labels==b)
        
        k = label + f"{k.split('TauM')[-1]}"
        print(b, k)
        path = f"root_for_combine/{k}"
        if not os.path.exists(path):
            os.makedirs(path)
        for cat in cats:
            for syst in systs:                    
                if ("btag" in syst): continue
                if (
                    (cat not in list(v.axes[1])) or
                    (syst not in list(v.axes[3])) or
                    (sum(bidx)==0)
                   ):
                    group_hist = get_empty_hist()
                else:
                    group_hist = v[::sum, cat, bidx, syst, "cons", :]
                    
                if "nom" in syst:
                    file[f"{cat}/{k}"] =  group_hist.to_numpy()
                else:
                    shift = syst.split("_")[-1]
                    syst = syst.replace(f"_{shift}", "")
                    syst = syst + shift.capitalize()
                    file[f"{cat}/{k}_{syst}"] = group_hist.to_numpy()

In [None]:
uproot.open(
    "root_for_combine/2018_btag.root"
)["eeem"].keys()

In [None]:
uproot.open(
    '~/nobackup/combine/CMSSW_11_3_4/src/HiggsAnalysis/CombinedLimit/data/tutorials/longexercise/datacard_part3.shapes.root'
).keys()