In [1]:
import os
import sys
import dask
import uproot4
import numpy as np
import pandas as pd
import dask.dataframe as dd
import dask_histogram as dh
import multiprocessing as mp
from grid.samples_ufo_newvars import samples_DXAOD
from hist.axis import Variable, Regular
from termcolor import colored, cprint
from itertools import product
import hist
from hist import Hist
import pickle as pkl
from dask.delayed import delayed
import dask.bag as db
from functools import reduce

global FORMULAS

def add_hists(a, b):
    '''
    Simple helper function for adding together histos
    '''
    if not a:
        if b:
            return b
        else:
            return False
    else:
        if b:
            return a + b
        else:
            return a


def find_file(kws, folder):
    '''
    Helper function which uses filters to select samples
    Inputs
        - kws -> list[str]: List of keywords to search for
        - folder -> str: Folder path
    Return:
        list[str]: List of files matching keywords
    '''
    files_out = []
    for root, dir, files in os.walk(folder):
        for file in files:
            if all(map(lambda x: x in file, kws)):
                files_out.append(file)
    return files_out

def get_formula(formula):
    formula = formula.strip().split(':')[1:]
    formula = ':'.join(formula)
    return formula.replace('TMath::Power(x,', '(x**')

def load_formulas(path='/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/BoostedJetTaggers/'):
    '''
    Load dictionary of splices interpreted from the tagger configs
    Inputs:
        - path -> str: Path with config files for taggers
    Return:
        formulas -> {lambda}: Dictionary structured as dict[tagger] of splices
    '''
    config = {
        'dnn_cont_50': 'JSSWTopTaggerDNN/Rel21/UFO_tests/DNNTagger_AntiKt10UFOSD_TopContained50_TauRatios_Oct30.dat',
        'dnn_cont_80': 'JSSWTopTaggerDNN/Rel21/UFO_tests/DNNTagger_AntiKt10UFOSD_TopContained80_TauRatios_Oct30.dat',
        'dnn_incl_50': 'JSSWTopTaggerDNN/Rel21/UFO_tests/DNNTagger_AntiKt10UFOSD_TopInclusive50_TauRatios_Oct30.dat',
        'dnn_incl_80': 'JSSWTopTaggerDNN/Rel21/UFO_tests/DNNTagger_AntiKt10UFOSD_TopInclusive80_TauRatios_Oct30.dat',
        'w_50': 'SmoothedWZTaggers/Rel21/UFO_tests/SmoothedContainedWTagger_AntiKt10VanillaSD_FixedSignalEfficiency50_July21.dat',
        'w_80': 'SmoothedWZTaggers/Rel21/UFO_tests/SmoothedContainedWTagger_AntiKt10VanillaSD_FixedSignalEfficiency80_2021May10.dat',
    }

    formulas = {}
    for name, f in config.items():
        for x in open(path+f).readlines():
            if 'dnn' in name and x.startswith('ScoreCut'):
                formula = get_formula(x)
                break
            if 'w' in name and x.startswith('MassCut'):
                if 'Low' in x:
                    m_lo = get_formula(x)
                else:
                    m_hi = get_formula(x)
            elif 'w' in name and x.startswith('D2Cut'):
                    d2 = get_formula(x)
            elif 'w' in name and x.startswith('NtrkCut'):
                    ntrk = get_formula(x)
            continue
        if 'w' in name:
            exec('formulas[name] = dict()')
            exec('formulas[name]["m_lo"] = lambda x: '+m_lo)
            exec('formulas[name]["m_hi"] = lambda x: '+m_hi)
            exec('formulas[name]["d2"]   = lambda x: '+d2)
            exec('formulas[name]["ntrk"] = lambda x: '+ntrk)
        else:
            exec('formulas[name] = lambda x: '+formula)
    return formulas


def get_weight(sample, dsid, xsec_file='sample_xsections_newvars.txt', nevt_file='data/weights_dijet.root'):
    '''
    Function which derives the sample normalisation
    Special consideration for selecting where the denominator comes from.
    Inputs:
        sample -> str: Name of sample (used to identify the MC campaign)
        dsid   -> int: DSID of sample (used to check if JZX slice dijet sample)
        xsec_file -> str: CSV file with [dsid filt efficency xsec shower algorithm]
        nevt_file -> str: Path of root file containing total events (weighted or not) for denominator
    Return:
        float: (Lumi)*(Cross-section)*(Filter efficiency)/(Total weighted events)
    '''
    xsec_file = pd.read_csv(xsec_file, names=['dsid', 'filt', 'xsec', 'shower'], header=None, delimiter=' ')

    for array in uproot4.iterate(nevt_file+':sumWeights', library='pd'):
        nevt_file = array

    jzx_slice = list(range(364700,364713)) + \
                list(range(364902,364910)) + \
                list(range(364922,364930)) + \
                list(range(364681,364686)) + \
                list(range(364690,364695)) + \
                list(range(364443,364455)) + \
                list(range(426131,426143))
    mc = 1 if 'mc16a' in sample else 2 if 'mc16d' in sample else 3 if 'mc16e' in sample else -1
    if 'dijet' in sample:
        if not dsid in jzx_slice:
            nevts = float(nevt_file[(nevt_file['dsid']==dsid) & (nevt_file['Type']==mc)]['totalEvents'])
        elif dsid in jzx_slice:
            try:
                nevts = float(nevt_file[(nevt_file['dsid']==dsid) & (nevt_file['Type']==mc)]['totalEventsWeighted'])
            except:
                nevts = float(nevt_file[(nevt_file['dsid']==dsid) & (nevt_file['Type']==3 )]['totalEventsWeighted'])
    else:
        nevts = float(nevt_file[(nevt_file['dsid']==dsid) & (nevt_file['Type']==mc)]['totalEventsWeighted'])
    xsec = float(xsec_file[xsec_file['dsid']==dsid]['xsec']) * 1e6
    filt = float(xsec_file[xsec_file['dsid']==dsid]['filt'])
    lumi = 36.2 if mc is 1 else 40.5 if mc is 2 else 58.5 if mc is 3 else 1.0
    return xsec * filt / nevts * lumi

def load_files(sample, folder, vnames=False, filt=[], debug=False):
    if debug:
        cprint('Colour key:', 'blue')
        cprint('File successfully processed', 'green')
        cprint('Problem with opening file', 'red', attrs=['bold'])
        cprint('Specific DSID/mc is missing from folder '+folder, 'magenta', attrs=['bold'])
    fnames = samples_DXAOD[sample]
    all_f = 0
    cprint('Loading '+sample, 'blue')
    fail = []
    broken = []
    files = {}
    for fname in fnames:
        dsid = fname.split('.')[1]
        tags = fname.split('.')[-1]
        real_name = find_file([dsid, tags, 'dijet']+filt, folder=folder)
        for r in real_name:
            all_f += 1
        files[dsid] = []
        if not real_name:
            cprint(fname, 'magenta', attrs = ['bold'])
            fail.append(fname)
            continue
        for pkl in real_name:
            _broken = []
            try:
                cols = ['rljet_m_comb[:,0]',
                        'rljet_topTag_DNN20_TausRatio_qqb_score',
                        'rljet_topTag_DNN20_TausRatio_inclusive_score',
                        'rljet_D2',
                        'rljet_ungroomed_ntrk500']
                if vnames:
                    cols += vnames
                if not 'data' in sample:
                    cols += ['weight_mc', 'weight_pileup']
                cols = list(set(cols))
                files[dsid].append(dd.from_pandas(pd.read_pickle(folder+pkl)[cols].droplevel(1), npartitions=20))
                #files[dsid].append(pd.read_pickle(folder+pkl)[cols].droplevel(1))
            except Exception as e:
                cprint(fname, 'red', attrs = ['blink'])
                cprint('\tFailed '+pkl, 'red', attrs=['bold'])
                print('\t', type(e), e)
                _broken += [(fname, e)]
                continue
            broken += _broken
        if not _broken:
            cprint(fname, 'green')
    print(all_f)
    return files, fail, broken

def fill_histos(trees, histo, vnames, taggers):
    histos = []
    global TAGGERS
    n_trees = len(trees)
    for i in range(n_trees):
        print(i)
        histos = [fill_histo(trees[n_trees-1-i], histo, v, tag) for v, tag in list(product(vnames, TAGGERS))]
        res = list(dask.compute(*histos))
        for x, r in zip(list(product(vnames, TAGGERS)), res):
            histo[x[0]][x[1]] = r
        del trees[n_trees-1-i]
    del histos
    del res
    return histo

@dask.delayed
def fill_histo(tree, histo, vname, taggers):
    '''
    Main function for retrieving histograms
    Inputs:
        - vname  -> str: Variable name in ntuple
    Return:
        hists -> {hist}: Dictionary of hist objects structured as dict[tag]
    '''
    tree = tree[tree['rljet_m_comb[:,0]']>50000.]
    target = tree[vname]

    # Calculate weight
    if 'period' in dsid:
        weight = np.ones_like(tree['rljet_m_comb[:,0]'])
    else:
        weight = (tree['weight_mc']*tree['weight_pileup'])

    tag = taggers
    target_pass = target
    weight_pass = weight

    # Caclulate pass based on pt
    if 'dnn' in tag:
        if 'cont' in tag:
            tagger = tree['rljet_topTag_DNN20_TausRatio_qqb_score']
        else:
            tagger = tree['rljet_topTag_DNN20_TausRatio_inclusive_score']
        score_cut = FORMULAS[tag]
        score_cut = score_cut(tree['rljet_pt_comb']/1000.)

        # Apply cuts
        target_pass = target[tagger > score_cut]
        weight_pass = weight[tagger > score_cut]

    elif 'w_50' in tag or 'w_80' in tag:
        f = FORMULAS[tag]
        mass = tree['rljet_m_comb[:,0]']/1000.
        d2   = tree['rljet_D2']
        ntrk = tree['rljet_ungroomed_ntrk500']
        mlo_cut   = f['m_lo'](tree['rljet_pt_comb']/1000.)
        mhi_cut   = f['m_hi'](tree['rljet_pt_comb']/1000.)
        d2_cut    = f['d2']  (tree['rljet_pt_comb']/1000.)
        ntrk_cut  = f['ntrk'](tree['rljet_pt_comb']/1000.)

        target_pass = target[(d2 < d2_cut) & (mass > mlo_cut) & (mass < mhi_cut) & (ntrk < ntrk_cut)]
        weight_pass = weight[(d2 < d2_cut) & (mass > mlo_cut) & (mass < mhi_cut) & (ntrk < ntrk_cut)]

    # Create histos
    histo[vname][tag].fill(target_pass, weight=weight_pass)

    return histo[vname][taggers]


    open.defaults["xrootd_handler"] = uproot.MultithreadedXRootDSource



In [None]:
from dask.distributed import Client

client = Client()
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44583 instead
  f"Port {expected} is already in use.\n"


OSError: Timed out trying to connect to tcp://127.0.0.1:42201 after 30 s

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 5
Total threads: 10,Total memory: 28.60 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:42177,Workers: 5
Dashboard: http://127.0.0.1:8787/status,Total threads: 10
Started: Just now,Total memory: 28.60 GiB

0,1
Comm: tcp://137.138.55.242:43632,Total threads: 2
Dashboard: http://137.138.55.242:46027/status,Memory: 5.72 GiB
Nanny: tcp://127.0.0.1:45142,
Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-yw1d9gkb,Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-yw1d9gkb

0,1
Comm: tcp://137.138.55.242:32976,Total threads: 2
Dashboard: http://137.138.55.242:38925/status,Memory: 5.72 GiB
Nanny: tcp://127.0.0.1:37255,
Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-_dv5fubz,Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-_dv5fubz

0,1
Comm: tcp://137.138.55.242:42421,Total threads: 2
Dashboard: http://137.138.55.242:38040/status,Memory: 5.72 GiB
Nanny: tcp://127.0.0.1:41874,
Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-170augjw,Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-170augjw

0,1
Comm: tcp://137.138.55.242:36868,Total threads: 2
Dashboard: http://137.138.55.242:44615/status,Memory: 5.72 GiB
Nanny: tcp://127.0.0.1:40513,
Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-epmm0ars,Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-epmm0ars

0,1
Comm: tcp://137.138.55.242:36102,Total threads: 2
Dashboard: http://137.138.55.242:33749/status,Memory: 5.72 GiB
Nanny: tcp://127.0.0.1:33293,
Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-4kba596g,Local directory: /afs/cern.ch/work/b/brle/private/jetetmiss/ab_21_2_183/source/DataMCDijetTopology/dask-worker-space/worker-4kba596g


In [None]:
var_main = {
        'rljet_pt_comb':      Variable([ 4.5e5, 5e5, 5.5e5, 6e5, 6.5e5, 7e5, 7.5e5, 8e5, 8.5e5, 9e5, 9.5e5, 1e6, 1.1e6, 1.2e6, 1.3e6, 1.4e6, 1.5e6, 1.7e6, 2.5e6], name='x', label=r'$p_{T}$[MeV]'  ),
        'rljet_Angularity'        : Regular(20, 0., 0.5,  name='x', label='Angularity'),
        'rljet_Aplanarity'        : Regular(20, 0., 1,    name='x', label='Aplanarity'),
        'rljet_C2'                : Regular(20, 0., 1,    name='x', label='$C_{2}$'        ),
        'rljet_D2'                : Regular(20, 0., 6,    name='x', label='$D_{2}$'        ),
        'rljet_Dip12'             : Regular(20, 0., 2,    name='x', label='Dip12'       ),
        'rljet_ECF1'              : Regular(20, 0., 1e7,  name='x', label='ECF1'        ),
        'rljet_ECF2'              : Regular(20, 0., 1e12, name='x', label='ECF2'        ),
        'rljet_ECF3'              : Regular(20, 0., 1e17, name='x', label='ECF3'        ),
        'rljet_FoxWolfram0'       : Regular(20, 0., 1,    name='x', label='FoxWolfram0' ),
        'rljet_FoxWolfram2'       : Regular(20, 0., 1,    name='x', label='FoxWolfram2' ),
        'rljet_KtDR'              : Regular(20, 0., 6,    name='x', label='KtDR'        ),
        'rljet_L1'                : Regular(20, 0., 1,    name='x', label='$L_{1}$'     ),
        'rljet_L2'                : Regular(20, 0., 1,    name='x', label='$L_{2}$'     ),
        'rljet_L3'                : Regular(20, 0., 1,    name='x', label='$L_{3}$'     ),
        'rljet_L4'                : Regular(20, 0., 5,    name='x', label='$L_{4}$'     ),
        'rljet_L5'                : Regular(20, 0., 1,    name='x', label='$L_{5}$'     ),
        'rljet_M2'                : Regular(20, 0., 1,    name='x', label='$M_{2}$'     ),
        'rljet_Mu12'              : Regular(20, 0., 1,    name='x', label='Mu12'        ),
        'rljet_N2'                : Regular(20, 0., 1,    name='x', label='$N_{2}$'          ),
        'rljet_N3'                : Regular(20, 0., 6,    name='x', label='$N_{3}$'          ),
        'rljet_PlanarFlow'        : Regular(20, 0., 2,    name='x', label='PlanarFlow'  ),
        'rljet_Qw'                : Regular(20, 0., 1e6,  name='x', label='$Q_{w}$'          ),
        'rljet_Sphericity'        : Regular(20, 0., 1,    name='x', label='Sphericity'  ),
        'rljet_Split12'           : Regular(20, 0., 1e6,  name='x', label='Split12'     ),
        'rljet_Split23'           : Regular(20, 0., 1e6,  name='x', label='Split23'     ),
        'rljet_Split34'           : Regular(20, 0., 1e6,  name='x', label='Split34'     ),
        'rljet_Tau1_wta'          : Regular(20, 0., 1,    name='x', label='$\tau_{1}'    ),
        'rljet_Tau2_wta'          : Regular(20, 0., 1,    name='x', label='$\tau_{2}'    ),
        'rljet_Tau32_wta'         : Regular(20, 0., 1,    name='x', label='$\tau_{32}'   ),
        'rljet_Tau3_wta'          : Regular(20, 0., 1,    name='x', label='$\tau_{3}'    ),
        'rljet_Tau42_wta'         : Regular(20, 0., 1,    name='x', label='$\tau_{42}'   ),
        'rljet_Tau4_wta'          : Regular(20, 0., 1,    name='x', label='$\tau{4}$'    ),
        'rljet_ThrustMaj'         : Regular(20, 0., 1,    name='x', label='ThrustMaj'   ),
        'rljet_ThrustMin'         : Regular(20, 0., 1,    name='x', label='ThrustMin'   ),
        'rljet_ZCut12'            : Regular(20, 0., 1,    name='x', label='ZCut12'      ),
        'rljet_n_constituents'    : Regular(20, 0., 200,  name='x', label='$n_{const.}$'),
        'rljet_ungroomed_ntrk500' : Regular(20, 0., 100,  name='x', label='$N_{trk,500}$' ),
    }
TAGGERS = ['dnn_cont_50', 'dnn_cont_80', 'dnn_incl_50', 'dnn_incl_80', 'w_50', 'w_80']
samples = list(samples_DXAOD.keys())
samples = ['pythia_dijet_mc16e']
global FORMULAS
FORMULAS = load_formulas()
folder = '/eos/atlas/atlascerngroupdisk/perf-jets/JSS/WTopBackgroundSF2019/UFO_test/slimmed_SEP/'

weights = {}
histos  = {}

for samp in samples:
        _files, _fail, _broken = load_files(samp, folder, vnames=list(var_main.keys()))
        histos[samp] = {}
        dsids = list(_files.keys())
        for dsid in dsids:
            files = _files[dsid]
            histos[samp][dsid] = {}
            for vname, var in var_main.items():
                histos[samp][dsid][vname] = {}
                for tag in TAGGERS+['all']:
                    histos[samp][dsid][vname][tag] = dh.Histogram(var, storage=dh.storage.Double())
                    #histos[samp][dsid][vname][tag] = Hist(var, storage=hist.storage.Double())
            try:
                print('Initialised', samp, dsid)
                histos[samp][dsid] = fill_histos(files, histos[samp][dsid], list(var_main.keys()), FORMULAS)
                print('Saving sample histograms')
                with open('dask_test_'+samp+'_'+dsid+'.pkl', 'wb') as file:
                    pkl.dump(histos[samp][dsid], file, protocol=pkl.HIGHEST_PROTOCOL)
            except:
                print('Failed', samp, dsid)
                continue
            del files
            del _files[dsid]
            del histos[samp][dsid]


[34mLoading pythia_dijet_mc16e[0m
[32mmc16_13TeV.364703.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ3WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[32mmc16_13TeV.364704.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ4WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[32mmc16_13TeV.364705.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ5WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[32mmc16_13TeV.364706.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ6WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[1m[35mmc16_13TeV.364707.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ7WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[1m[35mmc16_13TeV.364708.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ8WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[32mmc16_13TeV.364709.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ9WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[1m[35mmc16_13TeV.364710.Pythia8EvtGen_A14NNPDF23LO_jetjet_JZ10WithSW.deriv.DAOD_JETM6.e7142_s3126_r10724_p4308[0m
[1m[35mmc16_13TeV.364711.Pythia8EvtGen_A14NNP