Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add funcs for getting stored source #590

Merged
merged 8 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 50 additions & 5 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _context_hash(self):
{data_type: (plugin.__version__, plugin.compressor, plugin.input_timeout)
for data_type, plugin in self._plugin_class_registry.items()
if not data_type.startswith('_temp_')
})
})
return strax.deterministic_hash(base_hash_on_config)

def _plugins_are_cached(self, targets: ty.Tuple[str],) -> bool:
Expand Down Expand Up @@ -1381,10 +1381,9 @@ def get_df(self, run_id: ty.Union[str, tuple, list],
f"array fields. Please use get_array.")
raise


def get_zarr(self, run_ids, targets, storage='./strax_temp_data',
progress_bar=False, overwrite=True, **kwargs):
"""get perisistant arrays using zarr. This is useful when
def get_zarr(self, run_ids, targets, storage='./strax_temp_data',
progress_bar=False, overwrite=True, **kwargs):
"""get persistent arrays using zarr. This is useful when
loading large amounts of data that cannot fit in memory
zarr is very compatible with dask.
Targets are loaded into separate arrays and runs are merged.
Expand Down Expand Up @@ -1656,6 +1655,52 @@ def copy_to_frontend(self,
f'Trying to write {data_key} to {t_sf} which already exists, '
'do you have two storage frontends writing to the same place?')

def get_source(self,
run_id: str,
target: str,
check_forbidden: bool = True,
) -> ty.Union[set, None]:
"""
For a given run_id and target get the stored bases where we can
start processing from, if no base is available, return None.

:param run_id: run_id
:param target: target
:param check_forbidden: Check that we are not requesting to make
a plugin that is forbidden by the context to be created.
:return: set of plugin names that are needed to start processing
from and are needed in order to build this target.
"""
if self.is_stored(run_id, target):
return {target}

deps = strax.to_str_tuple(self._plugin_class_registry[target].depends_on)
if not deps:
return None

forbidden = strax.to_str_tuple(self.context_config['forbid_creation_of'])
if check_forbidden and target in forbidden:
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe be more explicit in this case and raise an error or make a warning. Like you have further down.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point!


stored_sources = set()
for dep in deps:
if self.is_stored(run_id, dep):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt this redundent? shouldnt the recursive call to self.get_source check if its stored or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah great point!

stored_sources |= {dep}
elif check_forbidden and dep in forbidden:
self.log.warning(f'For run {run_id}:{target}, you are not '
f'allowed to make {dep} and it is not stored. '
f'Disable with check_forbidden=False'
)
return None
else:
deeper = self.get_source(run_id, dep, check_forbidden=check_forbidden)
if deeper is None:
self.log.info(f'For run {run_id}, requested dependency '
f'{dep} for {target} is not stored')
return None
stored_sources |= deeper
return stored_sources

def _is_stored_in_sf(self, run_id, target,
storage_frontend: strax.StorageFrontend) -> bool:
"""
Expand Down
44 changes: 40 additions & 4 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
import numpy as np
from hypothesis import given, settings
import hypothesis.strategies as st
import hypothesis.strategies as strategy
import typing as ty
import os
import unittest
Expand Down Expand Up @@ -49,7 +49,7 @@ def nothing(data, r, t):


@settings(deadline=None)
@given(st.integers(min_value=-10, max_value=10))
@given(strategy.integers(min_value=-10, max_value=10))
def test_apply_ch_shift_to_data(magic_shift: int):
"""
Apply some magic shift number to the channel field and check the results
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_copy_to_frontend():
)


class TestPerRunDefaults(unittest.TestCase):
class TestContext(unittest.TestCase):
"""Test the per-run defaults options of a context"""
def setUp(self):
"""
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_deregister(self):
st.register(Records)
st.register(Peaks)
st.deregister_plugins_with_missing_dependencies()
assert all([p in st._plugin_class_registry for p in 'peaks records'.split()])
assert all([p in st._plugin_class_registry for p in 'peaks records'.split()])
st._plugin_class_registry.pop('records', None)
st.deregister_plugins_with_missing_dependencies()
assert st._plugin_class_registry.pop('peaks', None) is None
Expand Down Expand Up @@ -277,3 +277,39 @@ def _has_per_run_default(plugin) -> bool:
# Found one option
break
return has_per_run_defaults

def test_get_source(self):
"""See if we get the correct answer for each of the plugins"""
st = self.get_context(True)
st.register(Records)
st.register(Peaks)
st.register(self.get_dummy_peaks_dependency())
st.set_context_config({'forbid_creation_of': ('peaks',)})
for target in 'records peaks cut_peaks'.split():
# Nothing is available and nothing should find a source
assert not st.is_stored(run_id, target)
assert st.get_source(run_id, target) is None

# Now make a source "records"
st.make(run_id, 'records')
assert st.get_source(run_id, 'records') == {'records'}
# since we cannot make peaks!
assert st.get_source(run_id, 'peaks', check_forbidden=True) is None
assert st.get_source(run_id, 'cut_peaks', check_forbidden=True) is None

# We could ignore the error though
assert st.get_source(run_id, 'peaks', check_forbidden=False) == {'records'}
assert st.get_source(run_id, 'cut_peaks', check_forbidden=False) == {'records'}

st.set_context_config({'forbid_creation_of': ()})
st.make(run_id, 'peaks')
assert st.get_source(run_id, 'records') == {'records'}
assert st.get_source(run_id, 'peaks') == {'peaks'}
assert st.get_source(run_id, 'peaks') == {'peaks'}

@staticmethod
def get_dummy_peaks_dependency():
class DummyDependsOnPeaks(strax.CutPlugin):
depends_on = 'peaks'
provides = 'cut_peaks'
return DummyDependsOnPeaks