Skip to content

Commit

Permalink
Fix caching issue (#768)
Browse files Browse the repository at this point in the history
* Fix caching issue

* Minor fixes

* Change back order of arguements

* Move lineage into an additional function

* Speed up caching. Only load requested plugins

* Increase safety counter add raise

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: dachengx <dx2227@columbia.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 21, 2023
1 parent 672dacb commit 72f935c
Showing 1 changed file with 116 additions and 94 deletions.
210 changes: 116 additions & 94 deletions strax/context.py
Expand Up @@ -677,7 +677,11 @@ def _fix_dependency(self, plugin_registry: dict, end_plugin: str):
self._fix_dependency(plugin_registry, go_to)
plugin_registry[end_plugin].fix_dtype()

def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]:
def __get_requested_plugins_from_cache(
self,
run_id: str,
targets: ty.Tuple[str],
) -> ty.Dict[str, strax.Plugin]:
# Doubly underscored since we don't do any key-checks etc here
"""Load requested plugins from the plugin_cache."""
requested_plugins = {}
Expand All @@ -703,37 +707,27 @@ def __get_plugins_from_cache(self, run_id: str) -> ty.Dict[str, strax.Plugin]:
plugin.deps = {
dependency: requested_plugins[dependency] for dependency in plugin.depends_on
}

# Finally, fix the dtype. Since infer_dtype may depend on the
# entire deps chain, we need to start at the last plugin and go
# all the way down to the lowest level.
for final_plugins in self._get_end_targets(requested_plugins):
self._fix_dependency(requested_plugins, final_plugins)
for target_plugins in targets:
self._fix_dependency(requested_plugins, target_plugins)

requested_plugins = {i: v for i, v in requested_plugins.items() if i in targets}
return requested_plugins

def _get_plugins(
self, targets: ty.Union[ty.Tuple[str], ty.List[str]], run_id: str
self,
targets: ty.Union[ty.Tuple[str], ty.List[str]],
run_id: str,
) -> ty.Dict[str, strax.Plugin]:
"""Return dictionary of plugin instances necessary to compute targets from scratch.
For a plugin that produces multiple outputs, we make only a single instance, which is
referenced under multiple keys in the output dict.
"""
if self._plugins_are_cached(targets):
cached_plugins = self.__get_plugins_from_cache(run_id)
plugins = {}
targets = list(targets)
while targets:
target = targets.pop(0)
if target in plugins:
continue

target_plugin = cached_plugins[target]
for provides in target_plugin.provides:
plugins[provides] = target_plugin
targets += list(target_plugin.depends_on)
return plugins

# Check all config options are taken by some registered plugin class
# (helps spot typos)
all_opts = set().union(*[
Expand All @@ -743,98 +737,126 @@ def _get_plugins(
if not (k in all_opts or k in self.context_config["free_options"]):
self.log.warning(f"Option {k} not taken by any registered plugin")

# Initialize plugins for the entire computation graph
# (most likely far further down than we need)
# to get lineages and dependency info.
def get_plugin(data_type):
nonlocal non_local_plugins
plugins = {}
targets = list(targets)
safety_counter = 0
while targets and safety_counter < 10_000:
safety_counter += 1
targets = list(set(targets)) # Remove duplicates from list.
target = targets.pop(0)
if target in plugins:
continue

if data_type not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_type}")
target_plugin = self.__get_plugin(run_id, target)
for provides in target_plugin.provides:
plugins[provides] = target_plugin
targets += list(target_plugin.depends_on)

plugin = self._plugin_class_registry[data_type]()
_not_all_plugins_initalized = (safety_counter == 10_000) & len(targets)
if _not_all_plugins_initalized:
raise ValueError(
"Could not initalize all plugins to compute target from scratch. "
f"The reamining targets missing are: {targets}"
)

d_provides = None # just to make codefactor happy
for d_provides in plugin.provides:
non_local_plugins[d_provides] = plugin
return plugins

plugin.run_id = run_id
def __get_plugin(self, run_id: str, data_type: str):
"""Get single plugin either from cache or initialize it."""
# Check if plugin for data_type is already cached
if self._plugins_are_cached((data_type,)):
cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,))
target_plugin = cached_plugins[data_type]
return target_plugin

# The plugin may not get all the required options here
# but we don't know if we need the plugin yet
self._set_plugin_config(plugin, run_id, tolerant=True)
if data_type not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_type}")

plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on}
plugin = self._plugin_class_registry[data_type]()

last_provide = d_provides
plugin.run_id = run_id

if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
# drop the parents config from the lineage
configs = {}
# The plugin may not get all the required options here
# but we don't know if we need the plugin yet
self._set_plugin_config(plugin, run_id, tolerant=True)

# Getting information about the parent:
parent_class = plugin.__class__.__bases__[0]
# Get all parent options which are overwritten by a child:
parent_options = [
option.parent_option_name
for option in plugin.takes_config.values()
if option.child_option
]
plugin.deps = {
d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on
}

for option_name, v in plugin.config.items():
# Looping over all settings, option_name is either the option name of the
# parent or the child.
if option_name in parent_options:
# In case it is the parent we continue
continue
self.__add_lineage_to_plugin(run_id, plugin)

if plugin.takes_config[option_name].track:
# Add all options which should be tracked:
configs[option_name] = v
if not hasattr(plugin, "data_kind") and not plugin.multi_output:
if len(plugin.depends_on):
# Assume data kind is the same as the first dependency
first_dep = plugin.depends_on[0]
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep)
else:
# No dependencies: assume provided data kind and
# data type are synonymous
plugin.data_kind = plugin.provides[0]

# Also adding name and version of the parent to the lineage:
configs[parent_class.__name__] = parent_class.__version__
plugin.fix_dtype()

plugin.lineage = {
last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs)
}
else:
plugin.lineage = {
last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
{
option: setting
for option, setting in plugin.config.items()
if plugin.takes_config[option].track
},
)
}
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

if not hasattr(plugin, "data_kind") and not plugin.multi_output:
if len(plugin.depends_on):
# Assume data kind is the same as the first dependency
first_dep = plugin.depends_on[0]
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep)
else:
# No dependencies: assume provided data kind and
# data type are synonymous
plugin.data_kind = plugin.provides[0]
# Add plugin to cache
self._plugins_to_cache({data_type: plugin for data_type in plugin.provides})

plugin.fix_dtype()
return plugin

return plugin
def __add_lineage_to_plugin(self, run_id, plugin):
"""Adds lineage to plugin in place.
non_local_plugins = {}
for t in targets:
p = get_plugin(t)
non_local_plugins[t] = p
Also adds parent infromation in case of a child plugin.
"""
last_provide = [d_provides for d_provides in plugin.provides][-1]

if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
# drop the parents config from the lineage
configs = {}

# Getting information about the parent:
parent_class = plugin.__class__.__bases__[0]
# Get all parent options which are overwritten by a child:
parent_options = [
option.parent_option_name
for option in plugin.takes_config.values()
if option.child_option
]

for option_name, v in plugin.config.items():
# Looping over all settings, option_name is either the option name of the
# parent or the child.
if option_name in parent_options:
# In case it is the parent we continue
continue

if plugin.takes_config[option_name].track:
# Add all options which should be tracked:
configs[option_name] = v

# Also adding name and version of the parent to the lineage:
configs[parent_class.__name__] = parent_class.__version__

plugin.lineage = {
last_provide: (plugin.__class__.__name__, plugin.version(run_id), configs)
}
else:
plugin.lineage = {
last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
{
option: setting
for option, setting in plugin.config.items()
if plugin.takes_config[option].track
},
)
}

self._plugins_to_cache(non_local_plugins)
return non_local_plugins
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

def _per_run_default_allowed_check(self, option_name, option):
"""Check if an option of a registered plugin is allowed."""
Expand Down

0 comments on commit 72f935c

Please sign in to comment.