Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch plugin cache #748

Merged
merged 4 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,19 @@ def _get_plugins(self,
instance, which is referenced under multiple keys in the output dict.
"""
if self._plugins_are_cached(targets):
return self.__get_plugins_from_cache(run_id)
cached_plugins = self.__get_plugins_from_cache(run_id)
plugins = {}
targets = list(targets)
for target in targets:
if target in plugins:
continue

target_plugin = cached_plugins[target]
for provides in target_plugin.provides:
plugins[provides] = target_plugin

targets += list(target_plugin.depends_on)
return plugins

# Check all config options are taken by some registered plugin class
# (helps spot typos)
Expand Down
24 changes: 24 additions & 0 deletions strax/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,30 @@ def compute(self, records):
return p


@strax.takes_config(
strax.Option('base_area', type=int, default=0),
strax.Option('give_wrong_dtype', type=bool, default=False),
strax.Option('bonus_area', type=int, default=0))
class PeaksWoPerRunDefault(strax.Plugin):
"""Same as peak plugin but without per run default option
to allow for plugin caching.
"""
provides = 'peaks'
data_kind = 'peaks'
depends_on = ('records',)
dtype = strax.peak_dtype()
parallel = True

def compute(self, records):
if self.config['give_wrong_dtype']:
return np.zeros(5, [('a', np.int64), ('b', np.float64)])
p = np.zeros(len(records), self.dtype)
p['time'] = records['time']
p['length'] = p['dt'] = 1
p['area'] = self.config['base_area'] + self.config['bonus_area']
return p


# Another peak-kind plugin, to test time_range selection
# with unaligned chunks
class PeakClassification(strax.Plugin):
Expand Down
13 changes: 12 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import strax
from strax.testutils import Records, Peaks, PeakClassification, run_id
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id
import tempfile
import numpy as np
from hypothesis import given, settings
Expand Down Expand Up @@ -215,6 +215,17 @@ def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)

def test_get_plugins_with_cache(self):
st = self.get_context(False)
st.register(Records)
st.register(PeaksWoPerRunDefault)
st.register(PeakClassification)

not_cached_plugins = st._get_plugins(('peaks',), run_id)
st._get_plugins(('peak_classification',), run_id)
cached_plugins = st._get_plugins(('peaks',), run_id)
assert not_cached_plugins.keys() == cached_plugins.keys(), f'_get_plugins returns different plugins if cached!'

def test_register_no_defaults(self, runs_default_allowed=False):
"""Test if we only register a plugin with no run-defaults"""
st = self.get_context(runs_default_allowed)
Expand Down