Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speedup get_source with lookupdict #599

Merged
merged 9 commits into from
Dec 9, 2021
Merged
90 changes: 72 additions & 18 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,37 +1666,91 @@ def get_source(self,
start processing from, if no base is available, return None.

:param run_id: run_id
:param target: target
:param target: target
:param check_forbidden: Check that we are not requesting to make
a plugin that is forbidden by the context to be created.
:return: set of plugin names that are needed to start processing
from and are needed in order to build this target.
"""
if self.is_stored(run_id, target):
return {target}

deps = strax.to_str_tuple(self._plugin_class_registry[target].depends_on)
if not deps:
try:
return set(plugin_name
for plugin_name, plugin_stored in
self.stored_dependencies(run_id=run_id,
target=target,
check_forbidden=check_forbidden
).items()
if plugin_stored
)
except strax.DataNotAvailable:
return None

def stored_dependencies(self,
run_id: str,
target: ty.Union[str, list, tuple],
check_forbidden: bool = True,
_targets_stored: ty.Optional[dict] = None,
) -> ty.Optional[dict]:
"""
For a given run_id and target(s) get a dictionary of all the datatypes that:

:param run_id: run_id
:param target: target or a list of targets
:param check_forbidden: Check that we are not requesting to make
a plugin that is forbidden by the context to be created.
:return: dictionary of data types (keys) required for building
the requested target(s) and if they are stored (values)
:raises strax.DataNotAvailable: if there is at least one data
type that is not stored and has no dependency or if it
cannot be created
"""
if _targets_stored is None:
_targets_stored = dict()

targets = strax.to_str_tuple(target)
if len(targets) > 1:
# Multiple targets, do them all
for dep in targets:
self.stored_dependencies(run_id,
dep,
check_forbidden=check_forbidden,
_targets_stored=_targets_stored,
)
return _targets_stored

# Make sure we have the string not ('target',)
target = targets[0]

if target in _targets_stored:
return

this_target_is_stored = self.is_stored(run_id, target)
_targets_stored[target] = this_target_is_stored

if this_target_is_stored:
return _targets_stored

# Need to init the class e.g. if we want to allow depends on like this:
# https://github.com/XENONnT/cutax/blob/d7ec0685650d03771fef66507fd6882676151b9b/cutax/cutlist.py#L33 # noqa
plugin = self._plugin_class_registry[target]()
dependencies = strax.to_str_tuple(plugin.depends_on)
if not dependencies:
raise strax.DataNotAvailable(f'Lowest level dependency {target} is not stored')

forbidden = strax.to_str_tuple(self.context_config['forbid_creation_of'])
forbidden_warning = (
'For {run_id}:{target}, you are not allowed to make {dep} and '
'it is not stored. Disable check with check_forbidden=False'
)
if check_forbidden and target in forbidden:
self.log.warning(forbidden_warning.format(run_id=run_id,
target=target,
dep=target,))
return None

stored_sources = set()
for dep in deps:
deeper = self.get_source(run_id, dep, check_forbidden=check_forbidden)
if deeper is None:
return None
stored_sources |= deeper
return stored_sources
raise strax.DataNotAvailable(
forbidden_warning.format(run_id=run_id, target=target, dep=target,))

self.stored_dependencies(run_id,
target=dependencies,
check_forbidden=check_forbidden,
_targets_stored=_targets_stored,
)
return _targets_stored

def _is_stored_in_sf(self, run_id, target,
storage_frontend: strax.StorageFrontend) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def test_get_source(self):
# since we cannot make peaks!
assert st.get_source(run_id, 'peaks', check_forbidden=True) is None
assert st.get_source(run_id, 'cut_peaks', check_forbidden=True) is None
assert st.get_source(run_id, ('peaks', 'cut_peaks'), check_forbidden=True) is None

# We could ignore the error though
assert st.get_source(run_id, 'peaks', check_forbidden=False) == {'records'}
Expand All @@ -304,8 +305,9 @@ def test_get_source(self):
st.set_context_config({'forbid_creation_of': ()})
st.make(run_id, 'peaks')
assert st.get_source(run_id, 'records') == {'records'}
assert st.get_source(run_id, ('records', 'peaks')) == {'records', 'peaks'}
assert st.get_source(run_id, 'peaks') == {'peaks'}
assert st.get_source(run_id, 'peaks') == {'peaks'}
assert st.get_source(run_id, 'cut_peaks') == {'peaks'}

@staticmethod
def get_dummy_peaks_dependency():
Expand Down