In [None]:
import random
import copy

import numpy as np
import awkward as ak
import uproot
import torch

from numba import njit
# from concurrent.futures import ThreadPoolExecutor


def superbatch_iterator(files, keys, superbatch_size=1024*100):
    """
    Reduces the overhead caused by reading data in small batches.
    Now the batch size and data load size are independent.
    WARNING: data load size should be a multiple of batch size.

    Parameters
    ----------
    files: file paths as in uproot.iterator
    keys: TBranch names in list
    superbatch_size: batch_size * multiplet 
    """
    arrays = []
    num_entries = 0
    iterator = uproot.iterate(files, keys, step_size=superbatch_size, library="ak", 
                             #  decompression_executor=ThreadPoolExecutor(max_workers=2) # -- parallel xz/zstd)
                              ) 

    for it in iterator:
        it_length = it.type.length

        if num_entries + it_length < superbatch_size:
            # Batch is missing some entries, get more entries
            arrays.append(it)
            num_entries += it_length
        else:
            # If batch is ready or more than ready, handle it
            remaining = superbatch_size - num_entries

            if remaining > 0:
                arrays.append(it[:remaining])
                yield ak.concatenate(arrays)  # next iteration starts directly after this line.
                arrays = [it[remaining:]]
                num_entries = it_length - remaining
            else:
                yield ak.concatenate(arrays)  # next iteration starts directly after this line.
                arrays = []
                num_entries = 0
                # arrays = [it]
                # num_entries = it_length

    if arrays:
        yield ak.concatenate(arrays)



class ModifiedUprootIterator(torch.utils.data.IterableDataset):
    def __init__(self, files, branches, shuffle=False, nWorkers=1, step_size=100):
        """
        Parameters
        ----------
        files : dict
                keys: "sig", "bkg"
                values: ['path_to_file:Events', ...]
        branches : dict
                dict of branches in TTree.
                keys should be ev, sv, tk.
                values should be of type list.
        nWorkers : int
                Files will be divided among the workers.
                Therefore, nWorkers determines number of divisions.
                nWorkers=0 will be treated as nWorkers=1.
        step_size : int
                number of Events to be read from the files at each iteration.
        """
        
        print('Initialize iterable dataset')
        self.files = files
        self.branches = branches

        self.branchList = [b for key, value in branches.items() if value is not None for b in value]
        
        self.step_size = step_size
        self.nWorkers = max(nWorkers, 1)
        self.shuffle = shuffle
        print('nWorkers: ', self.nWorkers)
        
        if self.shuffle:
            random.shuffle(self.files['sig'])
            if self.files['bkg']:
                random.shuffle(self.files['bkg'])

        if self.files['bkg']:
            self.workerBkgList = self._distribute_files(self.files['bkg'])
        else:
            self.workerBkgList = None
        self.workerSigList = self._copy_files(self.files['sig'], shuffle=self.shuffle)

        self.SigIteratorList = None
        self.BkgIteratorList = None
        self._refresh_iterators()
        
        self.x = None
        self.xSig = None
        self.xBkg = None

    
    def _distribute_files(self, files):
        return [[files[i] for i in range(len(files)) if i % self.nWorkers == worker_id] for worker_id in range(self.nWorkers)]

    def _copy_files(self, files, shuffle=True):
        workerList = []
        for worker_info_id in range(self.nWorkers):
            files_copy = copy.copy(files)
            if shuffle: random.shuffle(files_copy)
            workerList.append(files_copy)
        return workerList
        
    def _refresh_iterators(self, shuffle=False):
        self.workerSigList = self._copy_files(self.files['sig'], shuffle=shuffle)
        self.SigIteratorList = [superbatch_iterator(workerFiles, self.branchList, superbatch_size=self.step_size) for workerFiles in self.workerSigList]
        if self.workerBkgList:
            self.BkgIteratorList = [superbatch_iterator(workerFiles, self.branchList, superbatch_size=self.step_size) for workerFiles in self.workerBkgList]
        else:
            self.BkgIteratorList = None

    def __iter__(self):
        print('__iter__ is called.')
        if self.step_size <200: 
            self.step_size += 25
            print('step_size is increased to ', self.step_size)
        self._refresh_iterators(shuffle=self.shuffle)
        return self

    def update_step_size(self, new_step_size):
        self.step_size = new_step_size

    
    def __next__(self):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0

        if self.BkgIteratorList:
            self.xBkg = next(self.BkgIteratorList[worker_id])
            try:
                self.xSig = next(self.SigIteratorList[worker_id])
            except StopIteration:
                print(f'Worker {worker_id}s SigIteratorList is exhausted. Loading again.')
                self.SigIteratorList[worker_id] = superbatch_iterator(self.workerSigList[worker_id], self.branchList, superbatch_size=self.step_size)
                self.xSig = next(self.SigIteratorList[worker_id])
            self.x = ak.concatenate([self.xBkg, self.xSig])
        else:
            self.xSig = next(self.SigIteratorList[worker_id])
            self.x = self.xSig

        if self.shuffle:
            self.x = self._shuffle_akArr(self.x)
        self._add_four_vector_branches()
        
        return self._prepare_output()


    def _shuffle_akArr(self, x):
        """ Shuffle awkward array. """
        idx = np.arange(len(x))
        np.random.shuffle(idx)
        return x[idx]

    
    def _add_four_vector_branches(self):
        if all(x in self.branchList for x in ['SDVTrack_pt', 'SDVTrack_eta', 'SDVTrack_phi']) and \
           any(x not in self.branchList for x in ['SDVTrack_E', 'SDVTrack_px', 'SDVTrack_py', 'SDVTrack_pz']):
            
            self.branches['tk'].extend(['SDVTrack_E', 'SDVTrack_px', 'SDVTrack_py', 'SDVTrack_pz'])
            E, px, py, pz = ptetaphim_to_epxpypz(self.x['SDVTrack_pt'], self.x['SDVTrack_eta'], self.x['SDVTrack_phi'])
            self.x['SDVTrack_E'] = E
            self.x['SDVTrack_px'] = px
            self.x['SDVTrack_py'] = py
            self.x['SDVTrack_pz'] = pz


    def _prepare_output(self):
        return pad_and_fill(self.x, self.branches, svDim=12, tkDim=10, fillValue=0.)


def ptetaphim_to_epxpypz(pt, eta, phi, m=0.13957):
    px = pt * np.cos(phi)
    py = pt * np.sin(phi)
    pz = pt * np.sinh(eta)
    E = np.sqrt(px*px + py*py + pz*pz + m*m)
    return (E, px, py, pz)


def pad_and_fill(X, branchDict, svDim=12, tkDim=10, fillValue=-9e10):
    def process_field(field, broadcast_to=None):
        ak_arr = X[field]
        if broadcast_to == 'sv':
            # Broadcast to any sv branch shape
            ak_arr, _ = ak.broadcast_arrays(ak_arr, X[branchDict['sv'][0]])
        else:
            pass
        flat = ak.flatten(ak_arr, axis=1)
        X_np = flat.to_numpy()
        return torch.tensor(X_np)

    X_dict = {}
    # print('X.fields: ', X.fields)
    for field in X.fields:
        if branchDict['ev'] and field in branchDict['ev']:
            if field.startswith('n'):
                X[field] = ak.values_astype(X[field], np.int32)
            elif field.startswith('Jet'):
                X[field] = X[field][:, 0]
            
            X_dict[field] = process_field(field, broadcast_to='sv')
        
        elif branchDict['sv'] and field in branchDict['sv']:
            X_dict[field] = process_field(field)
        
        elif branchDict['label'] and field in branchDict['label']:
            X_dict[field] = process_field(field)

        elif branchDict['tk'] and field in branchDict['tk']:
            trIdx = X.SDVIdxLUT_TrackIdx
            svIdx = X.SDVIdxLUT_SecVtxIdx
            n_sv =  X.nSDVSecVtx
            
            builder = ak.ArrayBuilder()
            deepTable(X[field], trIdx, svIdx, n_sv, builder)
            deepX = builder.snapshot()
            # print(deepX.type)
            # print('deepX: ', deepX)

            field_fillValue = {
                'SDVTrack_E': 1e3,
                'SDVTrack_pz': 0
            }.get(field, fillValue)         # Returns 'fillValue' if the 'field' does not exist.

            almost_flat = ak.flatten(deepX, axis=1)
            # print('almost_flat: ', almost_flat)
            padded = ak.pad_none(almost_flat, target=tkDim, clip=True, axis=1)
            filled = ak.fill_none(padded, field_fillValue, axis=1)

            X_np = filled.to_numpy()
            X_dict[field] = torch.tensor(X_np)

    return X_dict


@njit
def deepTable(tkBranch, trIdx, svIdx, n_sv, builder):
    """
    Takes the track level branch and converts its shape
    from: (nEvent * var * float32) to (nEvent * var * var * float32)
    representing the association of each tk with sv.

    svIdx: [[0, 0, 1, 1,  2, 2,  3,  3, ...], ...]
    trIdx: [[1, 3, 3, 28, 7, 8, 10, 11, ...], ...]
    n_sv:  [4, 0, 6, 1, 1, 7, ...]
    """
    # at event level depth already
    # every element added are at event level depth
    for ev in range(len(n_sv)):                   # 
        builder.begin_list()                      # adding sv level depth
        for sv in range(n_sv[ev]):                # for each vtx ... 
            builder.begin_list()
            for i2, col in enumerate(svIdx[ev]):  # getting an svIdx from the svIdxs for that event
                if col == sv:                     # if the same svIdx 
                    builder.append(tkBranch[ev][trIdx[ev][i2]])
            builder.end_list()
        builder.end_list()

# @njit
# def deepTable(tkBranch, trIdx, svIdx, n_sv, builder):
#     """
#     Highly optimized version assuming svIdx is sorted per event.
#     Avoids intermediate lists for maximum performance and minimum memory overhead.
#     """
#     for ev in range(len(n_sv)):
#         builder.begin_list()  # Event-level list
#         
#         num_sv_in_event = n_sv[ev]
#         if num_sv_in_event == 0:
#             builder.end_list()
#             continue
# 
#         association_idx = 0
#         event_svIdx = svIdx[ev]
#         num_associations = len(event_svIdx)
# 
#         # For each expected SV, collect its tracks
#         for sv_target in range(num_sv_in_event):
#             builder.begin_list()  # SV-level list
#             
#             # Scan forward through the associations for the current sv_target
#             while association_idx < num_associations and event_svIdx[association_idx] == sv_target:
#                 track_idx = trIdx[ev][association_idx]
#                 builder.append(tkBranch[ev][track_idx])
#                 association_idx += 1 # Move to the next association
#             
#             builder.end_list()
#         builder.end_list()




In [2]:
import datetime
import warnings
import random
import glob
import gc

import math
import numpy as np
import awkward as ak
import uproot
from sklearn.metrics import confusion_matrix
from numba import jit, njit

import torch
import torch.nn as nn

# import ParT
# from vtxLevelDataset2 import ModifiedUprootIterator

import matplotlib.pyplot as plt


warnings.filterwarnings("ignore", category=UserWarning)


print('CPU count: ', torch.multiprocessing.cpu_count())
# torch.set_num_threads(18)
# torch.set_num_threads(torch.multiprocessing.cpu_count())

def significance(s,b,b_err):
    """
    Median discovery significance
    Definition at slide 33:
    https://www.pp.rhul.ac.uk/~cowan/stat/cowan_munich16.pdf
    
    """
    return np.sqrt(2*((s+b)*np.log(((s+b)*(b+b_err*b_err))/(b*b+(s+b)*b_err*b_err+1e-20)) - 
                    (b*b/(b_err*b_err + 1e-20))*np.log(1+(b_err*b_err*s)/(b*(b+b_err*b_err)+1e-20))))




MLDATADIR = '/scratch-cbe/users/alikaan.gueven/ML_KAAN/MC2018/all/'
tmpSigList = glob.glob(f'{MLDATADIR}/stop*/**/*.root', recursive=True)
tmpSigList = [sig + ':Events' for sig in tmpSigList]
maxTrain = round(len(tmpSigList)*0.70)
maxVal   = round(len(tmpSigList)*1)

trainSigList = tmpSigList[:maxTrain]
valSigList   = tmpSigList[maxTrain:maxVal]

# trainBkgList = glob.glob(f'{MLDATADIR}/training_set/bkg_mix*.root')
# valBkgList = glob.glob(f'{MLDATADIR}/val_set/bkg_mix*.root')

# trainBkgList = glob.glob('/scratch-cbe/users/alikaan.gueven/ML_KAAN/train/training_set/bkg_mix*.root')
# valBkgList = glob.glob('/scratch-cbe/users/alikaan.gueven/ML_KAAN/train/val_set/bkg_mix*.root')


# trainBkgList = [elm + ':Events' for elm in trainBkgList]
# valBkgList = [elm + ':Events' for elm in valBkgList]


# trainDict = {
#     'sig': trainSigList,
#     'bkg': trainBkgList
# }
# 
# valDict = {
#     'sig': valSigList,
#     'bkg': valBkgList
# }

trainDict = {
    'sig': trainSigList,
    'bkg': None
}

valDict = {
    'sig': valSigList,
    'bkg': None
}



branchDict = {}
branchDict['ev'] = ['MET_phi',
                    'nSDVSecVtx', 
                    'Jet_phi', 'Jet_pt', 'Jet_eta'
                    ]

branchDict['sv'] = ['SDVSecVtx_pt', 
                    'SDVSecVtx_pAngle', 
                    'SDVSecVtx_charge', 
                    'SDVSecVtx_ndof', 
                    'SDVSecVtx_chi2', 
                    'SDVSecVtx_tracksSize', 
                    'SDVSecVtx_sum_tkW', 
                    'SDVSecVtx_LxySig', 
                    'SDVSecVtx_L_phi', 
                    'SDVSecVtx_L_eta', 
                    'SDVIdxLUT_SecVtxIdx', 
                    'SDVIdxLUT_TrackIdx']

branchDict['tk'] = ['SDVTrack_pt', 'SDVTrack_ptError', 
                    'SDVTrack_eta', 'SDVTrack_etaError',
                    'SDVTrack_dxy', 'SDVTrack_dxyError', 
                    'SDVTrack_dz', 'SDVTrack_dzError',
                    'SDVTrack_normalizedChi2', 'SDVTrack_eta', 'SDVTrack_phi']
branchDict['label'] = ['SDVSecVtx_matchedLLPnDau_bydau']


shuffle = True
nWorkers = 2
step_size = 150

# trainDataset = ModifiedUprootIterator(trainDict, branchDict, shuffle=shuffle, nWorkers=min(nWorkers, len(trainBkgList)), step_size=step_size)
# trainLoader = torch.utils.data.DataLoader(trainDataset, num_workers=max(min(nWorkers, len(trainBkgList)),1), prefetch_factor=1, persistent_workers= True)
# 
# 
# valDataset = ModifiedUprootIterator(valDict, branchDict, shuffle=shuffle, nWorkers=min(nWorkers, len(valBkgList)), step_size=step_size*5)
# valLoader = torch.utils.data.DataLoader(valDataset, num_workers=max(min(nWorkers, len(valBkgList)),1), prefetch_factor=1, persistent_workers= True)

CPU count:  76


In [3]:
trainDataset = ModifiedUprootIterator(trainDict, branchDict, shuffle=shuffle, nWorkers=1, step_size=step_size)
trainLoader = torch.utils.data.DataLoader(trainDataset, num_workers=1, prefetch_factor=1, persistent_workers= True)

Initialize iterable dataset
nWorkers:  1


In [4]:
# %%timeit
# tensor_dict = trainDataset.__next__()

In [5]:
tensor_dict = trainDataset.__next__()

In [6]:
%prun -s cumulative [trainDataset.__next__() for _ in range(50)]

 

         28617731 function calls (28219062 primitive calls) in 29.540 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   29.549   29.549 {built-in method builtins.exec}
        1    0.005    0.005   29.549   29.549 <string>:1(<module>)
        1    0.000    0.000   29.544   29.544 <string>:1(<listcomp>)
       50    0.005    0.000   29.544    0.591 1058244011.py:147(__next__)
22050/14300    0.018    0.000   25.526    0.002 {built-in method builtins.next}
       50    0.007    0.000   20.257    0.405 1058244011.py:13(superbatch_iterator)
       54    0.003    0.000   18.936    0.351 TBranch.py:49(iterate)
182329/561    2.033    0.000   17.004    0.030 model.py:754(read)
        8    0.000    0.000   16.329    2.041 _util.py:954(regularize_object_path)
        8    0.000    0.000   16.290    2.036 reading.py:2072(__getitem__)
        8    0.000    0.000   16.290    2.036 reading.py:2459(get)
  279

In [None]:
tensor_dict

In [None]:
for i in tensor_dict:
    print(i, tensor_dict[i].shape)

MET_phi torch.Size([570])
nSDVSecVtx torch.Size([570])
Jet_phi torch.Size([570])
Jet_pt torch.Size([570])
Jet_eta torch.Size([570])
SDVSecVtx_pt torch.Size([570])
SDVSecVtx_pAngle torch.Size([570])
SDVSecVtx_charge torch.Size([570])
SDVSecVtx_ndof torch.Size([570])
SDVSecVtx_chi2 torch.Size([570])
SDVSecVtx_tracksSize torch.Size([570])
SDVSecVtx_sum_tkW torch.Size([570])
SDVSecVtx_LxySig torch.Size([570])
SDVSecVtx_L_phi torch.Size([570])
SDVSecVtx_L_eta torch.Size([570])
SDVIdxLUT_SecVtxIdx torch.Size([1280])
SDVIdxLUT_TrackIdx torch.Size([1280])
SDVTrack_pt torch.Size([570, 10])
SDVTrack_ptError torch.Size([570, 10])
SDVTrack_eta torch.Size([570, 10])
SDVTrack_etaError torch.Size([570, 10])
SDVTrack_dxy torch.Size([570, 10])
SDVTrack_dxyError torch.Size([570, 10])
SDVTrack_dz torch.Size([570, 10])
SDVTrack_dzError torch.Size([570, 10])
SDVTrack_normalizedChi2 torch.Size([570, 10])
SDVTrack_phi torch.Size([570, 10])
SDVSecVtx_matchedLLPnDau_bydau torch.Size([570])
SDVTrack_E torch.Siz