Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update apply function to data & test #422

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 19 additions & 8 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
'even when no registered plugin takes them.'),
strax.Option(name='apply_data_function', default=tuple(),
help='Apply a function to the data prior to returning the'
'data. The function should take two positional arguments: '
'func(<data>, <targets>).')
'data. The function should take three positional arguments: '
'func(<data>, <run_id>, <targets>).')
)
@export
class Context:
Expand Down Expand Up @@ -960,6 +960,11 @@ def get_iter(self, run_id: str,
if not isinstance(result, strax.Chunk):
raise ValueError(f"Got type {type(result)} rather than "
f"a strax Chunk from the processor!")
# Apply functions known to contexts if any.
result.data = self._apply_function(result.data,
run_id,
targets)

result.data = self.apply_selection(
result.data,
selection_str=selection_str,
Expand All @@ -968,7 +973,6 @@ def get_iter(self, run_id: str,
time_selection=time_selection)
self._update_progress_bar(
pbar, t_start, t_end, n_chunks, result.end)

yield result
_p.close()

Expand Down Expand Up @@ -1148,7 +1152,6 @@ def get_array(self, run_id: ty.Union[str, tuple, list],
results = [x.data for x in source]

results = np.concatenate(results)
results = self._apply_function(results, targets)
return results

def accumulate(self,
Expand Down Expand Up @@ -1210,7 +1213,6 @@ def function(arr):
for chunk in self.get_iter(run_id, targets,
**kwargs):
data = chunk.data
data = self._apply_function(data, targets)

if n_chunks == 0:
result['start'] = chunk.start
Expand Down Expand Up @@ -1381,17 +1383,26 @@ def _check_forbidden(self):
self.context_config['forbid_creation_of'] = strax.to_str_tuple(
self.context_config['forbid_creation_of'])

def _apply_function(self, data, targets):
def _apply_function(self,
chunk_data: np.ndarray,
run_id: ty.Union[str, tuple, list],
targets: ty.Union[str, tuple, list],
) -> np.ndarray:
"""
Apply functions stored in the context config to any data that is returned via
get_array, get_df or accumulate. Functions stored in
context_config['apply_data_function'] should take exactly two positional
arguments: data and targets.
:param data: Any type of data
:param run_id: run_id of the data.
:param targets: list/tuple of strings of data type names to get
:return: the data after applying the function(s)
"""
apply_functions = self.context_config['apply_data_function']
if hasattr(apply_functions, '__call__'):
# Apparently someone did not read the docstring and inserted
# a single function instead of a list.
apply_functions = [apply_functions]
if not isinstance(apply_functions, (tuple, list)):
raise ValueError(f"apply_data_function in context config should be tuple of "
f"functions. Instead got {apply_functions}")
Expand All @@ -1401,8 +1412,8 @@ def _apply_function(self, data, targets):
f'{function} but expected callable function with two '
f'positional arguments: f(data, targets).')
# Make sure that the function takes two arguments (data and targets)
data = function(data, targets)
return data
chunk_data = function(chunk_data, run_id, targets)
return chunk_data

def copy_to_frontend(self,
run_id: str,
Expand Down
111 changes: 111 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import strax
from strax.testutils import Records, run_id
import tempfile
import numpy as np
from hypothesis import given
import hypothesis.strategies as st
import typing as ty


def _apply_function_to_data(function
) -> ty.Tuple[np.ndarray, np.ndarray]:
"""
Inner core to test apply function to data
:param function: some callable function that takes thee positional
arguments
:return: records, records with function applied
"""
with tempfile.TemporaryDirectory() as temp_dir:
st = strax.Context(storage=strax.DataDirectory(temp_dir,
deep_scan=True),
register=[Records])

# First load normal records
records = st.get_array(run_id, 'records')

# Next update the context and apply
st.set_context_config({'apply_data_function': function})
changed_records = st.get_array(run_id, 'records')
return records, changed_records


def test_apply_pass_to_data():
"""
What happens if we apply a function that does nothing (well,
nothing hopefully)

:return: None
"""

def nothing(data, r, t):
return data

r, r_changed = _apply_function_to_data(nothing)
assert np.all(r == r_changed)


@given(st.integers(min_value=-10, max_value=10))
def test_apply_ch_shift_to_data(magic_shift: int):
"""
Apply some magic shift number to the channel field and check the results
:param magic_shift: some number to check that we can shift the
channel field with
:return: None
"""

def shift_channel(data, r, t):
"""Add a magic number to the channel field in the data"""
res = data.copy()
res['channel'] += magic_shift
return res

r, r_changed = _apply_function_to_data(shift_channel)
assert len(r) == len(r_changed)
assert np.all((r_changed['channel'] - (r['channel'] + magic_shift)) == 0)


def test_apply_drop_data():
"""
What if we drop random portions of the data, do we get the right results?
:return: None
"""

class Drop:
"""Small class to keep track of the number of dropped rows"""

kept = []

def drop(self, data, r, t):
"""Drop a random portion of the data"""
# I was too lazy to write a strategy to get the right number
# of random drops based on the input records.
keep = np.random.randint(0, 2, len(data)).astype(np.bool_)

# Keep in mind that we do this on a per chunk basis!
self.kept += [keep]
res = data.copy()[keep]
return res

# Init the bookkeeping class
dropper = Drop()
r, r_changed = _apply_function_to_data(dropper.drop)

# The number of records should e
assert np.all(r[np.concatenate(dropper.kept)] == r_changed)


def test_accumulate():
"""
Test the st.accumulate function. Should add the results and
accumulate per chunk. Lets add channels and verify the results
are correct.

:return: None
"""
with tempfile.TemporaryDirectory() as temp_dir:
st = strax.Context(storage=strax.DataDirectory(temp_dir,
deep_scan=True),
register=[Records])
channels_from_array = np.sum(st.get_array(run_id, 'records')['channel'])
channels = st.accumulate(run_id, 'records', fields='channel')['channel']
assert (channels_from_array == channels)