Skip to content

Commit

Permalink
Merge pull request #304 from jorana/apply_function_to_data
Browse files Browse the repository at this point in the history
Apply function to data
  • Loading branch information
JoranAngevaare committed Aug 17, 2020
2 parents 0c85fd4 + 55e0110 commit 9dc02cf
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
'possibly be removed, see issue #246'),
strax.Option(name='free_options', default=tuple(),
help='Do not warn if any of these options are passed, '
'even when no registered plugin takes them.')
'even when no registered plugin takes them.'),
strax.Option(name='apply_data_function', default=tuple(),
help='Apply a function to the data prior to returning the'
'data. The function should take two positional arguments: '
'func(<data>, <targets>).')
)
@export
class Context:
Expand Down Expand Up @@ -1058,7 +1062,10 @@ def get_array(self, run_id: ty.Union[str, tuple, list],
max_workers=max_workers,
**kwargs)
results = [x.data for x in source]
return np.concatenate(results)

results = np.concatenate(results)
results = self._apply_function(results, targets)
return results

def accumulate(self,
run_id: ty.Union[str, tuple, list],
Expand All @@ -1079,6 +1086,10 @@ def accumulate(self,
* None -> nothing accumulated
If not provided, the identify function is used.
NB: Additionally and independently, if there are any functions registered
under context_config['apply_data_function'] these are applied first directly
after loading the data.
:param fields: Fields of the function output to accumulate.
If not provided, all output fields will be accumulated.
Expand Down Expand Up @@ -1111,6 +1122,7 @@ def function(arr):

for chunk in self.get_iter(run_id, targets, **kwargs):
data = chunk.data
data = self._apply_function(data, targets)

if n_chunks == 0:
result['start'] = chunk.start
Expand Down Expand Up @@ -1283,6 +1295,29 @@ def _check_forbidden(self):
Otherwise, try to make it a tuple"""
self.context_config['forbid_creation_of'] = strax.to_str_tuple(
self.context_config['forbid_creation_of'])

def _apply_function(self, data, targets):
"""
Apply functions stored in the context config to any data that is returned via
get_array, get_df or accumulate. Functions stored in
context_config['apply_data_function'] should take exactly two positional
arguments: data and targets.
:param data: Any type of data
:param targets: list/tuple of strings of data type names to get
:return: the data after applying the function(s)
"""
apply_functions = self.context_config['apply_data_function']
if not isinstance(apply_functions, (tuple, list)):
raise ValueError(f"apply_data_function in context config should be tuple of "
f"functions. Instead got {apply_functions}")
for function in apply_functions:
if not hasattr(function, '__call__'):
raise TypeError(f'apply_data_function in the context_config got '
f'{function} but expected callable function with two '
f'positional arguments: f(data, targets).')
# Make sure that the function takes two arguments (data and targets)
data = function(data, targets)
return data

@classmethod
def add_method(cls, f):
Expand Down

0 comments on commit 9dc02cf

Please sign in to comment.