Skip to content

Commit

Permalink
Also copy dps and remove redundant checks. (#777)
Browse files Browse the repository at this point in the history
* Also copy dps and remove redundant checks.

* Remove unused function

* Fix infer_dtype propagation
  • Loading branch information
WenzDaniel committed Dec 14, 2023
1 parent 8c54d5b commit b0ca3cb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 30 deletions.
37 changes: 8 additions & 29 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,14 +626,14 @@ def _context_hash(self):
If any item changes in the config, so does this hash.
"""
base_hash_on_config = self.config.copy()
self._base_hash_on_config = self.config.copy()
# Also take into account the versions of the plugins registered
base_hash_on_config.update({
self._base_hash_on_config.update({
data_type: (plugin.__version__, plugin.compressor, plugin.input_timeout)
for data_type, plugin in self._plugin_class_registry.items()
if not data_type.startswith("_temp_")
})
return strax.deterministic_hash(base_hash_on_config)
return strax.deterministic_hash(self._base_hash_on_config)

def _plugins_are_cached(
self,
Expand Down Expand Up @@ -667,16 +667,6 @@ def _plugins_to_cache(self, plugins: dict) -> None:
for target, plugin in plugins.items():
self._fixed_plugin_cache[context_hash][target] = plugin

def _fix_dependency(self, plugin_registry: dict, end_plugin: str):
"""Starting from end-plugin, fix the dtype until there is nothing left to fix.
Keep in mind that dtypes can be chained.
"""
for go_to in plugin_registry[end_plugin].depends_on:
self._fix_dependency(plugin_registry, go_to)
plugin_registry[end_plugin].fix_dtype()

def __get_requested_plugins_from_cache(
self,
run_id: str,
Expand All @@ -685,9 +675,8 @@ def __get_requested_plugins_from_cache(
# Doubly underscored since we don't do any key-checks etc here
"""Load requested plugins from the plugin_cache."""
requested_plugins = {}
for target, plugin in self._fixed_plugin_cache[
self._context_hash()
].items(): # type: ignore
cached_plugins = self._fixed_plugin_cache[self._context_hash()] # type: ignore
for target, plugin in cached_plugins.items():
if target in requested_plugins:
# If e.g. target is already seen because the plugin is
# multi output
Expand All @@ -700,19 +689,9 @@ def __get_requested_plugins_from_cache(
for provides in strax.to_str_tuple(requested_p.provides):
requested_plugins[provides] = requested_p

# At this stage, all the plugins should be in requested_plugins
# To prevent infinite copying, we are only now linking the
# dependencies of each plugin to another where needed.
for target, plugin in requested_plugins.items():
plugin.deps = {
dependency: requested_plugins[dependency] for dependency in plugin.depends_on
}

# Finally, fix the dtype. Since infer_dtype may depend on the
# entire deps chain, we need to start at the last plugin and go
# all the way down to the lowest level.
for target_plugins in targets:
self._fix_dependency(requested_plugins, target_plugins)
# Finally, fix the dtype.
for plugin in requested_plugins.values():
plugin.fix_dtype()

requested_plugins = {i: v for i, v in requested_plugins.items() if i in targets}
return requested_plugins
Expand Down
10 changes: 9 additions & 1 deletion strax/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,15 @@ def __copy__(self, _deep_copy=False):
# As explained in PR #485 only copy attributes whereof we know
# don't depend on the run_id (for use_per_run_defaults == False).
# Otherwise we might copy run-dependent things like to_pe.
for attribute in ["dtype", "lineage", "takes_config", "__version__", "config", "data_kind"]:
for attribute in [
"dtype",
"lineage",
"takes_config",
"__version__",
"config",
"data_kind",
"deps",
]:
source_value = getattr(self, attribute)
if _deep_copy:
plugin_copy.__setattr__(attribute, deepcopy(source_value))
Expand Down

0 comments on commit b0ca3cb

Please sign in to comment.