Skip to content

Commit

Permalink
Fixing peaklet baseline bias (#486)
Browse files Browse the repository at this point in the history
* Fixing peaklet baseline bias

* Fix multi-record_i problem

* Revert "Fix multi-record_i problem"

This reverts commit 81fdc11.

* Fix record_i in multi-peaks

* Fixed bug of wrong area compared to wf

* Revert order for record_i and time check

* Add found next start in case peak ends

* Changed beyond peak case for peak splitting within hit

* Modified splitting test accordingly

* Extended tests according to change.

* Rename todo to please codefactor...

* Fix empty inputs

* Allow integration bounds beyond record

* Make find_hit_integration_bounds non private.

* Unify return

* Remove return as things are modified inplace

* Add test for hit integration bounds

* Refactored summed waveform to include hit integration bounds.

* Updated splitting accordingly

* forgot left_hit_i

* Minor fixes

* Command in le/re bounds outside record

* Fix peak area estimate

* Fix small bug n saturated channels

* Add additional arguments to function calls in tests

* Added test and small clean up.

* Fixed test

* Updated test to obey peak_finding rules for hits

* Remove print statements

* Change hit_waveform to buffer

* Updated doc string

* Small fix

* Revert "Updated doc string"

This reverts commit 10ca471.

* Add docs again

* Updated doc string of peak splitting

* Update doc-string

* Refactored function changed doc string removed todo

* Remove comment
  • Loading branch information
WenzDaniel committed Aug 25, 2021
1 parent 4adacdc commit d77b241
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 115 deletions.
205 changes: 138 additions & 67 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,18 @@ def store_downsampled_waveform(p, wv_buffer):

@export
@numba.jit(nopython=True, nogil=True, cache=True)
def sum_waveform(peaks, records, adc_to_pe, select_peaks_indices=None):
"""Compute sum waveforms for all peaks in peaks
def sum_waveform(peaks, hits, records, record_links, adc_to_pe, select_peaks_indices=None):
"""Compute sum waveforms for all peaks in peaks. Only builds summed
waveform other regions in which hits were found. This is required
to avoid any bias due to zero-padding and baselining.
Will downsample sum waveforms if they do not fit in per-peak buffer
:arg select_peaks_indices: Indices of the peaks for partial
:param peaks: Peaks for which the summed waveform should be build.
:param hits: Hits which are inside peaks. Must be sorted according
to record_i.
:param records: Records to be used to build peaks.
:param record_links: Tuple of previous and next records.
:param select_peaks_indices: Indices of the peaks for partial
processing. In the form of np.array([np.int, np.int, ..]). If
None (default), all the peaks are used for the summation.
Expand All @@ -173,18 +180,21 @@ def sum_waveform(peaks, records, adc_to_pe, select_peaks_indices=None):
if not len(select_peaks_indices):
return
dt = records[0]['dt']
n_samples_record = len(records[0]['data'])
prev_record_i, next_record_i = record_links

# Big buffer to hold even largest sum waveforms
# Need a little more even for downsampling..
swv_buffer = np.zeros(peaks['length'].max() * 2, dtype=np.float32)

# Index of first record that could still contribute to subsequent peaks
# Records before this do not need to be considered anymore
left_r_i = 0

n_channels = len(peaks[0]['area_per_channel'])
area_per_channel = np.zeros(n_channels, dtype=np.float32)

# Hit index for hits in peaks
left_h_i = 0
# Create hit waveform buffer
hit_waveform = np.zeros(hits['length'].max(), dtype=np.float32)

for peak_i in select_peaks_indices:
p = peaks[peak_i]
# Clear the relevant part of the swv buffer for use
Expand All @@ -197,55 +207,68 @@ def sum_waveform(peaks, records, adc_to_pe, select_peaks_indices=None):
area_per_channel *= 0
p['area'] = 0

# Find first record that contributes to this peak
for left_r_i in range(left_r_i, len(records)):
r = records[left_r_i]
# Find first hit that contributes to this peak
for left_h_i in range(left_h_i, len(hits)):
h = hits[left_h_i]
# TODO: need test that fails if we replace < with <= here
if p['time'] < r['time'] + r['length'] * dt:
if p['time'] < h['time'] + h['length'] * dt:
break
else:
# Records exhausted before peaks exhausted
# Hits exhausted before peaks exhausted
# TODO: this is a strange case, maybe raise warning/error?
break

# Scan over records that overlap
for right_r_i in range(left_r_i, len(records)):
r = records[right_r_i]
ch = r['channel']
multiplier = 2**r['amplitude_bit_shift']
assert p['dt'] == r['dt'], "Records and peaks must have same dt"
# Scan over hits that overlap with peak
for right_h_i in range(left_h_i, len(hits)):
h = hits[right_h_i]
record_i = h['record_i']
ch = h['channel']
assert p['dt'] == h['dt'], "Hits and peaks must have same dt"

shift = (p['time'] - r['time']) // dt
n_r = r['length']
n_p = p_length
shift = (p['time'] - h['time']) // dt
n_samples_hit = h['length']
n_samples_peak = p_length

if shift <= -n_p:
# Record is completely to the right of the peak;
if shift <= -n_samples_peak:
# Hit is completely to the right of the peak;
# we've seen all overlapping records
break

if n_r <= shift:
if n_samples_hit <= shift:
# The (real) data in this record does not actually overlap
# with the peak
# (although a previous, longer record did overlap)
# (although a previous, longer hit did overlap)
continue

(r_start, r_end), (p_start, p_end) = strax.overlap_indices(
r['time'] // dt, n_r,
p['time'] // dt, n_p)
# Get overlapping samples between hit and peak:
(h_start, h_end), (p_start, p_end) = strax.overlap_indices(
h['time'] // dt, n_samples_hit,
p['time'] // dt, n_samples_peak)

hit_waveform[:] = 0

# Get record which belongs to main part of hit (wo integration bounds):
r = records[record_i]

max_in_record = r['data'][r_start:r_end].max() * multiplier
p['saturated_channel'][ch] |= np.int8(max_in_record >= np.int16(r['baseline']))
is_saturated = _build_hit_waveform(h, r, hit_waveform)

bl_fpart = r['baseline'] % 1
# TODO: check numba does casting correctly here!
pe_waveform = adc_to_pe[ch] * (
multiplier * r['data'][r_start:r_end]
+ bl_fpart)
# Now check if we also have to go to prev/next record due to integration bounds.
# If bounds are outside of peak we chop when building the summed waveform later.
if h['left_integration'] < 0 and prev_record_i[record_i] != -1:
r = records[prev_record_i[record_i]]
is_saturated |= _build_hit_waveform(h, r, hit_waveform)

swv_buffer[p_start:p_end] += pe_waveform
if h['right_integration'] > n_samples_record and next_record_i[record_i] != -1:
r = records[next_record_i[record_i]]
is_saturated |= _build_hit_waveform(h, r, hit_waveform)

area_pe = pe_waveform.sum()
p['saturated_channel'][ch] |= is_saturated

hit_data = hit_waveform[h_start:h_end]
hit_data *= adc_to_pe[ch]
swv_buffer[p_start:p_end] += hit_data

area_pe = hit_data.sum()
area_per_channel[ch] += area_pe
p['area'] += area_pe

Expand All @@ -255,6 +278,30 @@ def sum_waveform(peaks, records, adc_to_pe, select_peaks_indices=None):
p['area_per_channel'][:] = area_per_channel


@numba.njit(cache=True, nogil=True)
def _build_hit_waveform(hit, record, hit_waveform):
"""
Adds information for overlapping record and hit to hit_waveform.
Updates hit_waveform inplace. Result is still in ADC counts.
:returns: Boolean if record saturated within the hit.
"""
(h_start_record, h_end_record), (r_start, r_end) = strax.overlap_indices(
hit['time'] // hit['dt'], hit['length'],
record['time'] // record['dt'], record['length'])

# Get record properties:
record_data = record['data'][r_start:r_end]
multiplier = 2**record['amplitude_bit_shift']
bl_fpart = record['baseline'] % 1
max_in_record = record_data.max() * multiplier

# Build hit waveform:
hit_waveform[h_start_record:h_end_record] = (multiplier * record_data + bl_fpart)

return np.int8(max_in_record >= np.int16(record['baseline']))


@export
def find_peak_groups(peaks, gap_threshold,
left_extension=0, right_extension=0,
Expand Down Expand Up @@ -292,50 +339,74 @@ def find_peak_groups(peaks, gap_threshold,
##
# Lone hit integration
##

@export
@numba.njit(nogil=True, cache=True)
def _find_hit_integration_bounds(
lone_hits, peaks, records, save_outside_hits, n_channels):
""""Update lone hits to include integration bounds
save_outside_hits: in ns!!
def find_hit_integration_bounds(
hits, excluded_intervals, records, save_outside_hits, n_channels,
allow_bounds_beyond_records=False):
""""Update (lone) hits to include integration bounds. Please note
that time and length of the original hit are not changed!
:param hits: Hits or lone hits which should be extended by
integration bounds.
:param excluded_intervals: Regions in which hits should not extend to. E.g. Peaks
for lone hits. If not needed just put a zero length
strax.time_fields array.
:param records: Records in which hits were found.
:param save_outside_hits: Hit extension to the left and right in ns
not samples!!
:param n_channels: Number of channels for given detector.
:param allow_bounds_beyond_records: If true extend left/
right_integration beyond record boundaries. E.g. to negative
samples for left side.
"""
result = np.zeros((len(lone_hits), 2), dtype=np.int64)
if not len(lone_hits):
result = np.zeros((len(hits), 2), dtype=np.int64)
if not len(hits):
return result

# By default, use save_outside_hits to determine bounds
result[:, 0] = lone_hits['time'] - save_outside_hits[0]
result[:, 1] = strax.endtime(lone_hits) + save_outside_hits[1]
result[:, 0] = hits['time'] - save_outside_hits[0]
result[:, 1] = strax.endtime(hits) + save_outside_hits[1]

NO_EARLIER_HIT = -1
last_hit_index = np.ones(n_channels, dtype=np.int32) * NO_EARLIER_HIT

n_peaks = len(peaks)
n_intervals = len(excluded_intervals)
FAR_AWAY = 9223372036_854775807 # np.iinfo(np.int64).max, April 2262
peak_i = 0
interval_i = 0

for hit_i, h in enumerate(lone_hits):
for hit_i, h in enumerate(hits):
ch = h['channel']

# Find end of previous peak and start of next peak
# (note peaks are disjoint from any lone hit, even though
# lone hits may not be disjoint from each other)
while peak_i < n_peaks and peaks[peak_i]['time'] < h['time']:
peak_i += 1
prev_p_end = strax.endtime(peaks[peak_i - 1]) if peak_i != 0 else 0
next_p_start = peaks[peak_i]['time'] if peak_i != n_peaks else FAR_AWAY
while interval_i < n_intervals and excluded_intervals[interval_i]['time'] < h['time']:
interval_i += 1

if interval_i != 0:
prev_interval_end = strax.endtime(excluded_intervals[interval_i - 1])
else:
prev_interval_end = 0

if interval_i != n_intervals:
next_interval_start = excluded_intervals[interval_i]['time']
else:
next_interval_start = FAR_AWAY

# Ensure we do not integrate parts of peaks
# or (at least for now) beyond the record in which the hit was found
r = records[h['record_i']]
result[hit_i][0] = max(prev_p_end,
r['time'],
result[hit_i][0])
result[hit_i][1] = min(next_p_start,
strax.endtime(r),
result[hit_i][1])
if allow_bounds_beyond_records:
result[hit_i][0] = max(prev_interval_end,
result[hit_i][0])
result[hit_i][1] = min(next_interval_start,
result[hit_i][1])
else:
result[hit_i][0] = max(prev_interval_end,
r['time'],
result[hit_i][0])
result[hit_i][1] = min(next_interval_start,
strax.endtime(r),
result[hit_i][1])

if last_hit_index[ch] != NO_EARLIER_HIT:
# Ensure previous hit does not integrate the over-threshold region
Expand All @@ -350,9 +421,9 @@ def _find_hit_integration_bounds(
last_hit_index[ch] = hit_i

# Convert to index in record and store
t0 = records[lone_hits['record_i']]['time']
dt = records[lone_hits['record_i']]['dt']
for hit_i, h in enumerate(lone_hits):
t0 = records[hits['record_i']]['time']
dt = records[hits['record_i']]['dt']
for hit_i, h in enumerate(hits):
h['left_integration'] = (result[hit_i, 0] - t0[hit_i]) // dt[hit_i]
h['right_integration'] = (result[hit_i, 1] - t0[hit_i]) // dt[hit_i]

Expand All @@ -373,8 +444,8 @@ def integrate_lone_hits(
TODO: this doesn't extend the integration range beyond record boundaries
"""
_find_hit_integration_bounds(
lone_hits, peaks, records, save_outside_hits, n_channels)
find_hit_integration_bounds(lone_hits, peaks, records, save_outside_hits,
n_channels)
for hit_i, h in enumerate(lone_hits):
r = records[h['record_i']]
start, end = h['left_integration'], h['right_integration']
Expand Down
19 changes: 14 additions & 5 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@


@export
def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
def split_peaks(peaks, hits, records, rlinks, to_pe, algorithm='local_minimum',
data_type='peaks', **kwargs):
"""Return peaks split according to algorithm, with waveforms summed
and widths computed.
Note:
Can also be used for hitlets splitting with local_minimum
splitter. Just put hitlets instead of peaks.
:param peaks: Original peaks. Sum waveform must have been built
and properties must have been computed (if you use them)
:param hits: Hits found in records. (or None in case of hitlets
splitting.)
:param records: Records from which peaks were built
:param rlinks: strax.record_links for given records
(or None in case of hitlets splitting.)
:param to_pe: ADC to PE conversion factor array (of n_channels)
:param algorithm: 'local_minimum' or 'natural_breaks'.
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
Expand All @@ -29,7 +37,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
data_type_is_not_supported = data_type not in ('hitlets', 'peaks')
if data_type_is_not_supported:
raise TypeError(f'Data_type "{data_type}" is not supported.')
return splitter(peaks, records, to_pe, data_type, **kwargs)
return splitter(peaks, hits, records, rlinks, to_pe, data_type, **kwargs)


NO_MORE_SPLITS = -9999999
Expand All @@ -40,6 +48,7 @@ class PeakSplitter:
:param peaks: Original peaks. Sum waveform must have been built
and properties must have been computed (if you use them).
:param records: Records from which peaks were built.
:param rlinks: strax.record_links for given records.
:param to_pe: ADC to PE conversion factor array (of n_channels).
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
sum_waveform or get_hitlets_data to compute the waveform of the
Expand All @@ -55,7 +64,7 @@ class PeakSplitter:
"""
find_split_args_defaults: tuple

def __call__(self, peaks, records, to_pe, data_type,
def __call__(self, peaks, hits, records, rlinks, to_pe, data_type,
do_iterations=1, min_area=0, **kwargs):
if not len(records) or not len(peaks) or not do_iterations:
return peaks
Expand Down Expand Up @@ -93,14 +102,14 @@ def __call__(self, peaks, records, to_pe, data_type,
if is_split.sum() != 0:
# Found new peaks: compute basic properties
if data_type == 'peaks':
strax.sum_waveform(new_peaks, records, to_pe)
strax.sum_waveform(new_peaks, hits, records, rlinks, to_pe)
strax.compute_widths(new_peaks)
elif data_type == 'hitlets':
# Add record fields here
new_peaks = strax.sort_by_time(new_peaks) # Hitlets are not necessarily sorted after splitting
new_peaks = strax.get_hitlets_data(new_peaks, records, to_pe)
# ... and recurse (if needed)
new_peaks = self(new_peaks, records, to_pe, data_type,
new_peaks = self(new_peaks, hits, records, rlinks, to_pe, data_type,
do_iterations=do_iterations - 1,
min_area=min_area, **kwargs)
if np.any(new_peaks['length'] == 0):
Expand Down

0 comments on commit d77b241

Please sign in to comment.