In [None]:
# default_exp modules.data_trackCalo

In [None]:
#export
import numpy as np
import glob
import os
import uproot as ur
import time
from multiprocessing import Process, Queue, set_start_method
import compress_pickle as pickle
from scipy.stats import circmean
import random
import itertools

# GraphDataGenerator

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
np.set_printoptions(precision=5, suppress=True)

In [None]:
#export
class CaloTrackGraphDataGenerator:
    """DataGenerator class for extracting and formating data from list of root files"""
    def __init__(self,
                 pion_file_list: list,
                 cellGeo_file: str,
                 batch_size: int,
                 use_geo_edges: bool = False,
                 shuffle: bool = True,
                 num_procs: int = 32,
                 preprocess: bool = False,
                 output_dir: str = None):
        """Initialization"""

        self.preprocess = preprocess
        self.output_dir = output_dir
        
        if self.preprocess and self.output_dir is not None:
            self.pion_file_list = pion_file_list
            self.num_files = len(self.pion_file_list)
        else:
            self.file_list = pion_file_list
            self.num_files = len(self.file_list)
        
        self.cellGeo_file = cellGeo_file
        self.use_geo_edges = use_geo_edges
        
        self.cellGeo_data = ur.open(self.cellGeo_file)['CellGeo']
        self.geoFeatureNames = self.cellGeo_data.keys()[1:9]
        self.nodeFeatureNames = ['cluster_cell_E', *self.geoFeatureNames[:-2]]
        self.num_nodeFeatures = len(self.nodeFeatureNames)

        self.edgeFeatureNames = self.cellGeo_data.keys()[9:] if self.use_geo_edges else []
        self.num_edgeFeatures = len(self.edgeFeatureNames)

        self.cellGeo_data = self.cellGeo_data.arrays(library='np')
        self.cellGeo_ID = self.cellGeo_data['cell_geo_ID'][0]
        self.sorter = np.argsort(self.cellGeo_ID)
        
        self.track_feature_names = ['trackPt','trackD0','trackZ0', 'trackEta_EMB2','trackPhi_EMB2',
                                    'trackEta','trackPhi','truthPartE', 'truthPartPt']
        self.cluster_feature_names = ['cluster_E', 'cluster_Eta', 'cluster_Phi', 'cluster_ENG_CALIB_TOT', 
                                      'cluster_EM_PROBABILITY','cluster_E_LCCalib','cluster_HAD_WEIGHT', 'dR']
        
        self.dr_thresh = 1.2
        self.clusterThresh = .5
        
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        if self.shuffle: np.random.shuffle(self.file_list)
        
        self.num_procs = num_procs
        self.procs = []

        if self.preprocess and self.output_dir is not None:
            os.makedirs(self.output_dir, exist_ok=True)
            self.preprocess_data()
    
#     def get_cluster_inds(self, event_data, event_ind):
        
#         if self.n_clusters==-1:   # get all nodes satisfying dR criterion
#             c_inds = range(event_data['nCluster'][event_ind])
#             c_inds = [c for c in c_inds if (event_data['dR'][event_ind][c]<self.dr_thresh) and 
#                       (event_data['cluster_E'][event_ind][c]>self.clusterThresh)]
#         else:                # get n leading nodes satisfying dR criterion
#             c_inds = np.argsort(event_data['cluster_E'][event_ind])[::-1]
#             c_inds = [c for c in c_inds if (event_data['dR'][event_ind][c]<self.dr_thresh) and 
#                       (event_data['cluster_E'][event_ind][c]>self.clusterThresh)]
#             c_inds = c_inds[:self.n_clusters]
        
#         return c_inds
    
    def get_meta(self, event_data, event_ind, c_inds):
        """ 
        Reading meta data
        """  
        track_meta_data = []
        for f in self.track_feature_names:
            track_meta_data.append(event_data[f][event_ind])
        
        cluster_meta_data = []
        for c in c_inds:
            curr_meta = []
            
            for f in self.cluster_feature_names:
                curr_meta.append(event_data[f][event_ind][c])
            
            cluster_meta_data.append(curr_meta)
            
        return np.array(track_meta_data, dtype=np.float32), np.array(cluster_meta_data, dtype=np.float32)
    
    def get_nodes(self, event_data, event_ind, cluster_ind):
        """ Reading Node features """ 

        cell_IDs = event_data['cluster_cell_ID'][event_ind][cluster_ind]
        cell_IDmap = self.sorter[np.searchsorted(self.cellGeo_ID, cell_IDs, sorter=self.sorter)]
        
        nodes = np.log10(event_data['cluster_cell_E'][event_ind][cluster_ind])
        
        # Scaling the cell_geo_sampling by 28
        nodes = np.append(nodes, self.cellGeo_data['cell_geo_sampling'][0][cell_IDmap]/28.)
        for f in self.nodeFeatureNames[2:4]:
            nodes = np.append(nodes, self.cellGeo_data[f][0][cell_IDmap])
        # Scaling the cell_geo_rPerp by 3000
        nodes = np.append(nodes, self.cellGeo_data['cell_geo_rPerp'][0][cell_IDmap]/3000.)
        for f in self.nodeFeatureNames[5:]:
            nodes = np.append(nodes, self.cellGeo_data[f][0][cell_IDmap])

        nodes = np.reshape(nodes, (len(self.nodeFeatureNames), -1)).T
        cluster_num_nodes = len(nodes)
        
        cluster_E = np.log10(event_data['cluster_E'][event_ind][cluster_ind])
        trackPt = np.log10(event_data['trackPt'][event_ind][0])
        trackEta = event_data['trackEta'][event_ind][0]
        
        global_node = np.array([cluster_E, trackPt, trackEta], dtype=np.float32)
        
        
        return nodes, global_node, cluster_num_nodes, cell_IDmap
    
    def get_edges(self, cluster_num_nodes, cell_IDmap):
        """ 
        Reading edge features 
        Resturns senders, receivers, and edges    
        """ 
        
        edge_inds = np.zeros((cluster_num_nodes, self.num_edgeFeatures))
        for i, f in enumerate(self.edgeFeatureNames):
            edge_inds[:, i] = self.cellGeo_data[f][0][cell_IDmap]
        edge_inds[np.logical_not(np.isin(edge_inds, cell_IDmap))] = np.nan
        
        senders, edge_on_inds = np.isin(edge_inds, cell_IDmap).nonzero()
        cluster_num_edges = len(senders)
        edges = np.zeros((cluster_num_edges, self.num_edgeFeatures))
        edges[np.arange(cluster_num_edges), edge_on_inds] = 1
        
        cell_IDmap_sorter = np.argsort(cell_IDmap)
        rank = np.searchsorted(cell_IDmap, edge_inds , sorter=cell_IDmap_sorter)
        receivers = cell_IDmap_sorter[rank[rank!=cluster_num_nodes]]
        
        return senders, receivers, edges

    def preprocessor(self, worker_id):
        file_num = worker_id
        while file_num < self.num_files:
            print(f"Proceesing file {os.path.basename(self.pion_file_list[file_num])}")
            file = self.pion_file_list[file_num]
            event_data = np.load(file, allow_pickle=True).item()
            num_events = len(event_data[[key for key in event_data.keys()][0]])

            preprocessed_data = []

            for event_ind in range(num_events):
                truth_particle_E = np.log10(event_data['truthPartE'][event_ind][0]) # first one is the pion! 
                trackPt = event_data['trackPt'][event_ind][0]
                if trackPt>5000:
                    continue
                
                c_ind = np.argsort(event_data['cluster_E'][event_ind])[-1]
                dR_cond = event_data['dR'][event_ind][c_ind]>self.dr_thresh
                clusterE_cond = event_data['cluster_E'][event_ind][c_ind]<self.clusterThresh
                if dR_cond or clusterE_cond:
                    continue
                        
                nodes, global_node, cluster_num_nodes, cell_IDmap = self.get_nodes(event_data, event_ind, c_ind)
                if self.use_geo_edges:
                    senders, receivers, edges = self.get_edges(cluster_num_nodes, cell_IDmap)
                else:
                    senders = [i for i in range(cluster_num_nodes) for j in range(cluster_num_nodes) if i != j]
                    receivers = [j for i in range(cluster_num_nodes) for j in range(cluster_num_nodes) if i != j]
                    senders = np.array(senders)
                    receivers = np.array(receivers)
                    n_edges = len(senders)
                    edges = np.zeros(shape=[n_edges, 0], dtype=np.float32)
                
                track_meta_data, cluster_meta_data = self.get_meta(event_data, event_ind, [c_ind])
                
                graph = {'nodes': nodes.astype(np.float32), 
                         'globals': global_node.astype(np.float32),
                         'senders': senders.astype(np.int32), 
                         'receivers': receivers.astype(np.int32),
                         'edges': edges}
                target = truth_particle_E.astype(np.float32)

                preprocessed_data.append((graph, target, track_meta_data, cluster_meta_data))

            random.shuffle(preprocessed_data)

            pickle.dump(preprocessed_data, open(self.output_dir + f'data_{file_num:03d}.p', 'wb'), compression='gzip')

            print(f"Finished processing {file_num} files")
            file_num += self.num_procs

    def preprocess_data(self):
        print('\nPreprocessing and saving data to {}'.format(self.output_dir))
        for i in range(self.num_procs):
            p = Process(target=self.preprocessor, args=(i,), daemon=True)
            p.start()
            self.procs.append(p)
        
        for p in self.procs:
            p.join()

        self.file_list = [self.output_dir + f'data_{i:03d}.p' for i in range(self.num_files)]

    def preprocessed_worker(self, worker_id, batch_queue):
        batch_graphs = []
        batch_targets = []
        batch_track_meta = []
        batch_cluster_meta = []
        
        file_num = worker_id
        while file_num < self.num_files:
            file_data = pickle.load(open(self.file_list[file_num], 'rb'), compression='gzip')

            for i in range(len(file_data)):
                batch_graphs.append(file_data[i][0])
                batch_targets.append(file_data[i][1])
                batch_track_meta.append(file_data[i][2])
                batch_cluster_meta.append(file_data[i][3])
                    
                if len(batch_graphs) == self.batch_size:
                    batch_targets = np.array(batch_targets).astype(np.float32)
                    batch_queue.put((batch_graphs, batch_targets, batch_track_meta, batch_cluster_meta))
                    
                    batch_graphs = []
                    batch_targets = []
                    batch_track_meta = []
                    batch_cluster_meta = []

            file_num += self.num_procs
                    
        if len(batch_graphs) > 0:
            batch_targets = np.array(batch_targets).astype(np.float32)
            batch_queue.put((batch_graphs, batch_targets, batch_track_meta, batch_cluster_meta))

    def worker(self, worker_id, batch_queue):
        if self.preprocess:
            self.preprocessed_worker(worker_id, batch_queue)
        else:
            raise Exception('Preprocessing is required for combined classification/regression models.')
        
    def check_procs(self):
        for p in self.procs:
            if p.is_alive(): return True
        
        return False

    def kill_procs(self):
        for p in self.procs:
            p.kill()

        self.procs = []
    
    def generator(self):
        # for file in self.file_list:
        batch_queue = Queue(2 * self.num_procs)
            
        for i in range(self.num_procs):
            p = Process(target=self.worker, args=(i, batch_queue), daemon=True)
            p.start()
            self.procs.append(p)
        
        while self.check_procs() or not batch_queue.empty():
            try:
                batch = batch_queue.get(True, 0.0001)
            except:
                continue
            
            yield batch
        
        for p in self.procs:
            p.join()

Test the data generation step...

In [None]:
pion_dir = '/usr/workspace/hip/ML4Jets/regression_images/graphs.v01-45-gaa27bcb/onetrack_multicluster/pion_files/'
pion_files = np.sort(glob.glob(pion_dir+"*.npy"))
n_files = 3

cell_geo_file = '/usr/workspace/hip/ML4Jets/regression_images/graph_examples/cell_geo.root'

In [None]:
data_gen = CaloTrackGraphDataGenerator(pion_file_list=pion_files[:n_files],
                                       cellGeo_file=cell_geo_file,
                                       batch_size=32,
                                       use_geo_edges=True,
                                       shuffle=False,
                                       num_procs=32,
                                       preprocess=True,
                                       output_dir='./')


Preprocessing and saving data to ./
Proceesing file 001.npyProceesing file 002.npy

Proceesing file 003.npy
Finished processing 2 files
Finished processing 1 files
Finished processing 0 files


In [None]:
# for graph, target, track_meta_data, cluster_meta_data in data_gen.generator():
#     print(np.array(cluster_meta_data)[:9,:, -1].squeeze())

[0.11789 0.03721 0.04428 0.13931 0.0963  0.02879 0.0257  0.10718 0.02651]
[0.02323 0.02434 0.02469 0.04276 0.03371 0.00788 0.0223  0.02623 0.06079]
[0.20967 0.01322 0.02411 0.02681 0.00914 0.03443 0.02262 0.00528 0.01584]
[0.05612 0.01514 0.1008  0.35566 0.17044 0.03246 0.02859 0.02413 0.00969]
[0.00844 0.00909 0.0061  0.23195 0.0274  0.0159  0.04065 0.03215 0.02915]
[0.3733  0.02073 0.13588 0.02572 0.1063  0.01644 0.01957 0.0143  0.07276]
[0.02334 0.02661 0.02863 0.05473 0.14176 0.02253 0.00883 0.00991 0.10053]
[0.00425 0.04959 0.00882 0.02357 0.01947 0.00447 0.0724  0.33375 0.0312 ]
[0.03362 0.0234  0.15624 0.01263 0.05806 0.02148 0.0708  0.03063 0.00912]
[0.05164 0.02031 0.19447 0.03546 0.02809 0.01701 0.00312 0.00617 0.00613]
[0.0131  0.03485 0.03844 0.09205 0.00761 0.01147 0.01221 0.03354 0.03452]
[0.01527 0.05492 0.1071  0.01391 0.01045 0.02161 0.0153  0.02591 0.02872]
[0.15676 0.17011 0.01252 0.27592 0.01879 0.14303 0.03055 0.0212  0.02699]
[0.01342 0.2101  0.01936 0.00447 0.011

[0.01135 0.02001 0.03437 0.01211 0.02713 0.09277 0.02678 0.03562 0.02472]
[0.20372 0.0304  0.03026 0.18594 0.00285 0.10769 0.03257 0.02076 0.03289]
[0.02069 0.01464 0.00758 0.03941 0.00825 0.02761 0.01482 0.05989 0.01642]
[0.06716 0.04143 0.0389  0.03138 0.02187 0.05003 0.01617 0.01848 0.19764]
[0.01756 0.04117 0.65495 0.00742 0.02229 0.04701 0.00641 0.02214 0.00833]
[0.00159 0.01319 0.01581 0.04125 0.00657 0.06455 0.01231 0.02579 0.00672]
[0.02113 0.01545 0.03791 0.02023 0.01444 0.00192 0.06179 0.00894 0.00487]
[0.02905 0.03591 0.00565 0.0159  0.0018  0.04734 0.00755 0.0341  0.02746]
[0.02401 0.06864 0.02609 0.01982 0.04607 0.01455 0.01076 0.00287 0.01501]
[0.21066 0.00165 0.02519 0.02502 0.07394 0.04887 0.37167 0.0598  0.06417]
[0.02667 0.08787 0.1051  0.02561 0.0137  0.03986 0.02092 0.04577 0.00199]
[0.02503 0.01442 0.09584 0.00743 0.04246 0.01825 0.08749 0.04627 0.0195 ]
[0.05778 0.0192  0.00471 0.01628 0.01837 0.00783 0.15763 0.0147  0.03023]
[0.11403 0.71842 0.09343 0.17051 0.016

[0.05034 0.05108 0.01089 0.01115 0.06083 0.0453  0.0204  0.01052 0.01735]
[0.23882 0.09898 0.0042  0.02327 0.21387 0.03195 0.01376 0.01436 0.16922]
[0.02023 0.04517 0.02184 0.00768 0.05647 0.04075 0.01114 0.01235 0.03618]
[0.02434 0.03426 0.02313 0.01471 0.17597 0.001   0.01727 0.03744 0.01128]
[0.06001 0.01145 0.21857 0.01134 0.04432 0.00372 0.02353 0.05696 0.02849]
[0.0649  0.0055  0.01161 0.01139 0.02888 0.01145 0.02681 0.41645 0.00778]
[0.00839 0.26011 0.02135 0.0141  0.00657 0.00454 0.00533 0.00431 0.02086]
[0.00196 0.03562 0.0068  0.01208 0.02034 0.02805 0.22141 0.02786 0.03336]
[0.03465 0.0435  0.00505 0.01115 0.04111 0.02262 0.11523 0.05591 0.02046]
[0.00684 0.02    0.00982 0.01292 0.00755 0.11602 0.0267  0.02169 0.02479]
[0.0055  0.01223 0.00347 0.03739 0.25653 0.02666 0.68654 0.03024 0.0097 ]
[0.00913 0.03364 0.029   0.0102  0.01448 0.02809 0.0117  0.02228 0.03217]
[0.01419 0.05686 0.03336 0.01961 0.04507 0.00541 0.10304 0.05315 0.05079]
[0.04371 0.21008 0.02519 0.23039 0.022

[0.04647 0.03012 0.00562 0.01728 0.0263  0.20275 0.00661 0.19174 0.01987]
[0.1024  0.0133  0.01367 0.01152 0.01154 0.27553 0.01009 0.01237 0.04129]
[0.4436  0.00706 0.0126  0.03404 0.00586 0.07092 0.01637 0.01136 0.05377]
[0.03346 0.05605 0.01584 0.03408 0.00249 0.01334 0.00944 0.00407 0.03556]
[0.02578 0.08679 0.01351 0.01202 0.01403 0.19659 0.04484 0.16333 0.0215 ]
[0.40215 0.01822 0.03001 0.02322 0.02906 0.07668 0.00356 0.04455 0.02256]
[0.00933 0.04817 0.00776 0.00376 0.05111 0.04198 0.04232 0.03036 0.0134 ]
[0.09306 0.02989 0.02064 0.08203 0.01688 0.01512 0.02345 0.00438 0.01809]
[0.17266 0.02536 0.01573 0.07969 0.03428 0.01655 0.06935 0.01358 0.01037]
[0.04084 0.57332 0.00129 0.01566 0.01514 0.13872 0.00808 0.00875 0.00947]
[0.02452 0.0228  0.01343 0.02022 0.07163 0.11034 0.10899 0.03005 0.01928]
[0.00641 0.05631 0.56548 0.00607 0.20991 0.00649 0.01924 0.01815 0.02584]
[0.02177 0.03639 0.06211 0.02698 0.02164 0.01777 0.02798 0.00586 0.06072]
[0.04989 0.01376 0.01876 0.05609 0.013

[0.00658 0.01468 0.0447  0.05541 0.01465 0.00827 0.03789 0.02163 0.05753]
[0.00935 0.06112 0.01047 0.03603 0.03678 0.15497 0.09763 0.01142 0.02118]
[0.01775 0.01585 0.00742 0.01625 0.02    0.00454 0.01354 0.08655 0.11232]
[0.04788 0.01138 0.0119  0.08978 0.0501  0.03784 0.01638 0.01184 0.01896]
[0.03251 0.09436 0.0253  0.02759 0.02753 0.03256 0.0181  0.03088 0.0047 ]
[0.00725 0.02685 0.00361 0.01106 0.03439 0.01492 0.00954 0.17832 0.01115]
[0.05003 0.05347 0.0302  0.01047 0.057   0.01223 0.02932 0.09006 0.00921]
[0.03026 0.00634 0.63336 0.01298 0.03503 0.04374 0.01694 0.01715 0.19636]
[0.01583 0.02509 0.02688 0.00389 0.00664 0.3037  0.03973 0.01633 0.00648]
[0.30033 0.00624 0.04272 0.03324 0.00326 0.07504 0.01744 0.01137 0.04552]
[0.00995 0.07247 0.05062 0.03387 0.10336 0.00316 0.01228 0.02292 0.04287]
[0.25489 0.03888 0.01561 0.0147  0.01747 0.0304  0.1218  0.02349 0.01535]
[0.017   0.08359 0.01161 0.01917 0.02271 0.02103 0.23972 0.03167 0.09022]
[0.00665 0.02494 0.00648 0.01981 0.023

[0.0426  0.04154 0.01834 0.02227 0.01231 0.0262  0.02571 0.01959 0.00775]
[0.00999 0.12209 0.02858 0.00881 0.02084 0.00669 0.00142 0.09534 0.00822]
[0.1865  0.00792 0.01582 0.02    0.12078 0.01814 0.05483 0.0166  0.01333]
[0.02143 0.03368 0.04789 0.01792 0.04092 0.0312  0.00942 0.03516 0.03409]
[0.01784 0.002   0.03251 0.01987 0.00907 0.01671 0.01243 0.01786 0.0168 ]
[0.00566 0.0442  0.08954 0.01832 0.03908 0.0193  0.00299 0.01828 0.0278 ]
[0.02208 0.02311 0.0433  0.01423 0.02537 0.01559 0.01693 0.01241 0.01034]
[0.12108 0.10962 0.02022 0.02611 0.01168 0.02719 0.00742 0.07699 0.03903]
[0.17401 0.04639 0.01129 0.0224  0.0464  0.10919 0.03068 0.0288  0.00842]
[0.04237 0.0981  0.02293 0.01689 0.02415 0.19355 0.01282 0.0021  0.00306]
[0.00808 0.00829 0.03684 0.0331  0.01464 0.08476 0.02228 0.11881 0.02177]
[0.01732 0.02981 0.02281 0.03059 0.01728 0.02039 0.0112  0.04101 0.02367]
[0.02016 0.02754 0.05042 0.04812 0.38942 0.01788 0.03358 0.04908 0.00624]
[0.05245 0.07266 0.04208 0.01549 0.011

[0.02983 0.65014 0.01399 0.02974 0.13989 0.06927 0.01041 0.0055  0.01114]
[0.00426 0.00837 0.27008 0.00582 0.02272 0.15105 0.02439 0.02753 0.0231 ]
[0.01853 0.01082 0.01897 0.176   0.05579 0.02563 0.04054 0.17635 0.01652]
[0.03705 0.00035 0.03134 0.00395 0.00854 0.02438 0.01376 0.02298 0.06513]
[0.02664 0.013   0.01083 0.03884 0.02067 0.01968 0.0126  0.0158  0.02133]
[0.01371 0.02197 0.01921 0.02104 0.00289 0.0209  0.03711 0.038   0.02358]
[0.04744 0.00311 0.19312 0.28056 0.26141 0.02249 0.00706 0.01718 0.00857]
[0.04274 0.01625 0.00169 0.01051 0.02318 0.01879 0.02844 0.09788 0.05197]
[0.02905 0.01452 0.00361 0.01934 0.39999 0.01705 0.00142 0.21024 0.03728]
[0.02066 0.03605 0.05358 0.00417 0.01002 0.00741 0.0278  0.1978  0.20708]
[0.01149 0.07044 0.01327 0.02187 0.02322 0.10827 0.038   0.03244 0.1602 ]
[0.04928 0.07328 0.02785 0.04228 0.02943 0.00679 0.0578  0.0349  0.01295]
[0.07751 0.02676 0.02516 0.16489 0.0093  0.01392 0.0106  0.03694 0.02218]
[0.00733 0.01177 0.00255 0.02632 0.158

In [None]:
graph, target, track_meta_data, cluster_meta_data = next(data_gen.generator())

In [None]:
graph

[{'nodes': array([[ 1.03284,  0.07143,  0.96153, -1.63583,  0.5751 ,  0.025  ,
           0.02454],
         [ 0.4771 ,  0.07143,  0.96152, -1.66031,  0.5751 ,  0.025  ,
           0.02454],
         [-0.11807,  0.07143,  0.96154, -1.61134,  0.57509,  0.025  ,
           0.02454],
         [ 0.28765,  0.07143,  0.93657, -1.63583,  0.5751 ,  0.025  ,
           0.02454],
         [-0.07765,  0.07143,  0.98649, -1.63583,  0.57178,  0.025  ,
           0.02454],
         [-0.15055,  0.07143,  0.93656, -1.66031,  0.5751 ,  0.025  ,
           0.02454],
         [-0.27177,  0.07143,  0.98648, -1.66031,  0.57179,  0.025  ,
           0.02454],
         [-0.77892,  0.07143,  0.93658, -1.61134,  0.57509,  0.025  ,
           0.02454],
         [-0.51553,  0.07143,  0.9865 , -1.61134,  0.57177,  0.025  ,
           0.02454],
         [-1.13155,  0.03571,  0.95355, -1.62366,  0.50956,  0.00313,
           0.09817],
         [-1.88569,  0.03571,  0.95667, -1.62366,  0.50956,  0.00313,
           