Skip to content

Commit

Permalink
Patch plugin cache (#748)
Browse files Browse the repository at this point in the history
*Add fix and test
  • Loading branch information
WenzDaniel committed Aug 11, 2023
1 parent c1d0554 commit 32df173
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
14 changes: 13 additions & 1 deletion strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,19 @@ def _get_plugins(self,
instance, which is referenced under multiple keys in the output dict.
"""
if self._plugins_are_cached(targets):
return self.__get_plugins_from_cache(run_id)
cached_plugins = self.__get_plugins_from_cache(run_id)
plugins = {}
targets = list(targets)
while targets:
target = targets.pop(0)
if target in plugins:
continue

target_plugin = cached_plugins[target]
for provides in target_plugin.provides:
plugins[provides] = target_plugin
targets += list(target_plugin.depends_on)
return plugins

# Check all config options are taken by some registered plugin class
# (helps spot typos)
Expand Down
24 changes: 24 additions & 0 deletions strax/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,30 @@ def compute(self, records):
return p


@strax.takes_config(
strax.Option('base_area', type=int, default=0),
strax.Option('give_wrong_dtype', type=bool, default=False),
strax.Option('bonus_area', type=int, default=0))
class PeaksWoPerRunDefault(strax.Plugin):
"""Same as peak plugin but without per run default option
to allow for plugin caching.
"""
provides = 'peaks'
data_kind = 'peaks'
depends_on = ('records',)
dtype = strax.peak_dtype()
parallel = True

def compute(self, records):
if self.config['give_wrong_dtype']:
return np.zeros(5, [('a', np.int64), ('b', np.float64)])
p = np.zeros(len(records), self.dtype)
p['time'] = records['time']
p['length'] = p['dt'] = 1
p['area'] = self.config['base_area'] + self.config['bonus_area']
return p


# Another peak-kind plugin, to test time_range selection
# with unaligned chunks
class PeakClassification(strax.Plugin):
Expand Down
13 changes: 12 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import strax
from strax.testutils import Records, Peaks, PeakClassification, run_id
from strax.testutils import Records, Peaks, PeaksWoPerRunDefault, PeakClassification, run_id
import tempfile
import numpy as np
from hypothesis import given, settings
Expand Down Expand Up @@ -215,6 +215,17 @@ def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)

def test_get_plugins_with_cache(self):
st = self.get_context(False)
st.register(Records)
st.register(PeaksWoPerRunDefault)
st.register(PeakClassification)

not_cached_plugins = st._get_plugins(('peaks',), run_id)
st._get_plugins(('peak_classification',), run_id)
cached_plugins = st._get_plugins(('peaks',), run_id)
assert not_cached_plugins.keys() == cached_plugins.keys(), f'_get_plugins returns different plugins if cached!'

def test_register_no_defaults(self, runs_default_allowed=False):
"""Test if we only register a plugin with no run-defaults"""
st = self.get_context(runs_default_allowed)
Expand Down

0 comments on commit 32df173

Please sign in to comment.