In [None]:
# default_exp modules.data

In [None]:
#export
import numpy as np
import glob
import os
import uproot as ur
import time
from multiprocessing import Process, Queue, set_start_method
import compress_pickle as pickle
from scipy.stats import circmean
import random

# GraphDataGenerator

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
class GraphDataGenerator:
    """
    DataGenerator class for extracting and formating data from list of root files
    This data generator uses the cell_geo file to create the input graph structure
    """
    def __init__(self,
                 pi0_file_list: list,
                 pion_file_list: list,
                 cellGeo_file: str,
                 batch_size: int,
                 shuffle: bool = True,
                 num_procs = 32,
                 preprocess = False,
                 output_dir = None):
        """Initialization"""

        self.preprocess = preprocess
        self.output_dir = output_dir

        if self.preprocess and self.output_dir is not None:
            self.pi0_file_list = pi0_file_list
            self.pion_file_list = pion_file_list
            assert len(pi0_file_list) == len(pion_file_list)
            self.num_files = len(self.pi0_file_list)
        else:
            self.file_list = pi0_file_list
            self.num_files = len(self.file_list)
        
        self.cellGeo_file = cellGeo_file
        
        self.cellGeo_data = ur.open(self.cellGeo_file)['CellGeo']
        self.geoFeatureNames = self.cellGeo_data.keys()[1:9]
        self.nodeFeatureNames = ['cluster_cell_E', *self.geoFeatureNames[:-2]]
        self.edgeFeatureNames = self.cellGeo_data.keys()[9:]
        self.num_nodeFeatures = len(self.nodeFeatureNames)
        self.num_edgeFeatures = len(self.edgeFeatureNames)

        self.cellGeo_data = self.cellGeo_data.arrays(library='np')
        self.cellGeo_ID = self.cellGeo_data['cell_geo_ID'][0]
        self.sorter = np.argsort(self.cellGeo_ID)
        
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        if self.shuffle: np.random.shuffle(self.file_list)
        
        self.num_procs = np.min([num_procs, self.num_files])
        self.procs = []

        if self.preprocess and self.output_dir is not None:
            os.makedirs(self.output_dir, exist_ok=True)
            self.preprocess_data()

    def get_cluster_calib(self, event_data, event_ind, cluster_ind):
        """ Reading cluster calibration energy """ 
            
        cluster_calib_E = event_data['cluster_ENG_CALIB_TOT'][event_ind][cluster_ind]

        if cluster_calib_E <= 0:
            return None

        return np.log10(cluster_calib_E)
            
    def get_nodes(self, event_data, event_ind, cluster_ind):
        """ Reading Node features """ 

        cell_IDs = event_data['cluster_cell_ID'][event_ind][cluster_ind]
        cell_IDmap = self.sorter[np.searchsorted(self.cellGeo_ID, cell_IDs, sorter=self.sorter)]
        
        nodes = np.log10(event_data['cluster_cell_E'][event_ind][cluster_ind])
        global_node = np.log10(event_data['cluster_E'][event_ind][cluster_ind])
        
        # Scaling the cell_geo_sampling by 28
        nodes = np.append(nodes, self.cellGeo_data['cell_geo_sampling'][0][cell_IDmap]/28.)
        for f in self.nodeFeatureNames[2:4]:
            nodes = np.append(nodes, self.cellGeo_data[f][0][cell_IDmap])
        # Scaling the cell_geo_rPerp by 3000
        nodes = np.append(nodes, self.cellGeo_data['cell_geo_rPerp'][0][cell_IDmap]/3000.)
        for f in self.nodeFeatureNames[5:]:
            nodes = np.append(nodes, self.cellGeo_data[f][0][cell_IDmap])

        nodes = np.reshape(nodes, (len(self.nodeFeatureNames), -1)).T
        cluster_num_nodes = len(nodes)
        
        return nodes, np.array([global_node]), cluster_num_nodes, cell_IDmap
    
    def get_edges(self, cluster_num_nodes, cell_IDmap):
        """ 
        Reading edge features 
        Resturns senders, receivers, and edges    
        """ 
        
        edge_inds = np.zeros((cluster_num_nodes, self.num_edgeFeatures))
        for i, f in enumerate(self.edgeFeatureNames):
            edge_inds[:, i] = self.cellGeo_data[f][0][cell_IDmap]
        edge_inds[np.logical_not(np.isin(edge_inds, cell_IDmap))] = np.nan
        
        senders, edge_on_inds = np.isin(edge_inds, cell_IDmap).nonzero()
        cluster_num_edges = len(senders)
        edges = np.zeros((cluster_num_edges, self.num_edgeFeatures))
        edges[np.arange(cluster_num_edges), edge_on_inds] = 1
        
        cell_IDmap_sorter = np.argsort(cell_IDmap)
        rank = np.searchsorted(cell_IDmap, edge_inds , sorter=cell_IDmap_sorter)
        receivers = cell_IDmap_sorter[rank[rank!=cluster_num_nodes]]
        
        return senders, receivers, edges

    def preprocessor(self, worker_id):
        """
        Prerocessing root file data for faster data 
        generation during multiple training epochs
        """
        file_num = worker_id
        while file_num < self.num_files:
            print(f"Proceesing file number {file_num}")
            file = self.pion_file_list[file_num]
            event_tree = ur.open(file)['EventTree']
            num_events = event_tree.num_entries

            event_data = event_tree.arrays(library='np')

            preprocessed_data = []

            for event_ind in range(num_events):
                num_clusters = event_data['nCluster'][event_ind]
                
                for i in range(num_clusters):
                    cluster_calib_E = self.get_cluster_calib(event_data, event_ind, i)
                    
                    if cluster_calib_E is None:
                        continue
                        
                    nodes, global_node, cluster_num_nodes, cell_IDmap = self.get_nodes(event_data, event_ind, i)
                    senders, receivers, edges = self.get_edges(cluster_num_nodes, cell_IDmap)

                    graph = {'nodes': nodes.astype(np.float32), 'globals': global_node.astype(np.float32),
                        'senders': senders.astype(np.int32), 'receivers': receivers.astype(np.int32),
                        'edges': edges.astype(np.float32)}
                    target = np.reshape([cluster_calib_E.astype(np.float32), 1], [1,2])

                    preprocessed_data.append((graph, target))

            file = self.pi0_file_list[file_num]
            event_tree = ur.open(file)['EventTree']
            num_events = event_tree.num_entries

            event_data = event_tree.arrays(library='np')

            for event_ind in range(num_events):
                num_clusters = event_data['nCluster'][event_ind]
                
                for i in range(num_clusters):
                    cluster_calib_E = self.get_cluster_calib(event_data, event_ind, i)
                    
                    if cluster_calib_E is None:
                        continue
                        
                    nodes, global_node, cluster_num_nodes, cell_IDmap = self.get_nodes(event_data, event_ind, i)
                    senders, receivers, edges = self.get_edges(cluster_num_nodes, cell_IDmap)
                    
                    graph = {'nodes': nodes.astype(np.float32), 'globals': global_node.astype(np.float32),
                        'senders': senders.astype(np.int32), 'receivers': receivers.astype(np.int32),
                        'edges': edges.astype(np.float32)}
                    target = np.reshape([cluster_calib_E.astype(np.float32), 0], [1,2])

                    preprocessed_data.append((graph, target))

            random.shuffle(preprocessed_data)

            pickle.dump(preprocessed_data, open(self.output_dir + f'data_{file_num:03d}.p', 'wb'), compression='gzip')
            
            print(f"Finished processing {file_num} files")
            file_num += self.num_procs

    def preprocess_data(self):
        print('\nPreprocessing and saving data to {}'.format(self.output_dir))
        for i in range(self.num_procs):
            p = Process(target=self.preprocessor, args=(i,), daemon=True)
            p.start()
            self.procs.append(p)
        
        for p in self.procs:
            p.join()

        self.file_list = [self.output_dir + f'data_{i:03d}.p' for i in range(self.num_files)]

    def preprocessed_worker(self, worker_id, batch_queue):
        batch_graphs = []
        batch_targets = []

        file_num = worker_id
        while file_num < self.num_files:
            file_data = pickle.load(open(self.file_list[file_num], 'rb'), compression='gzip')

            for i in range(len(file_data)):
                batch_graphs.append(file_data[i][0])
                batch_targets.append(file_data[i][1])
                    
                if len(batch_graphs) == self.batch_size:
                    batch_targets = np.reshape(np.array(batch_targets), [-1,2]).astype(np.float32)
                    
                    batch_queue.put((batch_graphs, batch_targets))
                    
                    batch_graphs = []
                    batch_targets = []

            file_num += self.num_procs
                    
        if len(batch_graphs) > 0:
            batch_targets = np.reshape(np.array(batch_targets), [-1,2]).astype(np.float32)
            
            batch_queue.put((batch_graphs, batch_targets))

    def worker(self, worker_id, batch_queue):
        if self.preprocess:
            self.preprocessed_worker(worker_id, batch_queue)
        else:
            raise Exception('Preprocessing is required for combined classification/regression models.')
        
    def check_procs(self):
        for p in self.procs:
            if p.is_alive(): return True
        
        return False

    def kill_procs(self):
        for p in self.procs:
            p.kill()

        self.procs = []
    
    def generator(self):
        """
        Generator that returns processed batches during training
        """
        batch_queue = Queue(2 * self.num_procs)
            
        for i in range(self.num_procs):
            p = Process(target=self.worker, args=(i, batch_queue), daemon=True)
            p.start()
            self.procs.append(p)
        
        while self.check_procs() or not batch_queue.empty():
            try:
                batch = batch_queue.get(True, 0.0001)
            except:
                continue
            
            yield batch
        
        for p in self.procs:
            p.join()

In [None]:
data_dir = '/usr/workspace/hip/ML4Jets/regression_images/'
out_dir = '/p/vast1/karande1/heavyIon/data/preprocessed_data/gn4pions/geo/train/'
pi0_files = np.sort(glob.glob(data_dir+'graphs.v01-45-gaa27bcb/'+'*pi0*/*.root'))[10:20]
pion_files = np.sort(glob.glob(data_dir+'graphs.v01-45-gaa27bcb/'+'*pion*/*.root'))[10:20]

In [None]:
data_gen = GraphDataGenerator(pion_file_list=pion_files, 
                              pi0_file_list=pi0_files,
                              cellGeo_file=data_dir+'graph_examples/cell_geo.root',
                              batch_size=32,
                              shuffle=False,
                              num_procs=32,
                              preprocess=True,
                              output_dir=out_dir)

# gen = data_gen.generator()


Preprocessing and saving data to /p/vast1/karande1/heavyIon/data/preprocessed_data/gn4pions/geo/train/
Proceesing file number 0
Proceesing file number 1
Proceesing file number 2
Proceesing file number 3
Proceesing file number 4
Proceesing file number 5
Proceesing file number 6
Proceesing file number 7
Proceesing file number 8
Proceesing file number 9
Finished processing 38 files
Finished processing 35 files
Finished processing 32 files
Finished processing 34 files
Finished processing 41 files
Finished processing 36 files
Finished processing 33 files
Finished processing 40 files
Finished processing 37 files
Finished processing 39 files


In [None]:
out_dir = '/p/vast1/karande1/heavyIon/data/preprocessed_data/gn4pions/geo/val/'
pi0_files = np.sort(glob.glob(data_dir+'graphs.v01-45-gaa27bcb/'+'*pi0*/*.root'))[20:30]
pion_files = np.sort(glob.glob(data_dir+'graphs.v01-45-gaa27bcb/'+'*pion*/*.root'))[20:30]
data_gen_test = GraphDataGenerator(pion_file_list=pion_files, 
                                   pi0_file_list=pi0_files,
                                   cellGeo_file=data_dir+'graph_examples/cell_geo.root',
                                   batch_size=32,
                                   shuffle=False,
                                   num_procs=32,
                                   preprocess=True,
                                   output_dir=out_dir)


Preprocessing and saving data to /p/vast1/karande1/heavyIon/data/preprocessed_data/gn4pions/geo/val/
Proceesing file number 0
Proceesing file number 1
Proceesing file number 2
Proceesing file number 3
Proceesing file number 4
Proceesing file number 5
Proceesing file number 6
Proceesing file number 7
Proceesing file number 8
Proceesing file number 9
Finished processing 16 files
Finished processing 15 files
Finished processing 18 files
Finished processing 19 files
Finished processing 10 files
Finished processing 14 files
Finished processing 13 files
Finished processing 12 files
Finished processing 17 files
Finished processing 11 files
