Skip to content

Commit

Permalink
Deregister partially replaced multi-output plugins (#775)
Browse files Browse the repository at this point in the history
* Deregister partially replaced multi-output plugins
  • Loading branch information
JelleAalbers committed Nov 20, 2023
1 parent a9d4ebb commit 672dacb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
15 changes: 15 additions & 0 deletions strax/context.py
Expand Up @@ -328,9 +328,24 @@ def register(self, plugin_class):
if isinstance(plugin_class.provides, str):
plugin_class.provides = tuple([plugin_class.provides])

# Register the plugin for all datatypes it provides,
# tracking which plugins we booted out.
deregistered = []
for p in plugin_class.provides:
old_plugin_class = self._plugin_class_registry.get(p, None)
if old_plugin_class and old_plugin_class != plugin_class:
deregistered.append(old_plugin_class)
self._plugin_class_registry[p] = plugin_class

# If we booted a plugin from a datatype, we must boot it from other
# datatypes it makes too, to preserve a one-to-one mapping between
# datatypes and registered plugins.
for old_plugin in deregistered:
for d in old_plugin.provides:
currently_registered = self._plugin_class_registry.get(d)
if old_plugin == currently_registered:
del self._plugin_class_registry[d]

already_seen = []
for plugin in self._plugin_class_registry.values():
if plugin in already_seen:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_context.py
Expand Up @@ -9,6 +9,7 @@
import unittest
import shutil
import uuid
import pytest


def _apply_function_to_data(function) -> ty.Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -224,6 +225,28 @@ def test_get_plugins_with_cache(self):
not_cached_plugins.keys() == cached_plugins.keys()
), f"_get_plugins returns different plugins if cached!"

def test_multioutput_deregistration(self):
"""Test that a multi-output plugin is deregistered once one of its outputs is provided by
another plugin."""

class RecordsPlus(strax.Plugin):
depends_on = tuple()
data_kind = dict(records="records", plus="plus")
dtype = dict(records=strax.record_dtype(), plus=strax.record_dtype())
provides = ("records", "plus")

st = self.get_context(False)
st.register(RecordsPlus)
st.register(Records)

# Records will make the records, not RecordsPlus
plugins = st._get_plugins(("records",), run_id)
assert isinstance(plugins["records"], Records)

# Plus is no longer available.
with pytest.raises(KeyError):
plugins = st.key_for("plus", run_id)

def test_register_no_defaults(self, runs_default_allowed=False):
"""Test if we only register a plugin with no run-defaults."""
st = self.get_context(runs_default_allowed)
Expand Down

0 comments on commit 672dacb

Please sign in to comment.