diff --git a/strax/context.py b/strax/context.py index 43cfad2df..5b69ba13a 100644 --- a/strax/context.py +++ b/strax/context.py @@ -328,9 +328,24 @@ def register(self, plugin_class): if isinstance(plugin_class.provides, str): plugin_class.provides = tuple([plugin_class.provides]) + # Register the plugin for all datatypes it provides, + # tracking which plugins we booted out. + deregistered = [] for p in plugin_class.provides: + old_plugin_class = self._plugin_class_registry.get(p, None) + if old_plugin_class and old_plugin_class != plugin_class: + deregistered.append(old_plugin_class) self._plugin_class_registry[p] = plugin_class + # If we booted a plugin from a datatype, we must boot it from other + # datatypes it makes too, to preserve a one-to-one mapping between + # datatypes and registered plugins. + for old_plugin in deregistered: + for d in old_plugin.provides: + currently_registered = self._plugin_class_registry.get(d) + if old_plugin == currently_registered: + del self._plugin_class_registry[d] + already_seen = [] for plugin in self._plugin_class_registry.values(): if plugin in already_seen: diff --git a/tests/test_context.py b/tests/test_context.py index b441e6aef..7602277fa 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -9,6 +9,7 @@ import unittest import shutil import uuid +import pytest def _apply_function_to_data(function) -> ty.Tuple[np.ndarray, np.ndarray]: @@ -224,6 +225,28 @@ def test_get_plugins_with_cache(self): not_cached_plugins.keys() == cached_plugins.keys() ), f"_get_plugins returns different plugins if cached!" + def test_multioutput_deregistration(self): + """Test that a multi-output plugin is deregistered once one of its outputs is provided by + another plugin.""" + + class RecordsPlus(strax.Plugin): + depends_on = tuple() + data_kind = dict(records="records", plus="plus") + dtype = dict(records=strax.record_dtype(), plus=strax.record_dtype()) + provides = ("records", "plus") + + st = self.get_context(False) + st.register(RecordsPlus) + st.register(Records) + + # Records will make the records, not RecordsPlus + plugins = st._get_plugins(("records",), run_id) + assert isinstance(plugins["records"], Records) + + # Plus is no longer available. + with pytest.raises(KeyError): + plugins = st.key_for("plus", run_id) + def test_register_no_defaults(self, runs_default_allowed=False): """Test if we only register a plugin with no run-defaults.""" st = self.get_context(runs_default_allowed)