Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak peaksplitting #309

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def sum_waveform(peaks, records, adc_to_pe):
p['time'] // dt, n_p)

max_in_record = r['data'][r_start:r_end].max() * multiplier
p['saturated_channel'][ch] |= int(max_in_record >= r['baseline'])
p['saturated_channel'][ch] |= np.int8(max_in_record >= r['baseline'])

bl_fpart = r['baseline'] % 1
# TODO: check numba does casting correctly here!
Expand Down
107 changes: 71 additions & 36 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
:param records: Records from which peaks were built
:param to_pe: ADC to PE conversion factor array (of n_channels)
:param algorithm: 'local_minimum' or 'natural_breaks'.
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
sum_wavefrom or get_hitlets_data to compute the waveform of
the new split peaks/hitlets.
:param result_dtype: dtype of the result.
Expand All @@ -25,7 +25,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
"""
splitter = dict(local_minimum=LocalMinimumSplitter,
natural_breaks=NaturalBreaksSplitter)[algorithm]()

if data_type == 'hitlets':
# This is only needed once.
_, next_ri = strax.record_links(records)
Expand Down Expand Up @@ -62,7 +62,7 @@ class PeakSplitter:
find_split_args_defaults: tuple

def __call__(self, peaks, records, to_pe, data_type,
next_ri=None, do_iterations=1, min_area=0, **kwargs):
next_ri=None, do_iterations=1, min_area=0, **kwargs):
if not len(records) or not len(peaks) or not do_iterations:
return peaks

Expand All @@ -86,42 +86,28 @@ def __call__(self, peaks, records, to_pe, data_type,

is_split = np.zeros(len(peaks), dtype=np.bool_)

# data_kind specific_outputs:
if data_type == 'peaks':
@numba.njit
def specific_output(r, p, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
r['max_gap'] = -1 # Too lazy to compute this
split_function = {'peaks': self._split_peaks,
'hitlets': self._split_hitlets}
if data_type not in split_function:
raise ValueError(f'Data_type "{data_type}" is not supported.')

elif data_type == 'hitlets':
@numba.njit
def specific_output(r, p, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
return
r['record_i'] = p['record_i']
else:
raise TypeError(f'Unknown data_type. "{data_type}" is not supported.')
new_peaks = self._split_peaks(
new_peaks = split_function[data_type](
# Numba doesn't like self as argument, but it's ok with functions...
split_finder=self.find_split_points,
peaks=peaks,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to change these lines into arguments, since peaks is called hits in _split_hitlets

is_split=is_split,
orig_dt=records[0]['dt'],
min_area=min_area,
args_options=tuple(args_options),
specific_output=specific_output,
result_dtype=peaks.dtype)

if is_split.sum() != 0:
# Found new peaks: compute basic properties
if data_type == 'peaks':
strax.sum_waveform(new_peaks, records, to_pe)
elif data_type == 'hitlets':
# Add record fields here
strax.update_new_hitlets(new_peaks, records, next_ri, to_pe)
else:
raise ValueError(f'Data_type "{data_type}" is not supported.')

strax.compute_widths(new_peaks)

Expand All @@ -138,12 +124,14 @@ def specific_output(r, p, split_i, bonus_output):
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
specific_output, args_options,
args_options,
_result_buffer=None, result_dtype=None):
"""Loop over peaks, pass waveforms to algorithm, construct
new peaks if and where a split occurs.
"""
# TODO NEEDS TESTS!
# NB: code very similar to _split_hitlets see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_peaks = _result_buffer
offset = 0

Expand All @@ -155,26 +143,21 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
w = p['data'][:p['length']]
for split_i, bonus_output in split_finder(
w, p['dt'], p_i, *args_options):

# This is a bit odd here. Due tp the specific_outputs we have to get r
# although we may not need it at all, but I do not see any nice way around
# this.
r = new_peaks[offset]
specific_output(r, p, split_i, bonus_output)
if split_i == NO_MORE_SPLITS:
# No idea if this if-statement can be integrated into
# specific return
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
continue

is_split[p_i] = True
r = new_peaks[offset]
r['time'] = p['time'] + prev_split_i * p['dt']
r['channel'] = p['channel']
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt

r['max_gap'] = -1 # Too lazy to compute this
if r['length'] <= 0:
print(p['data'])
print(prev_split_i, split_i)
Expand All @@ -189,6 +172,58 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,

yield offset

@staticmethod
@strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_hitlets(split_finder, peaks, orig_dt, is_split, min_area,
args_options,
_result_buffer=None, result_dtype=None):
"""Loop over hits, pass waveforms to algorithm, construct
new hits if and where a split occurs.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add here some warning that changes in this function might also be applied in _split_peaks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Daniel, did so in 17b4c4e

# TODO NEEDS TESTS!
# NB: code very similar to _split_peaks see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_hits = _result_buffer
offset = 0

for h_i, h in enumerate(peaks):
if h['area'] < min_area:
continue

prev_split_i = 0
w = h['data'][:h['length']]
for split_i, bonus_output in split_finder(
w, h['dt'], h_i, *args_options):
if split_i == NO_MORE_SPLITS:
continue

is_split[h_i] = True
r = new_hits[offset]
r['time'] = h['time'] + prev_split_i * h['dt']
r['channel'] = h['channel']
# Hitlet specific
r['record_i'] = h['record_i']
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * h['dt'] / orig_dt
if r['length'] <= 0:
print(h['data'])
print(prev_split_i, split_i)
raise ValueError("Attempt to create invalid hitlet!")

offset += 1
if offset == len(new_hits):
yield offset
offset = 0

prev_split_i = split_i

yield offset

@staticmethod
def find_split_points(w, dt, peak_i, *args_options):
"""This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter
Expand Down Expand Up @@ -267,7 +302,7 @@ class NaturalBreaksSplitter(PeakSplitter):
close as we can get to it given the peaks sampling) on either side.
"""
find_split_args_defaults = (
('threshold', None), # will be a numpy array of len(peaks)
('threshold', None), # will be a numpy array of len(peaks)
('normalize', False),
('split_low', False),
('filter_wing_width', 0))
Expand Down