Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow dynamic function creation for dtype copy #395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 63 additions & 1 deletion strax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
TODO: file numba issue.
"""
import numpy as np
import typing as ty
import numba


__all__ = ('interval_dtype raw_record_dtype record_dtype hit_dtype peak_dtype '
'DIGITAL_SUM_WAVEFORM_CHANNEL DEFAULT_RECORD_LENGTH '
'time_fields time_dt_fields hitlet_dtype hitlet_with_data_dtype').split()
'time_fields time_dt_fields hitlet_dtype hitlet_with_data_dtype '
'copy_to_buffer').split()

DIGITAL_SUM_WAVEFORM_CHANNEL = -1
DEFAULT_RECORD_LENGTH = 110
Expand Down Expand Up @@ -201,3 +205,61 @@ def peak_dtype(n_channels=100, n_sum_wv_samples=200, n_widths=11):
(('Maximum interior goodness of split',
'max_goodness_of_split'), np.float32),
]


def copy_to_buffer(source: np.ndarray,
buffer: np.ndarray,
func_name: str,
field_names: ty.Tuple[str] = None):
"""
Copy the data from the source to the destination e.g. raw_records to
records. To this end, we dynamically create the njitted function
with the name 'func_name' (should start with "_").

:param source: array of input
:param destination: array of buffer to fill with values from input
:param func_name: how to store the dynamically created function.
Should start with an _underscore
:param field_names: dtype names to copy (if none, use all in the
source)
"""
if np.shape(source) != np.shape(buffer):
raise ValueError('Source should be the same length as the buffer')

if field_names is None:
# To lazy to specify what to copy, just do all
field_names = tuple(n for n in source.dtype.names
if n in buffer.dtype.names)
elif not any([n in buffer.dtype.names for n in field_names]):
raise ValueError('Trying to copy dtypes that are not in the '
'destination')

if not func_name.startswith('_'):
raise ValueError('Start function with "_"')

if func_name not in globals():
# Create a numba function for us
_create_copy_function(buffer.dtype, field_names, func_name)

globals()[func_name](source, buffer)


def _create_copy_function(res_dtype, field_names, func_name):
"""Write out a numba-njitted function to copy data"""
# Cannot cache = True since we are creating the function dynamically
JoranAngevaare marked this conversation as resolved.
Show resolved Hide resolved
code = f'''
@numba.njit(nogil=True)
def {func_name}(source, result):
for i in range(len(source)):
s = source[i]
r = result[i]
'''
for d in field_names:
if d not in res_dtype.names:
raise ValueError('This cannot happen')
if np.shape(res_dtype[d]):
# Copy array fields as arrays
code += f'\n r["{d}"][:] = s["{d}"][:]'
else:
code += f'\n r["{d}"] = s["{d}"]'
exec(code, globals())
19 changes: 4 additions & 15 deletions strax/processing/pulse_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import numba
from scipy.ndimage import convolve1d

from warnings import warn
import strax
export, __all__ = strax.exporter()
__all__ += ['NO_RECORD_LINK']
Expand Down Expand Up @@ -76,25 +76,14 @@ def raw_to_records(raw_records):
len(raw_records),
dtype=strax.record_dtype(
record_length_from_dtype(raw_records.dtype)))
copy_raw_records(raw_records, records)
strax.copy_to_buffer(raw_records, records, '_copy_raw_records')
return records


# Numpy record arrays have a rowwise memory layout, so filling it
# rowwise should be faster.
@export
@numba.njit(nogil=True, cache=True)
def copy_raw_records(old, new):
for i in range(len(old)):
r = old[i]
r2 = new[i]
r2['channel'] = r['channel']
r2['dt'] = r['dt']
r2['time'] = r['time']
r2['length'] = r['length']
r2['pulse_length'] = r['pulse_length']
r2['record_i'] = r['record_i']
r2['data'][:] = r['data'][:]
warn('Deprecated, use strax.copy_to_buffer')
strax.copy_to_buffer(old, new, '_copy_raw_records')


@export
Expand Down
9 changes: 9 additions & 0 deletions tests/test_general_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,12 @@ def test_overlap_indices(a1, n_a, b1, n_b):
(true_a_inds[0], true_a_inds[-1] + 1),
(true_b_inds[0], true_b_inds[-1] + 1))
assert found == expected


@hypothesis.settings(deadline=None)
@hypothesis.given(strax.testutils.several_fake_records)
def test_raw_to_records(r):
buffer = np.zeros(len(r), r.dtype)
strax.copy_to_buffer(r, buffer, "_test_r_to_buffer")
if len(r):
assert np.all(buffer == r)