Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creation and storing of superruns if SaveWhen > 0 #509

Merged
merged 5 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,6 @@ def get_components(self, run_id: str,
"""Return components for setting up a processor
{get_docs}
"""

save = strax.to_str_tuple(save)
targets = strax.to_str_tuple(targets)

Expand Down Expand Up @@ -725,7 +724,7 @@ def check_cache(target_i):
# we have to deactivate the storage converter mode.
stc_mode = self.context_config['storage_converter']
self.context_config['storage_converter'] = False
self.make(list(sub_run_spec.keys()), target_i)
self.make(list(sub_run_spec.keys()), target_i, save=(target_i,))
self.context_config['storage_converter'] = stc_mode

ldrs = []
Expand All @@ -743,7 +742,8 @@ def check_cache(target_i):
if not loader:
raise RuntimeError(
f"Could not load {target_i} for subrun {subrun} "
f"even though we made it??")
"even though we made it? Is the plugin "
"you are requesting a SaveWhen.NEVER-plguin?")
ldrs.append(loader)

def concat_loader(*args, **kwargs):
Expand Down Expand Up @@ -776,6 +776,7 @@ def concat_loader(*args, **kwargs):
raise strax.DataNotAvailable(
f"{target_i} for {run_id} not found in any storage, and "
"your context specifies it cannot be created.")

to_compute[target_i] = target_plugin
for dep_d in target_plugin.depends_on:
check_cache(dep_d)
Expand All @@ -797,7 +798,9 @@ def concat_loader(*args, **kwargs):
if target_i not in targets:
return
elif target_plugin.save_when == strax.SaveWhen.EXPLICIT:
if target_i not in save:
# If we arrive here in case of a superrun the user want to save
# as self.context_config['write_superruns'] is true.
if target_i not in save and not _is_superrun:
return
else:
assert target_plugin.save_when == strax.SaveWhen.ALWAYS
Expand Down
45 changes: 43 additions & 2 deletions tests/test_superruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self, superrun_name='_superrun_test'):
provide_run_metadata=True,
readonly=False,
deep_scan=True)],
register=[Records, RecordsExtension, Peaks],
register=[Records, RecordsExtension, Peaks, PeaksExtension],
config={'bonus_area': 42}
)
self.context.set_context_config({'write_superruns': True,
Expand Down Expand Up @@ -213,7 +213,24 @@ def test_superrun_triggers_subrun_processing(self):
self.context.make(self.superrun_name, 'peaks')
assert self.context.is_stored(self.superrun_name, 'peaks')
assert self.context.is_stored(self.subrun_ids[0], 'peaks')


def test_superruns_and_save_when(self):
"""
Tests if only the highest level for save_when.EXPLICIT plugins is stored.
"""
assert not self.context.is_stored(self.superrun_name, 'peaks')
assert not self.context.is_stored(self.subrun_ids[0], 'peaks')
assert not self.context.is_stored(self.subrun_ids[0], 'peaks_extension')
assert not self.context.is_stored(self.superrun_name, 'peaks_extension')
self.context._plugin_class_registry['peaks'].save_when = strax.SaveWhen.EXPLICIT

self.context.make(self.superrun_name, 'peaks_extension')
assert not self.context.is_stored(self.superrun_name, 'peaks')
assert not self.context.is_stored(self.subrun_ids[0], 'peaks')
assert self.context.is_stored(self.superrun_name, 'peaks_extension')
assert self.context.is_stored(self.subrun_ids[0], 'peaks_extension')


def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)
Expand Down Expand Up @@ -267,3 +284,27 @@ def compute(self, records):
res['dt'] = records['dt']
res['additional_field'] = self.config['some_additional_value']
return res


@strax.takes_config(
strax.Option(
name='some_additional_peak_value',
default=42,
help="Some additional value for merger",
)
)
class PeaksExtension(strax.Plugin):

depends_on = 'peaks'
provides = 'peaks_extension'
save_when = strax.SaveWhen.EXPLICIT
dtype = strax.time_dt_fields + [(('Some additional field', 'some_additional_peak_field'), np.int16)]

def compute(self, peaks):

res = np.zeros(len(peaks), self.dtype)
res['time'] = peaks['time']
res['length'] = peaks['length']
res['dt'] = peaks['dt']
res['some_additional_peak_field'] = self.config['some_additional_peak_value']
return res