Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lineage if per run default is not allowed #483

Merged
merged 11 commits into from
Jul 13, 2021
145 changes: 99 additions & 46 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
import strax
import hashlib

export, __all__ = strax.exporter()
__all__ += ['RUN_DEFAULTS_KEY']
Expand Down Expand Up @@ -91,7 +92,7 @@ class Context:

runs: ty.Union[pd.DataFrame, type(None)] = None
_run_defaults_cache: dict = None

_fixed_plugin_cache: dict = None
storage: ty.List[strax.StorageFrontend]

def __init__(self,
Expand Down Expand Up @@ -415,12 +416,56 @@ def _set_plugin_config(self, p, run_id, tolerant=True):
parent_name = opt.parent_option_name

mes = (f'Cannot find "{parent_name}" among the options of the parent.'
f' Either you specified by accident {option_name} as child option'
f' Either you specified by accident {option_name} as child option'
f' or you specified the wrong parent_option_name. Have you specified '
'the correct parent option name?')
assert parent_name in p.config, mes
p.config[parent_name] = option_value

def _config_hash(self):
"""
Dump the current config to a hash as a sanity check for building
the _fixed_plugin_cache. If any item changes in the config, so
does this hash.
"""
return hashlib.sha1(str(self.config).encode('ascii')).hexdigest()
# return strax.deterministic_hash(self.config)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using the deterministic hash but it does not like immutibledict. Perhaps there is an other solution but my guess is that this serves the purpose quite well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this is not deterministic, you need to sort any dictionaries and sets to get a deterministic string. Also the string representation may depend on the version of the package of non-bultin values. Why not just fix the hashablize function here

from collections.abc import Mapping

254        if isinstance(obj, Mapping): # instead of if isinstance(obj, dict)
...

This will catch almost any dict-like object, not just the builtin dict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yossi, sure, let me try that.

Actually, it's not much of an issue if it's non-deterministic, the contexts will always just build the _config_hash on the fly, so any changes in order of the config or the version of a package may lead to a different context_hash but as long as it does not change for every run, it shouldn't matter. But you are right, better do make it deterministic because even if there is no issue now because some day it might if one assumes it is deterministic (which shouldn't be a bad assumption).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmosbacher , I tried your suggestion but it doesn't work. I initially also tried something like this (by doing the the same for immutabledict instead of Mapping), however, it turns out that since immutabledict has a hash method, this whole logic is never accessed.

I'll propose something similar to get it working.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry about that. I assumed the issue was with the hash function but I guess it was in the json encoder all along? In that case maybe add the Mapping logic to the NumpyJSONEncoder class?


def _plugins_are_cached(self, targets: ty.Tuple[str]) -> bool:
"""Check if all the requested targets are in the _fixed_plugin_cache"""
if self.context_config['use_per_run_defaults'] or self._fixed_plugin_cache is None:
# There is no point in caching if plugins (lineage) can
# change per run or the cache is empty.
return False

config_hash = self._config_hash()
if config_hash in self._fixed_plugin_cache:
targets_in_cache = [t in self._fixed_plugin_cache[config_hash] for t in targets]
if all(targets_in_cache):
return True

return False

def _plugins_to_cache(self, plugins: dict) -> None:
if self.context_config['use_per_run_defaults']:
# There is no point in caching if plugins (lineage) can change per run
return
if self._fixed_plugin_cache is None or self._config_hash() not in self._fixed_plugin_cache:
self._fixed_plugin_cache = {self._config_hash(): dict()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume Yossi's comment is also the reason why you are always making a new dictionary when the hash cannot be found?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very keenly spotted Daniel, I should have put a comment here.

So my suspicion is that the cache is only once created, at most twice if you change some option or register some plugin. I did think the likelihood that one was eating away memory for keeping the cache was greater than the change that someone was flipping between options often (although neither is very likely).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my suspicion is that the cache is only once created, at most twice if you change some option or register some plugin.

Yes, this would be my guess, too. As I said I am fine like it is but if you can implement Yossi's suggestion easily it would be a nice to have,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully agree

for target, plugin in plugins.items():
self._fixed_plugin_cache[self._config_hash()][target] = plugin

def __get_plugins_from_cache(self,
run_id: str) -> ty.Dict[str, strax.Plugin]:
# Doubly underscored since we don't do any key-checks etc here
"""Load requested plugins from the plugin_cache"""
requested_plugins = {}
for target, plugin in self._fixed_plugin_cache[self._config_hash()].items():
# Lineage is fixed, just replace the run_id
plugin.run_id = run_id
requested_plugins[target] = plugin
return requested_plugins

def _get_plugins(self,
targets: ty.Tuple[str],
run_id: str) -> ty.Dict[str, strax.Plugin]:
Expand All @@ -429,6 +474,9 @@ def _get_plugins(self,
For a plugin that produces multiple outputs, we make only a single
instance, which is referenced under multiple keys in the output dict.
"""
if self._plugins_are_cached(targets):
return self.__get_plugins_from_cache(run_id)

# Check all config options are taken by some registered plugin class
# (helps spot typos)
all_opts = set().union(*[
Expand All @@ -441,84 +489,88 @@ def _get_plugins(self,
# Initialize plugins for the entire computation graph
# (most likely far further down than we need)
# to get lineages and dependency info.
def get_plugin(data_kind):
nonlocal plugins
def get_plugin(data_type):
nonlocal non_local_plugins

if data_kind not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_kind}")
if data_type not in self._plugin_class_registry:
raise KeyError(f"No plugin class registered that provides {data_type}")

p = self._plugin_class_registry[data_kind]()
plugin = self._plugin_class_registry[data_type]()

d_provides = None # just to make codefactor happy
for d_provides in p.provides:
plugins[d_provides] = p
for d_provides in plugin.provides:
non_local_plugins[d_provides] = plugin

p.run_id = run_id
plugin.run_id = run_id

# The plugin may not get all the required options here
# but we don't know if we need the plugin yet
self._set_plugin_config(p, run_id, tolerant=True)
self._set_plugin_config(plugin, run_id, tolerant=True)

p.deps = {d_depends: get_plugin(d_depends) for d_depends in p.depends_on}
plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on}

last_provide = d_provides

if p.child_plugin:
if plugin.child_plugin:
# Plugin is a child of another plugin, hence we have to
# drop the parents config from the lineage
configs = {}

# Getting information about the parent:
parent_class = p.__class__.__bases__[0]
parent_class = plugin.__class__.__bases__[0]
# Get all parent options which are overwritten by a child:
parent_options = [option.parent_option_name for option in p.takes_config.values()
parent_options = [option.parent_option_name for option in plugin.takes_config.values()
if option.child_option]

for option_name, v in p.config.items():
for option_name, v in plugin.config.items():
# Looping over all settings, option_name is either the option name of the
# parent or the child.
if option_name in parent_options:
# In case it is the parent we continue
continue

if p.takes_config[option_name].track:
if plugin.takes_config[option_name].track:
# Add all options which should be tracked:
configs[option_name] = v

# Also adding name and version of the parent to the lineage:
configs[parent_class.__name__] = parent_class.__version__

p.lineage = {last_provide: (p.__class__.__name__,
p.version(run_id),
configs)}

plugin.lineage = {last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
configs)}
else:
p.lineage = {last_provide: (p.__class__.__name__,
p.version(run_id),
{q: v for q, v in p.config.items()
if p.takes_config[q].track})}
for d_depends in p.depends_on:
p.lineage.update(p.deps[d_depends].lineage)

if not hasattr(p, 'data_kind') and not p.multi_output:
if len(p.depends_on):
plugin.lineage = {last_provide: (
plugin.__class__.__name__,
plugin.version(run_id),
{option: setting for option, setting
in plugin.config.items()
if plugin.takes_config[option].track})}
for d_depends in plugin.depends_on:
plugin.lineage.update(plugin.deps[d_depends].lineage)

if not hasattr(plugin, 'data_kind') and not plugin.multi_output:
if len(plugin.depends_on):
# Assume data kind is the same as the first dependency
first_dep = p.depends_on[0]
p.data_kind = p.deps[first_dep].data_kind_for(first_dep)
first_dep = plugin.depends_on[0]
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep)
else:
# No dependencies: assume provided data kind and
# data type are synonymous
p.data_kind = p.provides[0]
plugin.data_kind = plugin.provides[0]

p.fix_dtype()
plugin.fix_dtype()

return p
return plugin

plugins = {}
for t in targets:
non_local_plugins = {}
for t in targets:
p = get_plugin(t)
plugins[t] = p
non_local_plugins[t] = p

return plugins
self._plugins_to_cache(non_local_plugins)
return non_local_plugins

def _per_run_default_allowed_check(self, option_name, option):
"""Check if an option of a registered plugin is allowed"""
Expand Down Expand Up @@ -659,7 +711,8 @@ def check_cache(d):
def concat_loader(*args, **kwargs):
for x in ldrs:
yield from x(*args, **kwargs)
ldr = lambda *args, **kwargs : concat_loader(*args, **kwargs)

ldr = lambda *args, **kwargs: concat_loader(*args, **kwargs)

if ldr:
# Found it! No need to make it or look in other frontends
Expand Down Expand Up @@ -856,7 +909,7 @@ def to_absolute_time_range(self, run_id, targets=None, time_range=None,
:param time_within: row of strax data (e.g. eent)
:param full_range: If True returns full time_range of the run.
"""

selection = ((time_range is None) +
(seconds_range is None) +
(time_within is None) +
Expand All @@ -875,7 +928,7 @@ def to_absolute_time_range(self, run_id, targets=None, time_range=None,
# Force time range to be integers, since float math on large numbers
# in not precise
time_range = tuple([int(x) for x in time_range])

if full_range:
time_range = self.estimate_run_start_and_end(run_id, targets)
return time_range
Expand Down Expand Up @@ -1195,6 +1248,7 @@ def accumulate(self,
if function is None:
def function(arr):
return arr

function_takes_fields = False

for chunk in self.get_iter(run_id, targets,
Expand Down Expand Up @@ -1369,7 +1423,7 @@ def _check_forbidden(self):
Otherwise, try to make it a tuple"""
self.context_config['forbid_creation_of'] = strax.to_str_tuple(
self.context_config['forbid_creation_of'])

def _apply_function(self,
chunk_data: np.ndarray,
run_id: ty.Union[str, tuple, list],
Expand Down Expand Up @@ -1516,8 +1570,8 @@ def provided_dtypes(self, runid='0'):
:return: dictionary of provided dtypes with their corresponding lineage hash, save_when, version
"""
hashes = set([(d, self.key_for(runid, d).lineage_hash, p.save_when, p.__version__)
for p in self._plugin_class_registry.values()
for d in p.provides])
for p in self._plugin_class_registry.values()
for d in p.provides])

return {dtype: dict(hash=h, save_when=save_when.name, version=version)
for dtype, h, save_when, version in hashes}
Expand Down Expand Up @@ -1560,7 +1614,6 @@ def add_method(cls, f):
data is not returned.
""" + select_docs


for attr in dir(Context):
attr_val = getattr(Context, attr)
if hasattr(attr_val, '__doc__'):
Expand Down