In [1]:
import cProfile
from io import StringIO
from functools import wraps
import time
from collections import defaultdict
import pstats
from tqdm import tqdm
import torch
from pympler import asizeof
import numpy as np
import pandas as pd
from IPython.display import clear_output

from time_res_util import get_compiled_NF_model
from momentum_prediction_util import load_defaultdict, SiPMSignalProcessor

Using device cuda:0


In [3]:
def profile_function(func):
    """
    Decorator to profile a specific function using cProfile
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        profiler = cProfile.Profile()
        try:
            return profiler.runcall(func, *args, **kwargs)
        finally:
            s = StringIO()
            stats = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
            stats.print_stats(20)  # Print top 20 time-consuming operations
            print(s.getvalue())
    return wrapper

'''MEMORY PROFILING'''
import linecache
import os
import tracemalloc

def display_top(snapshot, key_type='lineno', limit=3):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        # replace "/path/to/module/file.py" with "module/file.py"
        filename = os.sep.join(frame.filename.split(os.sep)[-2:])
        print("#%s: %s:%s: %.1f KiB"
              % (index, filename, frame.lineno, stat.size / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)

    other = top_stats[limit:]
    if other:
        size = sum(stat.size for stat in other)
        print("%s other: %.1f KiB" % (len(other), size / 1024))
    total = sum(stat.size for stat in top_stats)
    print("Total allocated size: %.1f KiB" % (total / 1024))

tracemalloc.start()

In [4]:
inputProcessedData = "./data/processed_data/jan_13_new_analyze_10events.json"
model_compile = get_compiled_NF_model()
processed_data = load_defaultdict(inputProcessedData)


  self.load_state_dict(torch.load(path))


In [4]:
normalizing_flow = model_compile
batch_size=50000
device='cuda'
pixel_threshold = 5

all_context = []
all_time_pixels = []
all_metadata = []
num_pixel_list = ["num_pixels_high_z","num_pixels_low_z"]
print("Processing data in new_prepare_nn_input...")
for event_idx, event_data in tqdm(processed_data.items()):
    for stave_idx, stave_data in event_data.items():
        for layer_idx, layer_data in stave_data.items():
            for segment_idx, segment_data in layer_data.items():
                trueID_list = []
                for particle_id, particle_data in segment_data.items():
#                         print(f"keys of particle data: {particle_data.keys()}")
#                         print(f"types: {type(particle_data['z_pos'])},{type(particle_data['hittheta'])},{type(particle_data['hitmomentum'])}")
                    base_context = torch.tensor([particle_data['z_pos'], particle_data['hittheta'], particle_data['hitmomentum']], 
                                                dtype=torch.float32)
                    base_time_pixels_low = torch.tensor([particle_data['time'], particle_data['num_pixels_low_z']], 
                                                    dtype=torch.float32)
                    base_time_pixels_high = torch.tensor([particle_data['time'], particle_data['num_pixels_high_z']], 
                                                    dtype=torch.float32)
                    if particle_data['trueID'] not in  trueID_list:
                        trueID_list.append(particle_data['trueID'])
                    for SiPM_idx in range(2):
                        z_pos = particle_data['z_pos']
                        context = base_context.clone()
                        context[0] = z_pos
                        num_pixel_tag = num_pixel_list[SiPM_idx]
                        all_context.append(context.repeat(particle_data[num_pixel_tag], 1))
                        if(SiPM_idx == 0):
                            all_time_pixels.append(base_time_pixels_high.repeat(particle_data[num_pixel_tag], 1))
                        else:
                            all_time_pixels.append(base_time_pixels_low.repeat(particle_data[num_pixel_tag], 1))
                        # Assuming particle_data is a dictionary-like object and trueID_list is defined
                        fields = [
                            'truemomentum', 'trueID', 'truePID', 'hitID', 'hitPID', 
                            'truetheta', 'truephi', 'strip_x', 'strip_y', 'strip_z', 
                            'hit_x', 'hit_y', 'hit_z', 'KMU_trueID', 'KMU_truePID', 
                            'KMU_true_phi', 'KMU_true_momentum_mag', 'KMU_endpoint_x', 
                            'KMU_endpoint_y', 'KMU_endpoint_z'
                        ]

                        # Print types of each particle_data field
#                             for field in fields:
#                                 value = particle_data.get(field, None)
#                                 print(f"{field}: {type(value)}")

#                             # Print the type of len(trueID_list)
#                             print(f"len(trueID_list): {type(len(trueID_list))}")

                        all_metadata.extend([(event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, particle_data['truemomentum'],particle_data['trueID'],particle_data['truePID'],particle_data['hitID'],particle_data['hitPID'],particle_data['truetheta'],particle_data['truephi'],particle_data['strip_x'],particle_data['strip_y'],particle_data['strip_z'],len(trueID_list),particle_data['hit_x'],particle_data['hit_y'],particle_data['hit_z'],particle_data['KMU_trueID'],particle_data['KMU_truePID'],particle_data['KMU_true_phi'],particle_data['KMU_true_momentum_mag'],particle_data['KMU_endpoint_x'],particle_data['KMU_endpoint_y'],particle_data['KMU_endpoint_z'])] * particle_data[num_pixel_tag])

all_context = torch.cat(all_context)
all_time_pixels = torch.cat(all_time_pixels)

print("Sampling data...")
sampled_data = []
begin = time.time()
for i in tqdm(range(0, len(all_context), batch_size)):
    batch_end = min(i + batch_size, len(all_context))
    batch_context = all_context[i:batch_end].to(device)
    batch_time_pixels = all_time_pixels[i:batch_end]

    with torch.no_grad():
        samples = abs(normalizing_flow.sample(num_samples=len(batch_context), context=batch_context)[0]).squeeze(1)

    sampled_data.extend(samples.cpu() + batch_time_pixels[:, 0])
end = time.time()
print(f"sampling took {end - begin} seconds")
print("Processing signal...")

Processing data in new_prepare_nn_input...


100%|██████████| 9/9 [00:01<00:00,  7.59it/s]


Sampling data...


100%|██████████| 24/24 [00:16<00:00,  1.44it/s]

sampling took 16.723486185073853 seconds
Processing signal...





In [15]:
# VARIABLES FOR SAVING DATA AS DF
processer = SiPMSignalProcessor()
rows = []

seen_keys = []
curr_key = (-1,-1,-1,-1)

current_samples = [[],[]] 
processor = SiPMSignalProcessor()

translated_trueID = 0
trueID_dict_running_idx = 0
trueID_dict = {}

begin = time.time()

sample_idx = 0
for (event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, momentum,trueID,truePID,hitID,hitPID,theta,phi,strip_x,strip_y,strip_z,trueID_list_len,hit_x,hit_y,hit_z,KMU_trueID,KMU_truePID,KMU_true_phi,KMU_true_momentum_mag,KMU_endpoint_x,KMU_endpoint_y,KMU_endpoint_z), sample in zip(all_metadata, sampled_data):

    #progress bar
    floor_percent = int(np.floor(len(sampled_data) / 100))
    if(sample_idx % floor_percent == 0):
        curr_time = time.time()
        print(f"Signal Processing is now {int(np.floor(sample_idx / len(sampled_data) * 100))}% complete (time elapsed: {curr_time - begin})")
        clear_output(wait = True)
    sample_idx += 1

    # Work with all samples of one SiPM together
    key = (event_idx, stave_idx, layer_idx, segment_idx)
    if key in seen_keys:
        if key == curr_key:
            current_samples[SiPM_idx].append(sample)
        else:
            continue
            print(f"ERROR: key: {key} | curr_key: {curr_key}")
    # First key
    elif curr_key == (-1,-1,-1,-1):
        current_samples[SiPM_idx].append(sample)
        seen_keys.append(key)
        curr_key = key
    # End of curr_key: perform calc
    else:
        #calculate photon stuff on current_samples

        '''IMPLEMENTING PREDICTION INPUT PULSE SEGMENT BY SEGMENT'''
        curr_event_idx = curr_key[0]
        curr_stave_idx = curr_key[1]
        curr_layer_idx = curr_key[2]
        curr_segment_idx = curr_key[3]
        for curr_SiPM_idx in range(2):
            trigger = False
            photon_times = np.array(current_samples[curr_SiPM_idx]) * 10 **(-9)
            if(len(photon_times) > 0):
                time_arr,waveform = processor.generate_waveform(photon_times)
                timing = processer.get_pulse_timing(waveform,threshold = pixel_threshold)
                if(timing is not None):
                    #scale inputs to avoid exploding gradients
                    curr_charge = processor.integrate_charge(waveform) * 1e6
                    curr_timing = timing * 1e8
                    trigger = True
                #skip segments that don't pass the threshold
                else:
                    continue
            #skip segments with no photon hits
            else:
                continue
            if(trueID_list_len > 1):
                translated_trueID = -1
            else:
                if((event_idx,trueID) not in trueID_dict):
                    trueID_dict[(event_idx,trueID)] = trueID_dict_running_idx
                    trueID_dict_running_idx += 1
                translated_trueID = trueID_dict[(event_idx,trueID)]
            new_row = {
                "event_idx"      : curr_event_idx,
                "stave_idx"      : curr_stave_idx,
                "layer_idx"      : curr_layer_idx,
                "segment_idx"    : curr_segment_idx,
                "SiPM_idx"    : curr_SiPM_idx,
                "trueID"         : translated_trueID,
                "truePID"        : trueID,
                "hitID"          : hitID,
                "P"              : momentum,
                "Theta"          : theta,
                "Phi"            : phi,
                "strip_x"        : strip_z,
                "strip_y"        : strip_x,
                "strip_z"        : strip_y,
                "hit_x"          : hit_x,
                "hit_y"          : hit_y,
                "hit_z"          : hit_z,
                "KMU_endpoint_x" : KMU_endpoint_x,
                "KMU_endpoint_y" : KMU_endpoint_y,
                "KMU_endpoint_z" : KMU_endpoint_z,
                "Charge"         : curr_charge,
                "Time"           : curr_timing
            }
            rows.append(new_row)
        ''' END IMPLEMENTATION '''
        #reset current samples for new key
        seen_keys.append(key)
        current_samples = [[],[]]
        current_samples.append(sample)
        curr_key = key


end = time.time()
ret_df = pd.DataFrame(rows)
print(f"Creating DF took {end - begin} seconds")

Creating DF took 36.81203651428223 seconds


In [None]:
from line_profiler import LineProfiler

profiler = LineProfiler()
profiler.add_function(my_function)
profiler.run('my_function()')
profiler.print_stats()