From e8b03fc2cc3f9c7baf833f3fe42ae265d1807d15 Mon Sep 17 00:00:00 2001 From: Dominick Cichon Date: Mon, 17 Oct 2022 14:51:16 +0200 Subject: [PATCH] Add capability for building summed waveform over channel subset (#565) * Added data fields for top/bottom array waveforms. * Added data_top and data_bot fields to peaks. Added capability of building separate waveforms (separation according to channel index). Added test for new feature. TODO: Remove data_bot when all tests are finished. * Switched to boolean flags in store_downsampled_waveform because numba does not like hashmaps with dynamic instead of static keys. * Refactored store_downsampled_waveform to downsample summed and separate array waveforms at the same time (similar to what Peter did) to avoid issues due to changing peak['length']. * removing bottom array and leaving only top one * Fix test * add digitize top argument Co-authored-by: Joran R. Angevaare Co-authored-by: Andrii Terliuk Co-authored-by: Daniel Wenz <43881800+WenzDaniel@users.noreply.github.com> Co-authored-by: Joran R. Angevaare Co-authored-by: Joran Angevaare --- strax/dtypes.py | 9 +++++-- strax/processing/peak_building.py | 39 ++++++++++++++++++++++++------ strax/processing/peak_merging.py | 34 ++++++++++++++++++++------ strax/processing/peak_splitting.py | 14 +++++++---- tests/test_peak_processing.py | 9 +++++-- 5 files changed, 81 insertions(+), 24 deletions(-) diff --git a/strax/dtypes.py b/strax/dtypes.py index 6e049aa37..240e7189b 100644 --- a/strax/dtypes.py +++ b/strax/dtypes.py @@ -175,14 +175,14 @@ def hitlet_with_data_dtype(n_samples=2): return dtype + additional_fields -def peak_dtype(n_channels=100, n_sum_wv_samples=200, n_widths=11): +def peak_dtype(n_channels=100, n_sum_wv_samples=200, n_widths=11, digitize_top=True): """Data type for peaks - ranges across all channels in a detector Remember to set channel to -1 (todo: make enum) """ if n_channels == 1: raise ValueError("Must have more than one channel") # Otherwise array changes shape?? badness ensues - return peak_interval_dtype + [ + dtype = peak_interval_dtype + [ # For peaklets this is likely to be overwritten: (('Classification of the peak(let)', 'type'), np.int8), @@ -209,6 +209,11 @@ def peak_dtype(n_channels=100, n_sum_wv_samples=200, n_widths=11): (('Maximum interior goodness of split', 'max_goodness_of_split'), np.float32), ] + if digitize_top: + top_field = (('Waveform data in PE/sample (not PE/ns!), top array', + 'data_top'), np.float32, n_sum_wv_samples) + dtype.insert(5, top_field) + return dtype def copy_to_buffer(source: np.ndarray, diff --git a/strax/processing/peak_building.py b/strax/processing/peak_building.py index e4518f155..2b1a3a53d 100644 --- a/strax/processing/peak_building.py +++ b/strax/processing/peak_building.py @@ -126,39 +126,50 @@ def find_peaks(hits, adc_to_pe, @export @numba.jit(nopython=True, nogil=True, cache=True) -def store_downsampled_waveform(p, wv_buffer): - """Downsample the waveform in buffer and store it in p['data'] +def store_downsampled_waveform(p, wv_buffer, store_in_data_top=False, + wv_buffer_top=np.ones(1, dtype=np.float32)): + """Downsample the waveform in buffer and store it in p['data'] and + in p['data_top'] if indicated to do so. :param p: Row of a strax peak array, or compatible type. Note that p['dt'] is adjusted to match the downsampling. :param wv_buffer: numpy array containing sum waveform during the peak at the input peak's sampling resolution p['dt']. - - The number of samples to take from wv_buffer, and thus the downsampling - factor, is determined from p['dt'] and p['length']. + :param store_in_data_top: Boolean which indicates whether to also store + into p['data_top'] When downsampling results in a fractional number of samples, the peak is shortened rather than extended. This causes data loss, but it is necessary to prevent overlaps between peaks. """ + n_samples = len(p['data']) + downsample_factor = int(np.ceil(p['length'] / n_samples)) if downsample_factor > 1: # Compute peak length after downsampling. # Do not ceil: see docstring! p['length'] = int(np.floor(p['length'] / downsample_factor)) + if store_in_data_top: + p['data_top'][:p['length']] = \ + wv_buffer_top[:p['length'] * downsample_factor] \ + .reshape(-1, downsample_factor) \ + .sum(axis=1) p['data'][:p['length']] = \ wv_buffer[:p['length'] * downsample_factor] \ .reshape(-1, downsample_factor) \ .sum(axis=1) p['dt'] *= downsample_factor else: + if store_in_data_top: + p['data_top'][:p['length']] = wv_buffer_top[:p['length']] p['data'][:p['length']] = wv_buffer[:p['length']] @export @numba.jit(nopython=True, nogil=True, cache=True) -def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_indices=None): +def sum_waveform(peaks, hits, records, record_links, adc_to_pe, n_top_channels=0, + select_peaks_indices=None): """Compute sum waveforms for all peaks in peaks. Only builds summed waveform other regions in which hits were found. This is required to avoid any bias due to zero-padding and baselining. @@ -169,6 +180,7 @@ def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_ind to record_i. :param records: Records to be used to build peaks. :param record_links: Tuple of previous and next records. + :param n_top_channels: Number of top array channels. :param select_peaks_indices: Indices of the peaks for partial processing. In the form of np.array([np.int, np.int, ..]). If None (default), all the peaks are used for the summation. @@ -191,6 +203,9 @@ def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_ind # Need a little more even for downsampling.. swv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32) + if n_top_channels > 0: + twv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32) + n_channels = len(peaks[0]['area_per_channel']) area_per_channel = np.zeros(n_channels, dtype=np.float32) @@ -206,6 +221,9 @@ def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_ind p_length = p['length'] swv_buffer[:min(2 * p_length, len(swv_buffer))] = 0 + if n_top_channels > 0: + twv_buffer[:min(2 * p_length, len(twv_buffer))] = 0 + # Clear area and area per channel # (in case find_peaks already populated them) area_per_channel *= 0 @@ -272,11 +290,18 @@ def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_ind hit_data *= adc_to_pe[ch] swv_buffer[p_start:p_end] += hit_data + if n_top_channels > 0: + if ch < n_top_channels: + twv_buffer[p_start:p_end] += hit_data + area_pe = hit_data.sum() area_per_channel[ch] += area_pe p['area'] += area_pe - store_downsampled_waveform(p, swv_buffer) + if n_top_channels > 0: + store_downsampled_waveform(p, swv_buffer, True, twv_buffer) + else: + store_downsampled_waveform(p, swv_buffer) p['n_saturated_channels'] = p['saturated_channel'].sum() p['area_per_channel'][:] = area_per_channel diff --git a/strax/processing/peak_merging.py b/strax/processing/peak_merging.py index 709b9ab5d..c91b9d163 100644 --- a/strax/processing/peak_merging.py +++ b/strax/processing/peak_merging.py @@ -13,8 +13,8 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, :param peaks: Record array of strax peak dtype. :param start_merge_at: Indices to start merge at :param end_merge_at: EXCLUSIVE indices to end merge at - :param max_buffer: Maximum number of samples in the sum_waveforms of - the resulting peaks (after merging). + :param max_buffer: Maximum number of samples in the sum_waveforms + and other waveforms of the resulting peaks (after merging). Peaks must be constructed based on the properties of constituent peaks, it being too time-consuming to revert to records/hits. @@ -24,6 +24,7 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, # Do the merging. Could numbafy this to optimize, probably... buffer = np.zeros(max_buffer, dtype=np.float32) + buffer_top = np.zeros(max_buffer, dtype=np.float32) for new_i, new_p in enumerate(new_peaks): @@ -39,7 +40,7 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, new_p['length'] = \ (strax.endtime(last_peak) - new_p['time']) // common_dt - # re-zero relevant part of buffer (overkill? not sure if + # re-zero relevant part of buffers (overkill? not sure if # this saves much time) buffer[:min( int( @@ -50,14 +51,25 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, ), len(buffer) )] = 0 + buffer_top[:min( + int( + ( + last_peak['time'] + + (last_peak['length'] * old_peaks['dt'].max()) + - first_peak['time']) / common_dt + ), + len(buffer_top) + )] = 0 for p in old_peaks: - # Upsample the sum waveform into the buffer + # Upsample the sum and top/bottom array waveforms into their buffers upsample = p['dt'] // common_dt n_after = p['length'] * upsample i0 = (p['time'] - new_p['time']) // common_dt buffer[i0: i0 + n_after] = \ np.repeat(p['data'][:p['length']], upsample) / upsample + buffer_top[i0: i0 + n_after] = \ + np.repeat(p['data_top'][:p['length']], upsample) / upsample # Handle the other peak attributes new_p['area'] += p['area'] @@ -65,8 +77,9 @@ def merge_peaks(peaks, start_merge_at, end_merge_at, new_p['n_hits'] += p['n_hits'] new_p['saturated_channel'][p['saturated_channel'] == 1] = 1 - # Downsample the buffer into new_p['data'] - strax.store_downsampled_waveform(new_p, buffer) + # Downsample the buffers into new_p['data'], new_p['data_top'], + # and new_p['data_bot'] + strax.store_downsampled_waveform(new_p, buffer, True, buffer_top) new_p['n_saturated_channels'] = new_p['saturated_channel'].sum() @@ -140,7 +153,7 @@ def _replace_merged(result, orig, merge, skip_windows): @export @numba.njit(cache=True, nogil=True) -def add_lone_hits(peaks, lone_hits, to_pe): +def add_lone_hits(peaks, lone_hits, to_pe, n_top_channels=0): """ Function which adds information from lone hits to peaks if lone hit is inside a peak (e.g. after merging.). Modifies peak area and data @@ -149,6 +162,7 @@ def add_lone_hits(peaks, lone_hits, to_pe): :param peaks: Numpy array of peaks :param lone_hits: Numpy array of lone_hits :param to_pe: Gain values to convert lone hit area into PE. + :param n_top_channels: Number of top array channels. """ fully_contained_index = strax.fully_contained_in(lone_hits, peaks) @@ -160,6 +174,10 @@ def add_lone_hits(peaks, lone_hits, to_pe): p['area'] += lh_area p['area_per_channel'][lh_i['channel']] += lh_area - # Add lone hit as delta pulse to waveform: + # Add lone hit as delta pulse to waveforms: index = (p['time'] - lh_i['time'])//p['dt'] p['data'][index] += lh_area + + if n_top_channels > 0: + if lh_i['channel'] < n_top_channels: + p['data_top'][index] += lh_area diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 316eb0699..590adff34 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -7,7 +7,7 @@ @export def split_peaks(peaks, hits, records, rlinks, to_pe, algorithm='local_minimum', - data_type='peaks', **kwargs): + data_type='peaks', n_top_channels=0, **kwargs): """Return peaks split according to algorithm, with waveforms summed and widths computed. @@ -27,6 +27,7 @@ def split_peaks(peaks, hits, records, rlinks, to_pe, algorithm='local_minimum', :param data_type: 'peaks' or 'hitlets'. Specifies whether to use sum_wavefrom or get_hitlets_data to compute the waveform of the new split peaks/hitlets. + :param n_top_channels: Number of top array channels. :param result_dtype: dtype of the result. Any other options are passed to the algorithm. @@ -37,7 +38,8 @@ def split_peaks(peaks, hits, records, rlinks, to_pe, algorithm='local_minimum', data_type_is_not_supported = data_type not in ('hitlets', 'peaks') if data_type_is_not_supported: raise TypeError(f'Data_type "{data_type}" is not supported.') - return splitter(peaks, hits, records, rlinks, to_pe, data_type, **kwargs) + return splitter(peaks, hits, records, rlinks, to_pe, data_type, + n_top_channels=n_top_channels, **kwargs) NO_MORE_SPLITS = -9999999 @@ -55,6 +57,7 @@ class PeakSplitter: new split peaks/hitlets. :param do_iterations: maximum number of times peaks are recursively split. :param min_area: Minimum area to do split. Smaller peaks are not split. + :param n_top_channels: Number of top array channels. The function find_split_points(), implemented in each subclass defines the algorithm, which takes in a peak's waveform and @@ -65,7 +68,7 @@ class PeakSplitter: find_split_args_defaults: tuple def __call__(self, peaks, hits, records, rlinks, to_pe, data_type, - do_iterations=1, min_area=0, **kwargs): + do_iterations=1, min_area=0, n_top_channels=0, **kwargs): if not len(records) or not len(peaks) or not do_iterations: return peaks @@ -102,7 +105,7 @@ def __call__(self, peaks, hits, records, rlinks, to_pe, data_type, if is_split.sum() != 0: # Found new peaks: compute basic properties if data_type == 'peaks': - strax.sum_waveform(new_peaks, hits, records, rlinks, to_pe) + strax.sum_waveform(new_peaks, hits, records, rlinks, to_pe, n_top_channels) strax.compute_widths(new_peaks) elif data_type == 'hitlets': # Add record fields here @@ -111,7 +114,8 @@ def __call__(self, peaks, hits, records, rlinks, to_pe, data_type, # ... and recurse (if needed) new_peaks = self(new_peaks, hits, records, rlinks, to_pe, data_type, do_iterations=do_iterations - 1, - min_area=min_area, **kwargs) + min_area=min_area, + n_top_channels=n_top_channels, **kwargs) if np.any(new_peaks['length'] == 0): raise ValueError('Want to add a new zero-length peak after splitting!') diff --git a/tests/test_peak_processing.py b/tests/test_peak_processing.py index 3ebc76bd6..dd4185a78 100644 --- a/tests/test_peak_processing.py +++ b/tests/test_peak_processing.py @@ -64,6 +64,7 @@ def test__build_hit_waveform(records): def test_sum_waveform(records): # Make a single big peak to contain all the records n_ch = 100 + n_top_channels = 50 rlinks = strax.record_links(records) hits = strax.find_hits(records, np.ones(n_ch)) @@ -77,7 +78,7 @@ def test_sum_waveform(records): min_area=0, min_channels=1, max_duration=10_000_000) - strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch)) + strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch), n_top_channels) for p in peaks: # Area measures must be consistent @@ -93,9 +94,13 @@ def test_sum_waveform(records): assert np.all(p['data'][:p['length']] == sum_wv) + # top array waveforms must be equal or smaller than total waveform + if not np.array_equal(p['data_top'], p['data']): + np.testing.assert_array_less(p['data_top'], p['data']) # Finally check that we also can use a selection of peaks to sum - strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch), select_peaks_indices=np.array([0])) + strax.sum_waveform(peaks, hits, records, rlinks, np.ones(n_ch), n_top_channels, + select_peaks_indices=np.array([0])) @settings(deadline=None)