From 72f935cdeb8db308847271c0625fac3f6eeb8a19 Mon Sep 17 00:00:00 2001 From: Daniel Wenz <43881800+WenzDaniel@users.noreply.github.com> Date: Tue, 21 Nov 2023 14:58:00 +0100 Subject: [PATCH] Fix caching issue (#768) * Fix caching issue * Minor fixes * Change back order of arguements * Move lineage into an additional function * Speed up caching. Only load requested plugins * Increase safety counter add raise * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: dachengx Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- strax/context.py | 210 ++++++++++++++++++++++++++--------------------- 1 file changed, 116 insertions(+), 94 deletions(-) diff --git a/strax/context.py b/strax/context.py index 5b69ba13a..a0bdda375 100644 --- a/strax/context.py +++ b/strax/context.py @@ -677,7 +677,11 @@ def _fix_dependency(self, plugin_registry: dict, end_plugin: str): self._fix_dependency(plugin_registry, go_to) plugin_registry[end_plugin].fix_dtype() - def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]: + def __get_requested_plugins_from_cache( + self, + run_id: str, + targets: ty.Tuple[str], + ) -> ty.Dict[str, strax.Plugin]: # Doubly underscored since we don't do any key-checks etc here """Load requested plugins from the plugin_cache.""" requested_plugins = {} @@ -703,15 +707,20 @@ def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]: plugin.deps = { dependency: requested_plugins[dependency] for dependency in plugin.depends_on } + # Finally, fix the dtype. Since infer_dtype may depend on the # entire deps chain, we need to start at the last plugin and go # all the way down to the lowest level. - for final_plugins in self._get_end_targets(requested_plugins): - self._fix_dependency(requested_plugins, final_plugins) + for target_plugins in targets: + self._fix_dependency(requested_plugins, target_plugins) + + requested_plugins = {i: v for i, v in requested_plugins.items() if i in targets} return requested_plugins def _get_plugins( - self, targets: ty.Union[ty.Tuple[str], ty.List[str]], run_id: str + self, + targets: ty.Union[ty.Tuple[str], ty.List[str]], + run_id: str, ) -> ty.Dict[str, strax.Plugin]: """Return dictionary of plugin instances necessary to compute targets from scratch. @@ -719,21 +728,6 @@ def _get_plugins( referenced under multiple keys in the output dict. """ - if self._plugins_are_cached(targets): - cached_plugins = self.__get_plugins_from_cache(run_id) - plugins = {} - targets = list(targets) - while targets: - target = targets.pop(0) - if target in plugins: - continue - - target_plugin = cached_plugins[target] - for provides in target_plugin.provides: - plugins[provides] = target_plugin - targets += list(target_plugin.depends_on) - return plugins - # Check all config options are taken by some registered plugin class # (helps spot typos) all_opts = set().union(*[ @@ -743,98 +737,126 @@ def _get_plugins( if not (k in all_opts or k in self.context_config["free_options"]): self.log.warning(f"Option {k} not taken by any registered plugin") - # Initialize plugins for the entire computation graph - # (most likely far further down than we need) - # to get lineages and dependency info. - def get_plugin(data_type): - nonlocal non_local_plugins + plugins = {} + targets = list(targets) + safety_counter = 0 + while targets and safety_counter < 10_000: + safety_counter += 1 + targets = list(set(targets)) # Remove duplicates from list. + target = targets.pop(0) + if target in plugins: + continue - if data_type not in self._plugin_class_registry: - raise KeyError(f"No plugin class registered that provides {data_type}") + target_plugin = self.__get_plugin(run_id, target) + for provides in target_plugin.provides: + plugins[provides] = target_plugin + targets += list(target_plugin.depends_on) - plugin = self._plugin_class_registry[data_type]() + _not_all_plugins_initalized = (safety_counter == 10_000) & len(targets) + if _not_all_plugins_initalized: + raise ValueError( + "Could not initalize all plugins to compute target from scratch. " + f"The reamining targets missing are: {targets}" + ) - d_provides = None # just to make codefactor happy - for d_provides in plugin.provides: - non_local_plugins[d_provides] = plugin + return plugins - plugin.run_id = run_id + def __get_plugin(self, run_id: str, data_type: str): + """Get single plugin either from cache or initialize it.""" + # Check if plugin for data_type is already cached + if self._plugins_are_cached((data_type,)): + cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,)) + target_plugin = cached_plugins[data_type] + return target_plugin - # The plugin may not get all the required options here - # but we don't know if we need the plugin yet - self._set_plugin_config(plugin, run_id, tolerant=True) + if data_type not in self._plugin_class_registry: + raise KeyError(f"No plugin class registered that provides {data_type}") - plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on} + plugin = self._plugin_class_registry[data_type]() - last_provide = d_provides + plugin.run_id = run_id - if plugin.child_plugin: - # Plugin is a child of another plugin, hence we have to - # drop the parents config from the lineage - configs = {} + # The plugin may not get all the required options here + # but we don't know if we need the plugin yet + self._set_plugin_config(plugin, run_id, tolerant=True) - # Getting information about the parent: - parent_class = plugin.__class__.__bases__[0] - # Get all parent options which are overwritten by a child: - parent_options = [ - option.parent_option_name - for option in plugin.takes_config.values() - if option.child_option - ] + plugin.deps = { + d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on + } - for option_name, v in plugin.config.items(): - # Looping over all settings, option_name is either the option name of the - # parent or the child. - if option_name in parent_options: - # In case it is the parent we continue - continue + self.__add_lineage_to_plugin(run_id, plugin) - if plugin.takes_config[option_name].track: - # Add all options which should be tracked: - configs[option_name] = v + if not hasattr(plugin, "data_kind") and not plugin.multi_output: + if len(plugin.depends_on): + # Assume data kind is the same as the first dependency + first_dep = plugin.depends_on[0] + plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep) + else: + # No dependencies: assume provided data kind and + # data type are synonymous + plugin.data_kind = plugin.provides[0] - # Also adding name and version of the parent to the lineage: - configs[parent_class.__name__] = parent_class.__version__ + plugin.fix_dtype() - plugin.lineage = { - last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs) - } - else: - plugin.lineage = { - last_provide: ( - plugin.__class__.__name__, - plugin.version(run_id), - { - option: setting - for option, setting in plugin.config.items() - if plugin.takes_config[option].track - }, - ) - } - for d_depends in plugin.depends_on: - plugin.lineage.update(plugin.deps[d_depends].lineage) - - if not hasattr(plugin, "data_kind") and not plugin.multi_output: - if len(plugin.depends_on): - # Assume data kind is the same as the first dependency - first_dep = plugin.depends_on[0] - plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep) - else: - # No dependencies: assume provided data kind and - # data type are synonymous - plugin.data_kind = plugin.provides[0] + # Add plugin to cache + self._plugins_to_cache({data_type: plugin for data_type in plugin.provides}) - plugin.fix_dtype() + return plugin - return plugin + def __add_lineage_to_plugin(self, run_id, plugin): + """Adds lineage to plugin in place. - non_local_plugins = {} - for t in targets: - p = get_plugin(t) - non_local_plugins[t] = p + Also adds parent infromation in case of a child plugin. + + """ + last_provide = [d_provides for d_provides in plugin.provides][-1] + + if plugin.child_plugin: + # Plugin is a child of another plugin, hence we have to + # drop the parents config from the lineage + configs = {} + + # Getting information about the parent: + parent_class = plugin.__class__.__bases__[0] + # Get all parent options which are overwritten by a child: + parent_options = [ + option.parent_option_name + for option in plugin.takes_config.values() + if option.child_option + ] + + for option_name, v in plugin.config.items(): + # Looping over all settings, option_name is either the option name of the + # parent or the child. + if option_name in parent_options: + # In case it is the parent we continue + continue + + if plugin.takes_config[option_name].track: + # Add all options which should be tracked: + configs[option_name] = v + + # Also adding name and version of the parent to the lineage: + configs[parent_class.__name__] = parent_class.__version__ + + plugin.lineage = { + last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs) + } + else: + plugin.lineage = { + last_provide: ( + plugin.__class__.__name__, + plugin.version(run_id), + { + option: setting + for option, setting in plugin.config.items() + if plugin.takes_config[option].track + }, + ) + } - self._plugins_to_cache(non_local_plugins) - return non_local_plugins + for d_depends in plugin.depends_on: + plugin.lineage.update(plugin.deps[d_depends].lineage) def _per_run_default_allowed_check(self, option_name, option): """Check if an option of a registered plugin is allowed."""