Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deregister partially replaced multi-output plugins #775

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,24 @@ def register(self, plugin_class):
if isinstance(plugin_class.provides, str):
plugin_class.provides = tuple([plugin_class.provides])

# Register the plugin for all datatypes it provides,
# tracking which plugins we booted out.
deregistered = []
for p in plugin_class.provides:
old_plugin_class = self._plugin_class_registry.get(p, None)
if old_plugin_class and old_plugin_class != plugin_class:
deregistered.append(old_plugin_class)
self._plugin_class_registry[p] = plugin_class

# If we booted a plugin from a datatype, we must boot it from other
# datatypes it makes too, to preserve a one-to-one mapping between
# datatypes and registered plugins.
for old_plugin in deregistered:
for d in old_plugin.provides:
currently_registered = self._plugin_class_registry.get(d)
if old_plugin == currently_registered:
del self._plugin_class_registry[d]

already_seen = []
for plugin in self._plugin_class_registry.values():
if plugin in already_seen:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest
import shutil
import uuid
import pytest


def _apply_function_to_data(function) -> ty.Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -224,6 +225,28 @@ def test_get_plugins_with_cache(self):
not_cached_plugins.keys() == cached_plugins.keys()
), f"_get_plugins returns different plugins if cached!"

def test_multioutput_deregistration(self):
"""Test that a multi-output plugin is deregistered once one of its outputs is provided by
another plugin."""

class RecordsPlus(strax.Plugin):
depends_on = tuple()
data_kind = dict(records="records", plus="plus")
dtype = dict(records=strax.record_dtype(), plus=strax.record_dtype())
provides = ("records", "plus")

st = self.get_context(False)
st.register(RecordsPlus)
st.register(Records)

# Records will make the records, not RecordsPlus
plugins = st._get_plugins(("records",), run_id)
assert isinstance(plugins["records"], Records)

# Plus is no longer available.
with pytest.raises(KeyError):
plugins = st.key_for("plus", run_id)

def test_register_no_defaults(self, runs_default_allowed=False):
"""Test if we only register a plugin with no run-defaults."""
st = self.get_context(runs_default_allowed)
Expand Down