-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix memory leak peaksplitting #309
Changes from 4 commits
3cfd250
d800492
cb52792
7127c79
17b4c4e
9d123f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', | |
:param records: Records from which peaks were built | ||
:param to_pe: ADC to PE conversion factor array (of n_channels) | ||
:param algorithm: 'local_minimum' or 'natural_breaks'. | ||
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use | ||
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use | ||
sum_wavefrom or get_hitlets_data to compute the waveform of | ||
the new split peaks/hitlets. | ||
:param result_dtype: dtype of the result. | ||
|
@@ -25,7 +25,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', | |
""" | ||
splitter = dict(local_minimum=LocalMinimumSplitter, | ||
natural_breaks=NaturalBreaksSplitter)[algorithm]() | ||
|
||
if data_type == 'hitlets': | ||
# This is only needed once. | ||
_, next_ri = strax.record_links(records) | ||
|
@@ -62,7 +62,7 @@ class PeakSplitter: | |
find_split_args_defaults: tuple | ||
|
||
def __call__(self, peaks, records, to_pe, data_type, | ||
next_ri=None, do_iterations=1, min_area=0, **kwargs): | ||
next_ri=None, do_iterations=1, min_area=0, **kwargs): | ||
if not len(records) or not len(peaks) or not do_iterations: | ||
return peaks | ||
|
||
|
@@ -86,42 +86,28 @@ def __call__(self, peaks, records, to_pe, data_type, | |
|
||
is_split = np.zeros(len(peaks), dtype=np.bool_) | ||
|
||
# data_kind specific_outputs: | ||
if data_type == 'peaks': | ||
@numba.njit | ||
def specific_output(r, p, split_i, bonus_output): | ||
if split_i == NO_MORE_SPLITS: | ||
p['max_goodness_of_split'] = bonus_output | ||
# although the iteration will end anyway afterwards: | ||
r['max_gap'] = -1 # Too lazy to compute this | ||
split_function = {'peaks': self._split_peaks, | ||
'hitlets': self._split_hitlets} | ||
if data_type not in split_function: | ||
raise ValueError(f'Data_type "{data_type}" is not supported.') | ||
|
||
elif data_type == 'hitlets': | ||
@numba.njit | ||
def specific_output(r, p, split_i, bonus_output): | ||
if split_i == NO_MORE_SPLITS: | ||
return | ||
r['record_i'] = p['record_i'] | ||
else: | ||
raise TypeError(f'Unknown data_type. "{data_type}" is not supported.') | ||
new_peaks = self._split_peaks( | ||
new_peaks = split_function[data_type]( | ||
# Numba doesn't like self as argument, but it's ok with functions... | ||
split_finder=self.find_split_points, | ||
peaks=peaks, | ||
is_split=is_split, | ||
orig_dt=records[0]['dt'], | ||
min_area=min_area, | ||
args_options=tuple(args_options), | ||
specific_output=specific_output, | ||
result_dtype=peaks.dtype) | ||
|
||
if is_split.sum() != 0: | ||
# Found new peaks: compute basic properties | ||
if data_type == 'peaks': | ||
strax.sum_waveform(new_peaks, records, to_pe) | ||
elif data_type == 'hitlets': | ||
# Add record fields here | ||
strax.update_new_hitlets(new_peaks, records, next_ri, to_pe) | ||
else: | ||
raise ValueError(f'Data_type "{data_type}" is not supported.') | ||
|
||
strax.compute_widths(new_peaks) | ||
|
||
|
@@ -138,7 +124,7 @@ def specific_output(r, p, split_i, bonus_output): | |
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) | ||
@numba.jit(nopython=True, nogil=True) | ||
def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, | ||
specific_output, args_options, | ||
args_options, | ||
_result_buffer=None, result_dtype=None): | ||
"""Loop over peaks, pass waveforms to algorithm, construct | ||
new peaks if and where a split occurs. | ||
|
@@ -155,26 +141,21 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, | |
w = p['data'][:p['length']] | ||
for split_i, bonus_output in split_finder( | ||
w, p['dt'], p_i, *args_options): | ||
|
||
# This is a bit odd here. Due tp the specific_outputs we have to get r | ||
# although we may not need it at all, but I do not see any nice way around | ||
# this. | ||
r = new_peaks[offset] | ||
specific_output(r, p, split_i, bonus_output) | ||
if split_i == NO_MORE_SPLITS: | ||
# No idea if this if-statement can be integrated into | ||
# specific return | ||
p['max_goodness_of_split'] = bonus_output | ||
# although the iteration will end anyway afterwards: | ||
continue | ||
|
||
is_split[p_i] = True | ||
r = new_peaks[offset] | ||
r['time'] = p['time'] + prev_split_i * p['dt'] | ||
r['channel'] = p['channel'] | ||
# Set the dt to the original (lowest) dt first; | ||
# this may change when the sum waveform of the new peak | ||
# is computed | ||
r['dt'] = orig_dt | ||
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt | ||
|
||
r['max_gap'] = -1 # Too lazy to compute this | ||
if r['length'] <= 0: | ||
print(p['data']) | ||
print(prev_split_i, split_i) | ||
|
@@ -189,6 +170,54 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, | |
|
||
yield offset | ||
|
||
@staticmethod | ||
@strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4)) | ||
@numba.jit(nopython=True, nogil=True) | ||
def _split_hitlets(split_finder, hits, orig_dt, is_split, min_area, | ||
args_options, | ||
_result_buffer=None, result_dtype=None): | ||
"""Loop over peaks, pass waveforms to algorithm, construct | ||
new peaks if and where a split occurs. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add here some warning that changes in this function might also be applied in _split_peaks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Daniel, did so in 17b4c4e |
||
new_hits = _result_buffer | ||
offset = 0 | ||
|
||
for p_i, p in enumerate(hits): | ||
if p['area'] < min_area: | ||
continue | ||
|
||
prev_split_i = 0 | ||
w = p['data'][:p['length']] | ||
for split_i, bonus_output in split_finder( | ||
w, p['dt'], p_i, *args_options): | ||
if split_i == NO_MORE_SPLITS: | ||
return | ||
r['record_i'] = p['record_i'] | ||
|
||
is_split[p_i] = True | ||
r = new_hits[offset] | ||
r['time'] = p['time'] + prev_split_i * p['dt'] | ||
r['channel'] = p['channel'] | ||
# Set the dt to the original (lowest) dt first; | ||
# this may change when the sum waveform of the new peak | ||
# is computed | ||
r['dt'] = orig_dt | ||
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt | ||
r['max_gap'] = -1 # Too lazy to compute this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hitlets does not support 'max_gap' so simply remove this line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 17b4c4e |
||
if r['length'] <= 0: | ||
print(p['data']) | ||
print(prev_split_i, split_i) | ||
raise ValueError("Attempt to create invalid peak!") | ||
|
||
offset += 1 | ||
if offset == len(new_hits): | ||
yield offset | ||
offset = 0 | ||
|
||
prev_split_i = split_i | ||
|
||
yield offset | ||
|
||
@staticmethod | ||
def find_split_points(w, dt, peak_i, *args_options): | ||
"""This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter | ||
|
@@ -267,7 +296,7 @@ class NaturalBreaksSplitter(PeakSplitter): | |
close as we can get to it given the peaks sampling) on either side. | ||
""" | ||
find_split_args_defaults = ( | ||
('threshold', None), # will be a numpy array of len(peaks) | ||
('threshold', None), # will be a numpy array of len(peaks) | ||
('normalize', False), | ||
('split_low', False), | ||
('filter_wing_width', 0)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have to change these lines into arguments, since
peaks
is calledhits
in_split_hitlets