diff --git a/aiodistbus/__init__.py b/aiodistbus/__init__.py index 0de5a08..109d6c7 100644 --- a/aiodistbus/__init__.py +++ b/aiodistbus/__init__.py @@ -5,6 +5,7 @@ from .eventbus import DEventBus, EventBus from .protocols import Event from .registry import registry +from .version import __version__ from .wrapper import DataClassEvent, make_evented setup_log() diff --git a/aiodistbus/entrypoint/aentrypoint.py b/aiodistbus/entrypoint/aentrypoint.py index 0d6ec2d..33a3d1e 100644 --- a/aiodistbus/entrypoint/aentrypoint.py +++ b/aiodistbus/entrypoint/aentrypoint.py @@ -87,7 +87,9 @@ async def wrapper(event: Event): return wrapper @abstractmethod - async def _update_handlers(self, event_type: Optional[str] = None): + async def _update_handlers( + self, event_type: Optional[str] = None, remove: bool = False + ): ... #################################################################### @@ -142,6 +144,21 @@ async def on( await self._update_handlers(event_type) + async def off(self, event_type: str): + """Remove a handler + + Args: + event_type (str): Event type + + """ + # Track handlers (supporting wildcards) + if "*" not in event_type: + del self._handlers[event_type] + else: + del self._wildcards[event_type] + + await self._update_handlers(event_type, remove=True) + @abstractmethod async def emit( self, event_type: str, data: Any, id: Optional[str] = None diff --git a/aiodistbus/entrypoint/dentrypoint.py b/aiodistbus/entrypoint/dentrypoint.py index dfede8e..4c11713 100644 --- a/aiodistbus/entrypoint/dentrypoint.py +++ b/aiodistbus/entrypoint/dentrypoint.py @@ -95,17 +95,20 @@ async def _run(self): if len(coros) > 0: await asyncio.gather(*coros) - async def _update_handlers(self, event_type: Optional[str] = None): + async def _update_handlers( + self, event_type: Optional[str] = None, remove: bool = False + ): if not self.subscriber: return + if remove and event_type: + self.subscriber.setsockopt(zmq.UNSUBSCRIBE, event_type.encode("utf-8")) + if event_type: self.subscriber.setsockopt(zmq.SUBSCRIBE, event_type.encode("utf-8")) - # logger.debug("SUBSCRIBER: Subscribed to %s", event_type) else: for event_type in self._handlers.keys(): self.subscriber.setsockopt(zmq.SUBSCRIBE, event_type.encode("utf-8")) - # logger.debug("SUBSCRIBER: Subscribed to %s", event_type) async def _pulse_sub(self): self.pulse_count += 1 @@ -120,7 +123,7 @@ async def _pulse_check(self): # If too many failures, close if self.pulse_fail > self.pulse_limit: - logger.error(f"aiodistbus: Pulse failure limit reached") + logger.error("aiodistbus: Pulse failure limit reached") if self._on_disrupt: await self._on_disrupt() await self.close() diff --git a/aiodistbus/entrypoint/entrypoint.py b/aiodistbus/entrypoint/entrypoint.py index deade4a..4ca5c8a 100644 --- a/aiodistbus/entrypoint/entrypoint.py +++ b/aiodistbus/entrypoint/entrypoint.py @@ -19,10 +19,15 @@ def __init__(self, block: bool = True): self.block = block self._bus: Optional[EventBus] = None - async def _update_handlers(self, event_type: Optional[str] = None): + async def _update_handlers( + self, event_type: Optional[str] = None, remove: bool = False + ): if self._bus is None: return + if remove and event_type: + await self._bus._off(self.id, event_type) + if event_type: if event_type in self._handlers: await self._bus._on(self.id, self._handlers[event_type]) diff --git a/aiodistbus/eventbus/eventbus.py b/aiodistbus/eventbus/eventbus.py index 2f48631..3f6840a 100644 --- a/aiodistbus/eventbus/eventbus.py +++ b/aiodistbus/eventbus/eventbus.py @@ -37,6 +37,13 @@ async def _on(self, id: str, handler: Handler): self._subs[handler.event_type][id] = sub self._dtypes[handler.event_type] = handler.dtype + async def _off(self, id: str, event_type: str): + if "*" in event_type: + del self._wildcard_subs[event_type][id] + else: + del self._subs[event_type][id] + del self._dtypes[event_type] + def _remove(self, id: str): to_be_removed: List[str] = [] for route, subs in self._subs.items(): diff --git a/aiodistbus/version.py b/aiodistbus/version.py index e69de29..3b93d0b 100644 --- a/aiodistbus/version.py +++ b/aiodistbus/version.py @@ -0,0 +1 @@ +__version__ = "0.0.2" diff --git a/pyproject.toml b/pyproject.toml index e65404f..aa6ce97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "aiodistbus" -version = "0.0.1" +version = "0.0.2" description = "ZeroMQ Distributed EventBus for Python" authors = [ {name = "Eduardo Davalos", email="eduardo.davalos.anaya@vanderbilt.edu"}, diff --git a/test/test_bus.py b/test/test_bus.py index 13b91e0..f46df38 100644 --- a/test/test_bus.py +++ b/test/test_bus.py @@ -71,3 +71,26 @@ async def test_local_bus_wildcard(bus, entrypoints): # Assert assert event.id in e1._received assert len(e1._received) == 1 + + +async def test_local_bus_off(bus, entrypoints): + + # Create resources + e1, e2 = entrypoints + + # Add funcs + await e1.on("test", func, ExampleEvent) + + # Connect + await e1.connect(bus) + await e2.connect(bus) + + # Remove + await e1.off("test") + + # Send message + event = await e2.emit("test", ExampleEvent("Hello")) + + # Assert + assert event.id not in e1._received + assert len(e1._received) == 0 diff --git a/test/test_dbus.py b/test/test_dbus.py index 813108d..9ae32be 100644 --- a/test/test_dbus.py +++ b/test/test_dbus.py @@ -104,3 +104,28 @@ async def test_dbus_emit_wildcard(dbus, dentrypoints): # Assert assert event1 and event1.id not in e1._received assert event2 and event2.id in e1._received + + +async def test_dbus_off(dbus, dentrypoints): + + # Create resources + e1, e2 = dentrypoints + + # Add funcs + await e1.on("test", func, ExampleEvent) + + # Connect + await e1.connect(dbus.ip, dbus.port) + await e2.connect(dbus.ip, dbus.port) + + # Off + await e1.off("test") + + # Send message + event1 = await e2.emit("test", ExampleEvent("Hello")) + + # Need to flush + await dbus.flush() + + # Assert + assert event1 and event1.id not in e1._received diff --git a/test/test_stress.py b/test/test_stress.py index 5ccb699..185f50d 100644 --- a/test/test_stress.py +++ b/test/test_stress.py @@ -12,7 +12,7 @@ logger = logging.getLogger("aiodistbus") N = 5 -M = 1000 +M = 100 @dataclass