-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix lineage if per run default is not allowed #483
Changes from 5 commits
b6348b5
1c7abb0
1dc6998
7ca0530
ad27883
93e2993
cae95cf
a1c2132
6e5cd04
8ffc9d8
e8422ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
import strax | ||
import hashlib | ||
|
||
export, __all__ = strax.exporter() | ||
__all__ += ['RUN_DEFAULTS_KEY'] | ||
|
@@ -91,7 +92,7 @@ class Context: | |
|
||
runs: ty.Union[pd.DataFrame, type(None)] = None | ||
_run_defaults_cache: dict = None | ||
|
||
_fixed_plugin_cache: dict = None | ||
storage: ty.List[strax.StorageFrontend] | ||
|
||
def __init__(self, | ||
|
@@ -415,12 +416,56 @@ def _set_plugin_config(self, p, run_id, tolerant=True): | |
parent_name = opt.parent_option_name | ||
|
||
mes = (f'Cannot find "{parent_name}" among the options of the parent.' | ||
f' Either you specified by accident {option_name} as child option' | ||
f' Either you specified by accident {option_name} as child option' | ||
f' or you specified the wrong parent_option_name. Have you specified ' | ||
'the correct parent option name?') | ||
assert parent_name in p.config, mes | ||
p.config[parent_name] = option_value | ||
|
||
def _config_hash(self): | ||
""" | ||
Dump the current config to a hash as a sanity check for building | ||
the _fixed_plugin_cache. If any item changes in the config, so | ||
does this hash. | ||
""" | ||
return hashlib.sha1(str(self.config).encode('ascii')).hexdigest() | ||
# return strax.deterministic_hash(self.config) | ||
|
||
def _plugins_are_cached(self, targets: ty.Tuple[str]) -> bool: | ||
"""Check if all the requested targets are in the _fixed_plugin_cache""" | ||
if self.context_config['use_per_run_defaults'] or self._fixed_plugin_cache is None: | ||
# There is no point in caching if plugins (lineage) can | ||
# change per run or the cache is empty. | ||
return False | ||
|
||
config_hash = self._config_hash() | ||
if config_hash in self._fixed_plugin_cache: | ||
targets_in_cache = [t in self._fixed_plugin_cache[config_hash] for t in targets] | ||
if all(targets_in_cache): | ||
return True | ||
|
||
return False | ||
|
||
def _plugins_to_cache(self, plugins: dict) -> None: | ||
if self.context_config['use_per_run_defaults']: | ||
# There is no point in caching if plugins (lineage) can change per run | ||
return | ||
if self._fixed_plugin_cache is None or self._config_hash() not in self._fixed_plugin_cache: | ||
self._fixed_plugin_cache = {self._config_hash(): dict()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume Yossi's comment is also the reason why you are always making a new dictionary when the hash cannot be found? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very keenly spotted Daniel, I should have put a comment here. So my suspicion is that the cache is only once created, at most twice if you change some option or register some plugin. I did think the likelihood that one was eating away memory for keeping the cache was greater than the change that someone was flipping between options often (although neither is very likely). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, this would be my guess, too. As I said I am fine like it is but if you can implement Yossi's suggestion easily it would be a nice to have, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fully agree |
||
for target, plugin in plugins.items(): | ||
self._fixed_plugin_cache[self._config_hash()][target] = plugin | ||
|
||
def __get_plugins_from_cache(self, | ||
run_id: str) -> ty.Dict[str, strax.Plugin]: | ||
# Doubly underscored since we don't do any key-checks etc here | ||
"""Load requested plugins from the plugin_cache""" | ||
requested_plugins = {} | ||
for target, plugin in self._fixed_plugin_cache[self._config_hash()].items(): | ||
# Lineage is fixed, just replace the run_id | ||
plugin.run_id = run_id | ||
requested_plugins[target] = plugin | ||
return requested_plugins | ||
|
||
def _get_plugins(self, | ||
targets: ty.Tuple[str], | ||
run_id: str) -> ty.Dict[str, strax.Plugin]: | ||
|
@@ -429,6 +474,9 @@ def _get_plugins(self, | |
For a plugin that produces multiple outputs, we make only a single | ||
instance, which is referenced under multiple keys in the output dict. | ||
""" | ||
if self._plugins_are_cached(targets): | ||
return self.__get_plugins_from_cache(run_id) | ||
|
||
# Check all config options are taken by some registered plugin class | ||
# (helps spot typos) | ||
all_opts = set().union(*[ | ||
|
@@ -441,84 +489,88 @@ def _get_plugins(self, | |
# Initialize plugins for the entire computation graph | ||
# (most likely far further down than we need) | ||
# to get lineages and dependency info. | ||
def get_plugin(data_kind): | ||
nonlocal plugins | ||
def get_plugin(data_type): | ||
nonlocal non_local_plugins | ||
|
||
if data_kind not in self._plugin_class_registry: | ||
raise KeyError(f"No plugin class registered that provides {data_kind}") | ||
if data_type not in self._plugin_class_registry: | ||
raise KeyError(f"No plugin class registered that provides {data_type}") | ||
|
||
p = self._plugin_class_registry[data_kind]() | ||
plugin = self._plugin_class_registry[data_type]() | ||
|
||
d_provides = None # just to make codefactor happy | ||
for d_provides in p.provides: | ||
plugins[d_provides] = p | ||
for d_provides in plugin.provides: | ||
non_local_plugins[d_provides] = plugin | ||
|
||
p.run_id = run_id | ||
plugin.run_id = run_id | ||
|
||
# The plugin may not get all the required options here | ||
# but we don't know if we need the plugin yet | ||
self._set_plugin_config(p, run_id, tolerant=True) | ||
self._set_plugin_config(plugin, run_id, tolerant=True) | ||
|
||
p.deps = {d_depends: get_plugin(d_depends) for d_depends in p.depends_on} | ||
plugin.deps = {d_depends: get_plugin(d_depends) for d_depends in plugin.depends_on} | ||
|
||
last_provide = d_provides | ||
|
||
if p.child_plugin: | ||
if plugin.child_plugin: | ||
# Plugin is a child of another plugin, hence we have to | ||
# drop the parents config from the lineage | ||
configs = {} | ||
|
||
# Getting information about the parent: | ||
parent_class = p.__class__.__bases__[0] | ||
parent_class = plugin.__class__.__bases__[0] | ||
# Get all parent options which are overwritten by a child: | ||
parent_options = [option.parent_option_name for option in p.takes_config.values() | ||
parent_options = [option.parent_option_name for option in plugin.takes_config.values() | ||
if option.child_option] | ||
|
||
for option_name, v in p.config.items(): | ||
for option_name, v in plugin.config.items(): | ||
# Looping over all settings, option_name is either the option name of the | ||
# parent or the child. | ||
if option_name in parent_options: | ||
# In case it is the parent we continue | ||
continue | ||
|
||
if p.takes_config[option_name].track: | ||
if plugin.takes_config[option_name].track: | ||
# Add all options which should be tracked: | ||
configs[option_name] = v | ||
|
||
# Also adding name and version of the parent to the lineage: | ||
configs[parent_class.__name__] = parent_class.__version__ | ||
|
||
p.lineage = {last_provide: (p.__class__.__name__, | ||
p.version(run_id), | ||
configs)} | ||
|
||
plugin.lineage = {last_provide: ( | ||
plugin.__class__.__name__, | ||
plugin.version(run_id), | ||
configs)} | ||
else: | ||
p.lineage = {last_provide: (p.__class__.__name__, | ||
p.version(run_id), | ||
{q: v for q, v in p.config.items() | ||
if p.takes_config[q].track})} | ||
for d_depends in p.depends_on: | ||
p.lineage.update(p.deps[d_depends].lineage) | ||
|
||
if not hasattr(p, 'data_kind') and not p.multi_output: | ||
if len(p.depends_on): | ||
plugin.lineage = {last_provide: ( | ||
plugin.__class__.__name__, | ||
plugin.version(run_id), | ||
{option: setting for option, setting | ||
in plugin.config.items() | ||
if plugin.takes_config[option].track})} | ||
for d_depends in plugin.depends_on: | ||
plugin.lineage.update(plugin.deps[d_depends].lineage) | ||
|
||
if not hasattr(plugin, 'data_kind') and not plugin.multi_output: | ||
if len(plugin.depends_on): | ||
# Assume data kind is the same as the first dependency | ||
first_dep = p.depends_on[0] | ||
p.data_kind = p.deps[first_dep].data_kind_for(first_dep) | ||
first_dep = plugin.depends_on[0] | ||
plugin.data_kind = plugin.deps[first_dep].data_kind_for(first_dep) | ||
else: | ||
# No dependencies: assume provided data kind and | ||
# data type are synonymous | ||
p.data_kind = p.provides[0] | ||
plugin.data_kind = plugin.provides[0] | ||
|
||
p.fix_dtype() | ||
plugin.fix_dtype() | ||
|
||
return p | ||
return plugin | ||
|
||
plugins = {} | ||
for t in targets: | ||
non_local_plugins = {} | ||
for t in targets: | ||
p = get_plugin(t) | ||
plugins[t] = p | ||
non_local_plugins[t] = p | ||
|
||
return plugins | ||
self._plugins_to_cache(non_local_plugins) | ||
return non_local_plugins | ||
|
||
def _per_run_default_allowed_check(self, option_name, option): | ||
"""Check if an option of a registered plugin is allowed""" | ||
|
@@ -659,7 +711,8 @@ def check_cache(d): | |
def concat_loader(*args, **kwargs): | ||
for x in ldrs: | ||
yield from x(*args, **kwargs) | ||
ldr = lambda *args, **kwargs : concat_loader(*args, **kwargs) | ||
|
||
ldr = lambda *args, **kwargs: concat_loader(*args, **kwargs) | ||
|
||
if ldr: | ||
# Found it! No need to make it or look in other frontends | ||
|
@@ -856,7 +909,7 @@ def to_absolute_time_range(self, run_id, targets=None, time_range=None, | |
:param time_within: row of strax data (e.g. eent) | ||
:param full_range: If True returns full time_range of the run. | ||
""" | ||
|
||
selection = ((time_range is None) + | ||
(seconds_range is None) + | ||
(time_within is None) + | ||
|
@@ -875,7 +928,7 @@ def to_absolute_time_range(self, run_id, targets=None, time_range=None, | |
# Force time range to be integers, since float math on large numbers | ||
# in not precise | ||
time_range = tuple([int(x) for x in time_range]) | ||
|
||
if full_range: | ||
time_range = self.estimate_run_start_and_end(run_id, targets) | ||
return time_range | ||
|
@@ -1195,6 +1248,7 @@ def accumulate(self, | |
if function is None: | ||
def function(arr): | ||
return arr | ||
|
||
function_takes_fields = False | ||
|
||
for chunk in self.get_iter(run_id, targets, | ||
|
@@ -1369,7 +1423,7 @@ def _check_forbidden(self): | |
Otherwise, try to make it a tuple""" | ||
self.context_config['forbid_creation_of'] = strax.to_str_tuple( | ||
self.context_config['forbid_creation_of']) | ||
|
||
def _apply_function(self, | ||
chunk_data: np.ndarray, | ||
run_id: ty.Union[str, tuple, list], | ||
|
@@ -1516,8 +1570,8 @@ def provided_dtypes(self, runid='0'): | |
:return: dictionary of provided dtypes with their corresponding lineage hash, save_when, version | ||
""" | ||
hashes = set([(d, self.key_for(runid, d).lineage_hash, p.save_when, p.__version__) | ||
for p in self._plugin_class_registry.values() | ||
for d in p.provides]) | ||
for p in self._plugin_class_registry.values() | ||
for d in p.provides]) | ||
|
||
return {dtype: dict(hash=h, save_when=save_when.name, version=version) | ||
for dtype, h, save_when, version in hashes} | ||
|
@@ -1560,7 +1614,6 @@ def add_method(cls, f): | |
data is not returned. | ||
""" + select_docs | ||
|
||
|
||
for attr in dir(Context): | ||
attr_val = getattr(Context, attr) | ||
if hasattr(attr_val, '__doc__'): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using the deterministic hash but it does not like immutibledict. Perhaps there is an other solution but my guess is that this serves the purpose quite well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this is not deterministic, you need to sort any dictionaries and sets to get a deterministic string. Also the string representation may depend on the version of the package of non-bultin values. Why not just fix the hashablize function here
This will catch almost any dict-like object, not just the builtin dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yossi, sure, let me try that.
Actually, it's not much of an issue if it's non-deterministic, the contexts will always just build the
_config_hash
on the fly, so any changes in order of the config or the version of a package may lead to a different context_hash but as long as it does not change for every run, it shouldn't matter. But you are right, better do make it deterministic because even if there is no issue now because some day it might if one assumes it is deterministic (which shouldn't be a bad assumption).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmosbacher , I tried your suggestion but it doesn't work. I initially also tried something like this (by doing the the same for
immutabledict
instead ofMapping
), however, it turns out that since immutabledict has a hash method, this whole logic is never accessed.I'll propose something similar to get it working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry about that. I assumed the issue was with the hash function but I guess it was in the json encoder all along? In that case maybe add the Mapping logic to the NumpyJSONEncoder class?