Skip to content

Commit

Permalink
Fix multi-record_i problem
Browse files Browse the repository at this point in the history
  • Loading branch information
WenzDaniel committed Jul 19, 2021
1 parent b0ce5d9 commit 81fdc11
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ def sum_waveform(peaks, hits, records, adc_to_pe, select_peaks_indices=None):
# Records exhausted before peaks exhausted
# TODO: this is a strange case, maybe raise warning/error?
break

# Find hits which contribute to peak:
start = last_hit_seen
for last_hit_seen in range(last_hit_seen, len(hits)):
h = hits[last_hit_seen]

if h['time'] > (p['time'] + p['length'] * p['dt']):
break

if h['time'] < p['time']:
start += 1
continue

hits_in_peak = _sort_by_record_i(hits[start:last_hit_seen+1]) # +1 since last_hit_seen is exclusive

# Scan over records that overlap
for right_r_i in range(left_r_i, len(records)):
Expand Down Expand Up @@ -245,22 +259,19 @@ def sum_waveform(peaks, hits, records, adc_to_pe, select_peaks_indices=None):
# TODO: check numba does casting correctly here!

pe_waveform = np.zeros(len(r['data']))
for h_i in range(last_hit_seen, len(hits)):
for _, h in enumerate(hits_in_peak):
# Loop over hits only sum waveform on region inside hits.
# This is needed as we would otherwise add up baseline
# leading to some bias.
h = hits[h_i]

if h['record_i'] < right_r_i:
last_hit_seen += 1
continue
if h['record_i'] > right_r_i:
break

h_start = h['left_integration']
h_end = h['right_integration']
pe_waveform[h_start:h_end] += (multiplier * r['data'][h_start:h_end] + bl_fpart)
last_hit_seen += 1

pe_waveform *= adc_to_pe[ch]
swv_buffer[p_start:p_end] += pe_waveform[r_start:r_end]
Expand All @@ -274,6 +285,10 @@ def sum_waveform(peaks, hits, records, adc_to_pe, select_peaks_indices=None):
p['n_saturated_channels'] = p['saturated_channel'].sum()
p['area_per_channel'][:] = area_per_channel

@numba.njit
def _sort_by_record_i(hits):
sort_index = np.argsort(hits['record_i'])
return hits[sort_index]

@export
def find_peak_groups(peaks, gap_threshold,
Expand Down

0 comments on commit 81fdc11

Please sign in to comment.