Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple targets to be computed simultaneously #408

Merged
merged 6 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 44 additions & 14 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def register_all(self, module):

for x in dir(module):
x = getattr(module, x)
if type(x) != type(type):
if not isinstance(x, type(type)):
continue
if issubclass(x, strax.Plugin):
self.register(x)
Expand Down Expand Up @@ -525,6 +525,19 @@ def get_plugin(data_kind):

return plugins

@staticmethod
def _get_end_targets(plugins: dict) -> ty.Tuple[str]:
"""
Get the datatype that is provided by a plugin but not depended
on by any other plugin
"""
provides = [prov for p in plugins.values()
for prov in strax.to_str_tuple(p.provides)]
depends_on = [dep for p in plugins.values()
for dep in strax.to_str_tuple(p.depends_on)]
uniques = list(set(provides) ^ set(depends_on))
return strax.to_str_tuple(uniques)

@property
def _find_options(self):

Expand Down Expand Up @@ -570,7 +583,7 @@ def _get_partial_loader_for(self, key, time_range=None, chunk_number=None):

def get_components(self, run_id: str,
targets=tuple(), save=tuple(),
time_range=None, chunk_number=None
time_range=None, chunk_number=None,
) -> strax.ProcessorComponents:
"""Return components for setting up a processor
{get_docs}
Expand All @@ -579,15 +592,11 @@ def get_components(self, run_id: str,
save = strax.to_str_tuple(save)
targets = strax.to_str_tuple(targets)

# Although targets is a tuple, we only support one target at the moment
# we could just make it a string!
assert len(targets) == 1, f"Found {len(targets)} instead of 1 target"
if len(targets[0]) == 1:
raise ValueError(
f"Plugin names must be more than one letter, not {targets[0]}")
for t in targets:
if len(t) == 1:
raise ValueError(f"Plugin names must be more than one letter, not {t}")

plugins = self._get_plugins(targets, run_id)
target = targets[0] # See above, already restricted to one target

# Get savers/loaders, and meanwhile filter out plugins that do not
# have to do computation. (their instances will stick around
Expand Down Expand Up @@ -761,7 +770,14 @@ def concat_loader(*args, **kwargs):
intersec = list(plugins.keys() & loaders.keys())
if len(intersec):
raise RuntimeError(f"{intersec} both computed and loaded?!")

if len(targets) > 1:
final_plugin = self._get_end_targets(plugins)[:1]
self.log.warning(
f'Multiple targets detected! This is only suitable for mass '
f'producing dataypes since only {final_plugin} will be '
f'subscribed in the mailbox system!')
else:
final_plugin = targets
# For the plugins which will run computations,
# check all required options are available or set defaults.
# Also run any user-defined setup
Expand All @@ -772,7 +788,7 @@ def concat_loader(*args, **kwargs):
plugins=plugins,
loaders=loaders,
savers=dict(savers),
targets=targets)
targets=strax.to_str_tuple(final_plugin))

def estimate_run_start_and_end(self, run_id, targets=None):
"""Return run start and end time in ns since epoch.
Expand Down Expand Up @@ -865,8 +881,9 @@ def get_iter(self, run_id: str,
time_selection='fully_contained',
selection_str=None,
keep_columns=None,
_chunk_number=None,
allow_multiple=False,
progress_bar=True,
_chunk_number=None,
**kwargs) -> ty.Iterator[strax.Chunk]:
"""Compute target for run_id and iterate over results.

Expand Down Expand Up @@ -899,7 +916,7 @@ def get_iter(self, run_id: str,
dict(depends_on=tuple(targets)))
self.register(p)
targets = (temp_name,)
else:
elif not allow_multiple:
raise RuntimeError("Cannot automerge different data kinds!")

components = self.get_components(run_id,
Expand Down Expand Up @@ -1105,6 +1122,10 @@ def get_array(self, run_id: ty.Union[str, tuple, list],
{get_docs}
"""
run_ids = strax.to_str_tuple(run_id)

if kwargs.get('allow_multiple', False):
raise RuntimeError('Cannot allow_multiple with get_array/get_df')

if len(run_ids) > 1:
results = strax.multi_run(
self.get_array, run_ids, targets=targets,
Expand Down Expand Up @@ -1165,6 +1186,9 @@ def accumulate(self,
n_chunks: number of chunks in run
n_rows: number of data entries in run
"""
if kwargs.get('allow_multiple', False):
raise RuntimeError('Cannot allow_multiple with accumulate')

n_chunks = 0
seen_data = False
result = {'n_rows': 0}
Expand All @@ -1175,7 +1199,8 @@ def function(arr):
return arr
function_takes_fields = False

for chunk in self.get_iter(run_id, targets, **kwargs):
for chunk in self.get_iter(run_id, targets,
**kwargs):
data = chunk.data
data = self._apply_function(data, targets)

Expand Down Expand Up @@ -1511,6 +1536,11 @@ def add_method(cls, f):
Many plugins save automatically anyway.
:param max_workers: Number of worker threads/processes to spawn.
In practice more CPUs may be used due to strax's multithreading.
:param allow_multiple: Allow multiple targets to be computed
simultaneously without merging the results of the target. This can
be used when mass producing plugins that are not of the same
datakind. Don't try to use this in get_array or get_df because the
data is not returned.
""" + select_docs


Expand Down
41 changes: 41 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,44 @@ def test_superrun():
np.testing.assert_array_equal(p1['area'], np.zeros(len(p1)))
np.testing.assert_array_equal(p2['area'], np.zeros(len(p2)))
np.testing.assert_array_equal(ps, np.concatenate([p1, p2]))


def test_allow_multiple(targets=('peaks', 'records')):
"""Test if we can use the allow_multiple correctly and fail otherwise"""
with tempfile.TemporaryDirectory() as temp_dir:
mystrax = strax.Context(storage=strax.DataDirectory(temp_dir,
deep_scan=True),
register=[Records, Peaks])
assert not mystrax.is_stored(run_id, 'peaks')
# Create everything at once with get_array and get_df should fail
for function in [mystrax.get_array, mystrax.get_df]:
try:
function(run_id=run_id,
allow_multiple=True,
targets=targets)
except RuntimeError:
# Great, this doesn't work (and it shouldn't!)
continue
raise ValueError(f'{function} could run with allow_multiple')

try:
mystrax.make(run_id=run_id,
targets=targets)
except RuntimeError:
# Great, we shouldn't be allowed
pass

assert not mystrax.is_stored(run_id, 'peaks')
mystrax.make(run_id=run_id,
allow_multiple=True,
targets=targets)

for t in targets:
assert mystrax.is_stored(run_id, t)


def test_allow_multiple_inverted():
# Make sure that the processing also works if the first target is
# actually depending on the second. In that case, we should
# subscribe the first target as the endpoint of the processing
test_allow_multiple(targets=('records', 'peaks',))