In [1]:
import cProfile
from io import StringIO
from functools import wraps
import time
from collections import defaultdict
import pstats
from tqdm import tqdm
import torch
from pympler import asizeof
import numpy as np
import pandas as pd
from IPython.display import clear_output
import matplotlib.pyplot as plot

from time_res_util import get_compiled_NF_model
from momentum_prediction_util import load_defaultdict, SiPMSignalProcessor

Using device cuda:0


In [2]:
def profile_function(func):
    """
    Decorator to profile a specific function using cProfile
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        profiler = cProfile.Profile()
        try:
            return profiler.runcall(func, *args, **kwargs)
        finally:
            s = StringIO()
            stats = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
            stats.print_stats(20)  # Print top 20 time-consuming operations
            print(s.getvalue())
    return wrapper

'''MEMORY PROFILING'''
import linecache
import os
import tracemalloc

def display_top(snapshot, key_type='lineno', limit=3):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        # replace "/path/to/module/file.py" with "module/file.py"
        filename = os.sep.join(frame.filename.split(os.sep)[-2:])
        print("#%s: %s:%s: %.1f KiB"
              % (index, filename, frame.lineno, stat.size / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)

    other = top_stats[limit:]
    if other:
        size = sum(stat.size for stat in other)
        print("%s other: %.1f KiB" % (len(other), size / 1024))
    total = sum(stat.size for stat in top_stats)
    print("Total allocated size: %.1f KiB" % (total / 1024))

tracemalloc.start()

In [3]:
inputProcessedData = "./data/processed_data/jan_13_new_analyze_10events.json"
model_compile = get_compiled_NF_model()
processed_data = load_defaultdict(inputProcessedData)


  self.load_state_dict(torch.load(path))


## Current fastest

In [12]:
def newer_prepare_nn_input(processed_data = processed_data, normalizing_flow=model_compile, batch_size=50000, device='cuda',pixel_threshold = 5):
    processer = SiPMSignalProcessor()
    
    all_context = []
    all_time_pixels = []
    all_metadata = []
    num_pixel_list = ["num_pixels_high_z","num_pixels_low_z"]
    print("Processing data in new_prepare_nn_input...")
    for event_idx, event_data in tqdm(processed_data.items()):
        for stave_idx, stave_data in event_data.items():
            for layer_idx, layer_data in stave_data.items():
                for segment_idx, segment_data in layer_data.items():
                    trueID_list = []
                    for particle_id, particle_data in segment_data.items():
#                         print(f"keys of particle data: {particle_data.keys()}")
#                         print(f"types: {type(particle_data['z_pos'])},{type(particle_data['hittheta'])},{type(particle_data['hitmomentum'])}")
                        base_context = torch.tensor([particle_data['z_pos'], particle_data['hittheta'], particle_data['hitmomentum']], 
                                                    dtype=torch.float32)
                        base_time_pixels_low = torch.tensor([particle_data['time'], particle_data['num_pixels_low_z']], 
                                                        dtype=torch.float32)
                        base_time_pixels_high = torch.tensor([particle_data['time'], particle_data['num_pixels_high_z']], 
                                                        dtype=torch.float32)
                        if particle_data['trueID'] not in  trueID_list:
                            trueID_list.append(particle_data['trueID'])
                        for SiPM_idx in range(2):
                            z_pos = particle_data['z_pos']
                            context = base_context.clone()
                            context[0] = z_pos
                            num_pixel_tag = num_pixel_list[SiPM_idx]
                            all_context.append(context.repeat(particle_data[num_pixel_tag], 1))
                            if(SiPM_idx == 0):
                                all_time_pixels.append(base_time_pixels_high.repeat(particle_data[num_pixel_tag], 1))
                            else:
                                all_time_pixels.append(base_time_pixels_low.repeat(particle_data[num_pixel_tag], 1))
                            # Assuming particle_data is a dictionary-like object and trueID_list is defined
                            fields = [
                                'truemomentum', 'trueID', 'truePID', 'hitID', 'hitPID', 
                                'truetheta', 'truephi', 'strip_x', 'strip_y', 'strip_z', 
                                'hit_x', 'hit_y', 'hit_z', 'KMU_trueID', 'KMU_truePID', 
                                'KMU_true_phi', 'KMU_true_momentum_mag', 'KMU_endpoint_x', 
                                'KMU_endpoint_y', 'KMU_endpoint_z'
                            ]

                            # Print types of each particle_data field
#                             for field in fields:
#                                 value = particle_data.get(field, None)
#                                 print(f"{field}: {type(value)}")

#                             # Print the type of len(trueID_list)
#                             print(f"len(trueID_list): {type(len(trueID_list))}")

                            all_metadata.extend([(event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, particle_data['truemomentum'],particle_data['trueID'],particle_data['truePID'],particle_data['hitID'],particle_data['hitPID'],particle_data['truetheta'],particle_data['truephi'],particle_data['strip_x'],particle_data['strip_y'],particle_data['strip_z'],len(trueID_list),particle_data['hit_x'],particle_data['hit_y'],particle_data['hit_z'],particle_data['KMU_trueID'],particle_data['KMU_truePID'],particle_data['KMU_true_phi'],particle_data['KMU_true_momentum_mag'],particle_data['KMU_endpoint_x'],particle_data['KMU_endpoint_y'],particle_data['KMU_endpoint_z'])] * particle_data[num_pixel_tag])

    all_context = torch.cat(all_context)
    all_time_pixels = torch.cat(all_time_pixels)
    
    print("Sampling data...")
    sampled_data = []
    begin = time.time()
    for i in tqdm(range(0, len(all_context), batch_size)):
        batch_end = min(i + batch_size, len(all_context))
        batch_context = all_context[i:batch_end].to(device)
        batch_time_pixels = all_time_pixels[i:batch_end]
        
        with torch.no_grad():
            samples = abs(normalizing_flow.sample(num_samples=len(batch_context), context=batch_context)[0]).squeeze(1)
        
        sampled_data.extend(samples.cpu() + batch_time_pixels[:, 0])
    end = time.time()
    print(f"sampling took {end - begin} seconds")
    print("Processing signal...")
    
    
    # VARIABLES FOR SAVING DATA AS DF
    processer = SiPMSignalProcessor()
    rows = []

    seen_keys = set()
    curr_key = (-1,-1,-1,-1)

    current_samples = [[],[]] 
    processor = SiPMSignalProcessor()

    translated_trueID = 0
    trueID_dict_running_idx = 0
    trueID_dict = {}

    begin = time.time()

#     sample_idx = 0
    for (event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, momentum,trueID,truePID,hitID,hitPID,theta,phi,strip_x,strip_y,strip_z,trueID_list_len,hit_x,hit_y,hit_z,KMU_trueID,KMU_truePID,KMU_true_phi,KMU_true_momentum_mag,KMU_endpoint_x,KMU_endpoint_y,KMU_endpoint_z), sample in zip(all_metadata, sampled_data):

        # Work with all samples of one SiPM together
        key = (event_idx, stave_idx, layer_idx, segment_idx)
        
        if key in seen_keys:
            if key == curr_key:
                current_samples[SiPM_idx].append(sample)
            else:
                continue
                print(f"ERROR: key: {key} | curr_key: {curr_key}")
        # First key
        elif curr_key == (-1,-1,-1,-1):
            current_samples[SiPM_idx].append(sample)
            seen_keys.add(key)
            curr_key = key
        # End of curr_key: perform calc
        else:
            #calculate photon stuff on current_samples

            '''IMPLEMENTING PREDICTION INPUT PULSE SEGMENT BY SEGMENT'''
            curr_event_idx = curr_key[0]
            curr_stave_idx = curr_key[1]
            curr_layer_idx = curr_key[2]
            curr_segment_idx = curr_key[3]
            for curr_SiPM_idx in range(2):
                trigger = False
                photon_times = np.array(current_samples[curr_SiPM_idx]) * 10 **(-9)
                if(len(photon_times) > 0):
                    time_arr,waveform = processor.generate_waveform(photon_times)
                    timing = processer.get_pulse_timing(waveform,threshold = pixel_threshold)
                    if(timing is not None):
                        #scale inputs to avoid exploding gradients
                        curr_charge = processor.integrate_charge(waveform) * 1e6
                        curr_timing = timing * 1e8
                        trigger = True
                    #skip segments that don't pass the threshold
                    else:
                        continue
                #skip segments with no photon hits
                else:
                    continue
                if(trueID_list_len > 1):
                    translated_trueID = -1
                else:
                    if((event_idx,trueID) not in trueID_dict):
                        trueID_dict[(event_idx,trueID)] = trueID_dict_running_idx
                        trueID_dict_running_idx += 1
                    translated_trueID = trueID_dict[(event_idx,trueID)]
                new_row = {
                    "event_idx"      : curr_event_idx,
                    "stave_idx"      : curr_stave_idx,
                    "layer_idx"      : curr_layer_idx,
                    "segment_idx"    : curr_segment_idx,
                    "SiPM_idx"    : curr_SiPM_idx,
                    "trueID"         : translated_trueID,
                    "truePID"        : trueID,
                    "hitID"          : hitID,
                    "P"              : momentum,
                    "Theta"          : theta,
                    "Phi"            : phi,
                    "strip_x"        : strip_z,
                    "strip_y"        : strip_x,
                    "strip_z"        : strip_y,
                    "hit_x"          : hit_x,
                    "hit_y"          : hit_y,
                    "hit_z"          : hit_z,
                    "KMU_endpoint_x" : KMU_endpoint_x,
                    "KMU_endpoint_y" : KMU_endpoint_y,
                    "KMU_endpoint_z" : KMU_endpoint_z,
                    "Charge"         : curr_charge,
                    "Time"           : curr_timing
                }
                rows.append(new_row)
            ''' END IMPLEMENTATION '''
            #reset current samples for new key
            seen_keys.add(key)
            current_samples = [[],[]]
            current_samples.append(sample)
            curr_key = key
                


    end = time.time()
    ret_df = pd.DataFrame(rows)
    print(f"Creating DF took {end - begin} seconds")
    return ret_df

## Test: claude way

In [24]:
from itertools import groupby
from operator import itemgetter

# Create a key function that extracts the grouping fields
def get_key(item):
    metadata, _ = item
    return metadata[:4]  # event_idx, stave_idx, layer_idx, segment_idx

def test_newer_prepare_nn_input(processed_data = processed_data, normalizing_flow=model_compile, batch_size=50000, device='cuda',pixel_threshold = 5):
    processer = SiPMSignalProcessor()
    
    all_context = []
    all_time_pixels = []
    all_metadata = []
    num_pixel_list = ["num_pixels_high_z","num_pixels_low_z"]
    print("Processing data in new_prepare_nn_input...")
    for event_idx, event_data in tqdm(processed_data.items()):
        for stave_idx, stave_data in event_data.items():
            for layer_idx, layer_data in stave_data.items():
                for segment_idx, segment_data in layer_data.items():
                    trueID_list = []
                    for particle_id, particle_data in segment_data.items():
#                         print(f"keys of particle data: {particle_data.keys()}")
#                         print(f"types: {type(particle_data['z_pos'])},{type(particle_data['hittheta'])},{type(particle_data['hitmomentum'])}")
                        base_context = torch.tensor([particle_data['z_pos'], particle_data['hittheta'], particle_data['hitmomentum']], 
                                                    dtype=torch.float32)
                        base_time_pixels_low = torch.tensor([particle_data['time'], particle_data['num_pixels_low_z']], 
                                                        dtype=torch.float32)
                        base_time_pixels_high = torch.tensor([particle_data['time'], particle_data['num_pixels_high_z']], 
                                                        dtype=torch.float32)
                        if particle_data['trueID'] not in  trueID_list:
                            trueID_list.append(particle_data['trueID'])
                        for SiPM_idx in range(2):
                            z_pos = particle_data['z_pos']
                            context = base_context.clone()
                            context[0] = z_pos
                            num_pixel_tag = num_pixel_list[SiPM_idx]
                            all_context.append(context.repeat(particle_data[num_pixel_tag], 1))
                            if(SiPM_idx == 0):
                                all_time_pixels.append(base_time_pixels_high.repeat(particle_data[num_pixel_tag], 1))
                            else:
                                all_time_pixels.append(base_time_pixels_low.repeat(particle_data[num_pixel_tag], 1))
                            # Assuming particle_data is a dictionary-like object and trueID_list is defined
                            fields = [
                                'truemomentum', 'trueID', 'truePID', 'hitID', 'hitPID', 
                                'truetheta', 'truephi', 'strip_x', 'strip_y', 'strip_z', 
                                'hit_x', 'hit_y', 'hit_z', 'KMU_trueID', 'KMU_truePID', 
                                'KMU_true_phi', 'KMU_true_momentum_mag', 'KMU_endpoint_x', 
                                'KMU_endpoint_y', 'KMU_endpoint_z'
                            ]

                            # Print types of each particle_data field
#                             for field in fields:
#                                 value = particle_data.get(field, None)
#                                 print(f"{field}: {type(value)}")

#                             # Print the type of len(trueID_list)
#                             print(f"len(trueID_list): {type(len(trueID_list))}")

                            all_metadata.extend([(event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, particle_data['truemomentum'],particle_data['trueID'],particle_data['truePID'],particle_data['hitID'],particle_data['hitPID'],particle_data['truetheta'],particle_data['truephi'],particle_data['strip_x'],particle_data['strip_y'],particle_data['strip_z'],len(trueID_list),particle_data['hit_x'],particle_data['hit_y'],particle_data['hit_z'],particle_data['KMU_trueID'],particle_data['KMU_truePID'],particle_data['KMU_true_phi'],particle_data['KMU_true_momentum_mag'],particle_data['KMU_endpoint_x'],particle_data['KMU_endpoint_y'],particle_data['KMU_endpoint_z'])] * particle_data[num_pixel_tag])

    all_context = torch.cat(all_context)
    all_time_pixels = torch.cat(all_time_pixels)
    
    print("Sampling data...")
    sampled_data = []
    begin = time.time()
    for i in tqdm(range(0, len(all_context), batch_size)):
        batch_end = min(i + batch_size, len(all_context))
        batch_context = all_context[i:batch_end].to(device)
        batch_time_pixels = all_time_pixels[i:batch_end]
        
        with torch.no_grad():
            samples = abs(normalizing_flow.sample(num_samples=len(batch_context), context=batch_context)[0]).squeeze(1)
        
        sampled_data.extend(samples.cpu() + batch_time_pixels[:, 0])
    end = time.time()
    print(f"sampling took {end - begin} seconds")
    print("Processing signal...")
    processor = SiPMSignalProcessor()
    rows = []
    trueID_dict = {}
    trueID_dict_running_idx = 0
    event_first_hits = {}

    # Sort the data first (required for groupby)
    sorted_data = sorted(zip(all_metadata, sampled_data), key=get_key)

    # Process each group
    for key, group in groupby(sorted_data, key=get_key):
        event_idx, stave_idx, layer_idx, segment_idx = key

        # Initialize arrays for both SiPMs
        sipm_samples = [[], []]

        # Get the first metadata tuple for this group (they should all be the same within a group)
        first_item = next(group)
        metadata = first_item[0]
        _, _, _, _, _, momentum,trueID,truePID,hitID,hitPID,theta,phi,strip_x,strip_y,strip_z,trueID_list_len,hit_x,hit_y,hit_z,KMU_trueID,KMU_truePID,KMU_true_phi,KMU_true_momentum_mag,KMU_endpoint_x,KMU_endpoint_y,KMU_endpoint_z = metadata
        sipm_samples[first_item[0][4]].append(first_item[1])

        # Process rest of group
        for metadata, sample in group:
            sipm_idx = metadata[4]
            sipm_samples[sipm_idx].append(sample)

        # Process each SiPM's samples
        for curr_SiPM_idx in range(2):
            if not sipm_samples[curr_SiPM_idx]:
                continue

            photon_times = np.array(sipm_samples[curr_SiPM_idx]) * 10**(-9)
            time_arr, waveform = processor.generate_waveform(photon_times)
            timing = processor.get_pulse_timing(waveform, threshold=pixel_threshold)

            if timing is None:
                continue

            curr_charge = processor.integrate_charge(waveform) * 1e6
            curr_timing = timing * 1e8
            
            if event_idx not in event_first_hits or curr_timing < event_first_hits[event_idx][0]:
                event_first_hits[event_idx] = (curr_timing, strip_z, strip_x)

            # Handle trueID translation
            if trueID_list_len > 1:
                translated_trueID = -1
            else:
                event_true_key = (event_idx, trueID)
                if event_true_key not in trueID_dict:
                    trueID_dict[event_true_key] = trueID_dict_running_idx
                    trueID_dict_running_idx += 1
                translated_trueID = trueID_dict[event_true_key]

            # Create row
            rows.append({
                "event_idx": event_idx,
                "stave_idx": stave_idx,
                "layer_idx": layer_idx,
                "segment_idx": segment_idx,
                "SiPM_idx": curr_SiPM_idx,
                "trueID": translated_trueID,
                "truePID": truePID,
                "hitID": hitID,
                "P"              : momentum,
                "Theta"          : theta,
                "Phi"            : phi,
                "strip_x"        : strip_z,
                "strip_y"        : strip_x,
                "strip_z"        : strip_y,
                "hit_x"          : hit_x,
                "hit_y"          : hit_y,
                "hit_z"          : hit_z,
                "KMU_endpoint_x" : KMU_endpoint_x,
                "KMU_endpoint_y" : KMU_endpoint_y,
                "KMU_endpoint_z" : KMU_endpoint_z,
                "Charge"         : curr_charge,
                "Time"           : curr_timing
            })

    ret_df = pd.DataFrame(rows)
    
    ret_df['first_hit_time'] = ret_df['event_idx'].map(lambda x: event_first_hits[x][0])
    ret_df['first_hit_strip_z'] = ret_df['event_idx'].map(lambda x: event_first_hits[x][1])
    ret_df['first_hit_strip_x'] = ret_df['event_idx'].map(lambda x: event_first_hits[x][2])
    return ret_df
data = test_newer_prepare_nn_input()

Processing data in new_prepare_nn_input...


100%|██████████| 10/10 [00:01<00:00,  8.60it/s]


Sampling data...


100%|██████████| 21/21 [00:13<00:00,  1.52it/s]


sampling took 13.836615085601807 seconds
Processing signal...


In [23]:
from line_profiler import LineProfiler

profiler = LineProfiler()
profiler.add_function(test_newer_prepare_nn_input)
profiler.run('test_newer_prepare_nn_input()')
profiler.print_stats()
with open('profiling/Analyze_dev/test_first_hit_profile_jan_18_10events.txt', 'w') as f:
    profiler.print_stats(stream=f)

Processing data in new_prepare_nn_input...


100%|██████████| 10/10 [00:01<00:00,  6.09it/s]


Sampling data...


100%|██████████| 21/21 [00:13<00:00,  1.52it/s]


sampling took 13.783637285232544 seconds
Processing signal...
Timer unit: 1e-09 s

Total time: 36.0925 s
File: /tmp/ipykernel_4051716/562588580.py
Function: test_newer_prepare_nn_input at line 9

Line #      Hits         Time  Per Hit   % Time  Line Contents
     9                                           def test_newer_prepare_nn_input(processed_data = processed_data, normalizing_flow=model_compile, batch_size=50000, device='cuda',pixel_threshold = 5):
    10         1    2098220.0    2e+06      0.0      processer = SiPMSignalProcessor()
    11                                               
    12         1        893.0    893.0      0.0      all_context = []
    13         1        863.0    863.0      0.0      all_time_pixels = []
    14         1        664.0    664.0      0.0      all_metadata = []
    15         1       1088.0   1088.0      0.0      num_pixel_list = ["num_pixels_high_z","num_pixels_low_z"]
    16         1     172556.0 172556.0      0.0      print("Processing dat

## Testing

In [32]:
class SiPMSignalProcessor:
    def __init__(self, 
                 sampling_rate=40,  # 40 GHz sampling rate
                 tau_rise=1,       # 1 ns rise time
                 tau_fall=10,      # 50 ns fall time
                 window=200,       # 200 ns time window
                 cfd_delay=5,      # 5 ns delay for CFD
                 cfd_fraction=0.3):   # 30% fraction for CFD
        
        self.sampling_rate = sampling_rate
        self.tau_rise = tau_rise
        self.tau_fall = tau_fall
        self.window = window
        self.cfd_delay = cfd_delay
        self.cfd_fraction = cfd_fraction
        
        # Time array for single pulse shape
        self.time = np.arange(0, self.window, 1/self.sampling_rate)
        
        # Generate single pulse shape
        self.pulse_shape = self._generate_pulse_shape()
    
    def _generate_pulse_shape(self):
        """Generate normalized pulse shape for a single photon"""
        shape = (1 - np.exp(-self.time/self.tau_rise)) * np.exp(-self.time/self.tau_fall)
        return shape / np.max(shape)  # Normalize
    
    def generate_waveform(self, photon_times):
        """Generate waveform from list of photon arrival times"""
        # Initialize waveform array
        waveform = np.zeros_like(self.time)
        
        # Add pulse for each photon
        for t in photon_times:
            if 0 <= t < self.window:
                idx = int(t * self.sampling_rate)
                remaining_samples = len(self.time) - idx
                waveform[idx:] += self.pulse_shape[:remaining_samples]
        
        return self.time, waveform
    
    def integrate_charge(self, waveform, integration_start=0, integration_time=100):
        """Integrate charge in specified time window"""
        start_idx = int(integration_start * self.sampling_rate)
        end_idx = int((integration_start + integration_time) * self.sampling_rate)
        
        # Integrate using trapezoidal rule
        charge = np.trapezoid(waveform[start_idx:end_idx], dx=1/self.sampling_rate)
        return charge
    def constant_threshold_timing(self,waveform,threshold):
        for i in range(len(self.time)):
            if(waveform[i] > threshold):
                return self.time[i]
        return -1
        
    def apply_cfd(self, waveform, use_interpolation=True):
        """Apply Constant Fraction Discrimination to the waveform.

        Parameters:
        -----------
        waveform : numpy.ndarray
            Input waveform to process
        use_interpolation : bool, optional
            If True, use linear interpolation for sub-sample precision
            If False, return the sample index of zero crossing
            Default is True

        Returns:
        --------
        tuple (numpy.ndarray, float)
            CFD processed waveform and the zero-crossing time in seconds.
            If use_interpolation is False, zero-crossing time will be aligned
            to sample boundaries.
        """
        # Calculate delay in samples
        delay_samples = int(self.cfd_delay * self.sampling_rate)

        # Create delayed and attenuated versions of the waveform
        delayed_waveform = np.pad(waveform, (delay_samples, 0))[:-delay_samples]
        attenuated_waveform = -self.cfd_fraction * waveform

        # Calculate CFD waveform
        cfd_waveform = delayed_waveform + attenuated_waveform

        # Find all zero crossings
        zero_crossings = np.where(np.diff(np.signbit(cfd_waveform)))[0]

        if len(zero_crossings) < 2:  # Need at least two crossings for valid CFD
            return cfd_waveform, None

        # Find the rising edge of the original pulse
        pulse_start = np.where(waveform > np.max(waveform) * 0.1)[0]  # 10% threshold
        if len(pulse_start) == 0:
            return cfd_waveform, None
        pulse_start = pulse_start[0]

        # Find the first zero crossing that occurs after the pulse starts
        valid_crossings = zero_crossings[zero_crossings > pulse_start]
        if len(valid_crossings) == 0:
            return cfd_waveform, None

        crossing_idx = valid_crossings[0]

        if not use_interpolation:
            # Simply return the sample index converted to time
            crossing_time = crossing_idx / self.sampling_rate
        else:
            # Use linear interpolation for sub-sample precision
            y1 = cfd_waveform[crossing_idx]
            y2 = cfd_waveform[crossing_idx + 1]

            # Calculate fractional position of zero crossing
            fraction = -y1 / (y2 - y1)

            # Calculate precise crossing time
            crossing_time = (crossing_idx + fraction) / self.sampling_rate

        return cfd_waveform, crossing_time


    def get_pulse_timing(self, waveform, threshold=0.1):
        """Get pulse timing using CFD method with additional validation.
        
        Parameters:
        -----------
        waveform : numpy.ndarray
            Input waveform to analyze
        threshold : float
            Minimum amplitude threshold for valid pulses (relative to max amplitude)
            
        Returns:
        --------
        float or None
            Timestamp of the pulse in seconds, or None if no valid pulse found
        """
        # Check if pulse amplitude exceeds threshold
        max_amplitude = np.max(waveform)
        if max_amplitude < threshold:
            return None
            
        # Apply CFD
        _, crossing_time = self.apply_cfd(waveform)
        
        return crossing_time

In [94]:
processed_data = processed_data
normalizing_flow=model_compile
batch_size=50000
device='cuda'
pixel_threshold = 5
processer = SiPMSignalProcessor()

pixel_dict = {}

all_context = []
all_time_pixels = []
all_metadata = []
num_pixel_list = ["num_pixels_high_z","num_pixels_low_z"]
print("Preparing input for NF")
for event_idx, event_data in tqdm(processed_data.items()):
    for stave_idx, stave_data in event_data.items():
        for layer_idx, layer_data in stave_data.items():
            for segment_idx, segment_data in layer_data.items():
                trueID_list = []
                for particle_id, particle_data in segment_data.items():
#                         print(f"keys of particle data: {particle_data.keys()}")
#                         print(f"types: {type(particle_data['z_pos'])},{type(particle_data['hittheta'])},{type(particle_data['hitmomentum'])}")
                    base_context = torch.tensor([particle_data['z_pos'], particle_data['hittheta'], particle_data['hitmomentum']], 
                                                dtype=torch.float32)
                    base_time_pixels_low = torch.tensor([particle_data['time'], particle_data['num_pixels_low_z']], 
                                                    dtype=torch.float32)
                    base_time_pixels_high = torch.tensor([particle_data['time'], particle_data['num_pixels_high_z']], 
                                                    dtype=torch.float32)
                    if particle_data['trueID'] not in  trueID_list:
                        trueID_list.append(particle_data['trueID'])
                    for SiPM_idx in range(2):
                        z_pos = particle_data['z_pos']
                        context = base_context.clone()
                        context[0] = z_pos
                        num_pixel_tag = num_pixel_list[SiPM_idx]
                        all_context.append(context.repeat(particle_data[num_pixel_tag], 1))
                        if(SiPM_idx == 0):
                            all_time_pixels.append(base_time_pixels_high.repeat(particle_data[num_pixel_tag], 1))
                        else:
                            all_time_pixels.append(base_time_pixels_low.repeat(particle_data[num_pixel_tag], 1))
                        # Assuming particle_data is a dictionary-like object and trueID_list is defined
                        fields = [
                            'truemomentum', 'trueID', 'truePID', 'hitID', 'hitPID', 
                            'truetheta', 'truephi', 'strip_x', 'strip_y', 'strip_z', 
                            'hit_x', 'hit_y', 'hit_z', 'KMU_trueID', 'KMU_truePID', 
                            'KMU_true_phi', 'KMU_true_momentum_mag', 'KMU_endpoint_x', 
                            'KMU_endpoint_y', 'KMU_endpoint_z'
                        ]

                        all_metadata.extend([(event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, particle_data['truemomentum'],particle_data['trueID'],particle_data['truePID'],particle_data['hitID'],particle_data['hitPID'],particle_data['truetheta'],particle_data['truephi'],particle_data['strip_x'],particle_data['strip_y'],particle_data['strip_z'],len(trueID_list),particle_data['hit_x'],particle_data['hit_y'],particle_data['hit_z'],particle_data['KMU_trueID'],particle_data['KMU_truePID'],particle_data['KMU_true_phi'],particle_data['KMU_true_momentum_mag'],particle_data['KMU_endpoint_x'],particle_data['KMU_endpoint_y'],particle_data['KMU_endpoint_z'])] * particle_data[num_pixel_tag])
                        particle_key = (event_idx,stave_idx,layer_idx,segment_idx)
                        if(particle_key in pixel_dict):
                            pixel_dict[particle_key][0] +=particle_data["num_pixels_high_z"]
                            pixel_dict[particle_key][1] +=particle_data["num_pixels_low_z"]
                        else:
                            pixel_dict[particle_key] =[particle_data["num_pixels_high_z"],particle_data["num_pixels_low_z"]]
all_context = torch.cat(all_context)
all_time_pixels = torch.cat(all_time_pixels)

print("Sampling data...")
sampled_data = []
begin = time.time()
for i in tqdm(range(0, len(all_context), batch_size)):
    batch_end = min(i + batch_size, len(all_context))
    batch_context = all_context[i:batch_end].to(device)
    batch_time_pixels = all_time_pixels[i:batch_end]

    with torch.no_grad():
        samples = abs(normalizing_flow.sample(num_samples=len(batch_context), context=batch_context)[0]).squeeze(1)

    sampled_data.extend(samples.cpu() + batch_time_pixels[:, 0])
end = time.time()
print(f"sampling took {end - begin} seconds")
print("Processing signal...")

Preparing input for NF


100%|██████████| 10/10 [00:01<00:00,  6.92it/s]


Sampling data...


100%|██████████| 21/21 [00:15<00:00,  1.39it/s]

sampling took 15.160471200942993 seconds
Processing signal...





In [99]:
# VARIABLES FOR SAVING DATA AS DF
processer = SiPMSignalProcessor()
rows = []

seen_keys = set()
curr_key = (-1,-1,-1,-1)

pixel_counter = np.zeros(2,dtype=int)
processor = SiPMSignalProcessor()

translated_trueID = 0
trueID_dict_running_idx = 0
trueID_dict = {}

begin = time.time()

#     sample_idx = 0
for (event_idx,stave_idx, layer_idx,segment_idx, SiPM_idx, momentum,trueID,truePID,hitID,hitPID,theta,phi,strip_x,strip_y,strip_z,trueID_list_len,hit_x,hit_y,hit_z,KMU_trueID,KMU_truePID,KMU_true_phi,KMU_true_momentum_mag,KMU_endpoint_x,KMU_endpoint_y,KMU_endpoint_z), sample in zip(all_metadata, sampled_data):

    #progress bar
#         floor_percent = int(np.floor(len(sampled_data) / 100))
#         if(sample_idx % floor_percent == 0):
#             curr_time = time.time()
#             print(f"Signal Processing is now {int(np.floor(sample_idx / len(sampled_data) * 100))}% complete (time elapsed: {curr_time - begin})")
#             clear_output(wait = True)
#         sample_idx += 1

    # Work with all samples of one SiPM together
    key = (event_idx, stave_idx, layer_idx, segment_idx)
    if key in seen_keys:
        if key == curr_key:
            current_samples[SiPM_idx][pixel_counter[SiPM_idx]] = sample
            pixel_counter[SiPM_idx] = pixel_counter[SiPM_idx] + 1
        else:
            continue
            print(f"ERROR: key: {key} | curr_key: {curr_key}")
    # First key
    elif curr_key == (-1,-1,-1,-1):
        current_samples = [np.empty(pixel_dict[key][0]),np.empty(pixel_dict[key][1])]
        current_samples[SiPM_idx][pixel_counter[SiPM_idx]] = sample
        pixel_counter[SiPM_idx] = pixel_counter[SiPM_idx] + 1
        seen_keys.add(key)
        curr_key = key
    # End of curr_key: perform calc
    else:
        #calculate photon stuff on current_samples

        '''IMPLEMENTING PREDICTION INPUT PULSE SEGMENT BY SEGMENT'''
        curr_event_idx = curr_key[0]
        curr_stave_idx = curr_key[1]
        curr_layer_idx = curr_key[2]
        curr_segment_idx = curr_key[3]
        for curr_SiPM_idx in range(2):
            trigger = False
            photon_times_not_np = current_samples[curr_SiPM_idx]
            photon_times = np.array(photon_times_not_np)
            if(len(photon_times) > 0):
                time_arr,waveform = processor.generate_waveform(photon_times)
                timing = processer.get_pulse_timing(waveform,threshold = pixel_threshold)
                if(timing is not None):
                    #scale inputs to avoid exploding gradients
                    curr_charge = processor.integrate_charge(waveform) / 100
                    curr_timing = timing /10
                    trigger = True
                #skip segments that don't pass the threshold
                else:
                    continue
            #skip segments with no photon hits
            else:
                continue
            if(trueID_list_len > 1):
                translated_trueID = -1
            else:
                if((event_idx,trueID) not in trueID_dict):
                    trueID_dict[(event_idx,trueID)] = trueID_dict_running_idx
                    trueID_dict_running_idx += 1
                translated_trueID = trueID_dict[(event_idx,trueID)]
            new_row = {
                "event_idx"      : curr_event_idx,
                "stave_idx"      : curr_stave_idx,
                "layer_idx"      : curr_layer_idx,
                "segment_idx"    : curr_segment_idx,
                "SiPM_idx"    : curr_SiPM_idx,
                "trueID"         : translated_trueID,
                "truePID"        : trueID,
                "hitID"          : hitID,
                "P"              : momentum,
                "Theta"          : theta,
                "Phi"            : phi,
                "strip_x"        : strip_z,
                "strip_y"        : strip_x,
                "strip_z"        : strip_y,
                "hit_x"          : hit_x,
                "hit_y"          : hit_y,
                "hit_z"          : hit_z,
                "KMU_endpoint_x" : KMU_endpoint_x,
                "KMU_endpoint_y" : KMU_endpoint_y,
                "KMU_endpoint_z" : KMU_endpoint_z,
                "Charge"         : curr_charge,
                "Time"           : curr_timing
            }
            rows.append(new_row)
        ''' END IMPLEMENTATION '''
        #reset current samples for new key
        seen_keys.add(key)
        pixel_counter = pixel_counter = np.zeros(2,dtype=int)
        current_samples = [np.empty(pixel_dict[key][0]),np.empty(pixel_dict[key][1])]
        current_samples[SiPM_idx][pixel_counter[SiPM_idx]] = sample
        pixel_counter[SiPM_idx] = pixel_counter[SiPM_idx] + 1
        curr_key = key


end = time.time()
ret_df = pd.DataFrame(rows)
print(f"Creating DF took {end - begin} seconds")

Creating DF took 19.41303563117981 seconds


In [97]:
current_samples[SiPM_idx]

array([4.64141715e-310, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       0.00000000e+000, 6.95310027e-310, 4.94065646e-324, 4.94065646e-324,
       6.90719599e-310, 6.95310027e-310, 1.48219694e-323, 6.32404027e-322,
       4.94065646e-324, 6.90719585e-310, 4.94065646e-324, 0.00000000e+000,
       0.00000000e+000, 6.95310027e-310, 0.00000000e+000, 6.90719585e-310,
       6.95310027e-310, 6.90709997e-310, 2.12199579e-314, 6.95310027e-310,
       2.12199585e-314, 6.95293141e-310, 2.12199579e-314, 6.90709630e-310,
       6.95310027e-310, 6.95310027e-310, 7.63918485e-313, 4.94065646e-323,
       0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 0.00000000e+000,
       6.90697860e-310, 1.40107117e+001, 1.71535511e+001, 1.56047277e+001,
       1.48282042e+001, 1.63265953e+001, 1.33217020e+001, 1.87257118e+001,
       2.18567371e+001, 1

In [98]:
pixel_counter

array([32, 66])

In [82]:
pixel_counter[SiPM_idx]

np.int64(1)

In [83]:
key

('0', '1', '2', '32')

In [84]:
SiPM_idx

1

In [88]:
pixel_dict[('0','1','2','32')]

[0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 2,
 1,
 4,
 1,
 4,
 1,
 2,
 1,
 2,
 0,
 0,
 0,
 0,
 1,
 4,
 1,
 4,
 3,
 7,
 3,
 7,
 2,
 5,
 2,
 5,
 1,
 3,
 1,
 3,
 0,
 0,
 0,
 0,
 1,
 5,
 1,
 5,
 0,
 0,
 0,
 0,
 1,
 6,
 1,
 6,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 3,
 1,
 3]

In [87]:
event_idx

'0'