Skip to content

Commit

Permalink
Added off method for both local and distributed APIs. (#2)
Browse files Browse the repository at this point in the history
* Added off method for both local and distributed APIs.

* Lowered stress test, added version.py, and bumped to 0.0.2.
  • Loading branch information
edavalosanaya committed Nov 1, 2023
1 parent 0c16b69 commit bd6710c
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 8 deletions.
1 change: 1 addition & 0 deletions aiodistbus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .eventbus import DEventBus, EventBus
from .protocols import Event
from .registry import registry
from .version import __version__
from .wrapper import DataClassEvent, make_evented

setup_log()
Expand Down
19 changes: 18 additions & 1 deletion aiodistbus/entrypoint/aentrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ async def wrapper(event: Event):
return wrapper

@abstractmethod
async def _update_handlers(self, event_type: Optional[str] = None):
async def _update_handlers(
self, event_type: Optional[str] = None, remove: bool = False
):
...

####################################################################
Expand Down Expand Up @@ -142,6 +144,21 @@ async def on(

await self._update_handlers(event_type)

async def off(self, event_type: str):
"""Remove a handler
Args:
event_type (str): Event type
"""
# Track handlers (supporting wildcards)
if "*" not in event_type:
del self._handlers[event_type]
else:
del self._wildcards[event_type]

await self._update_handlers(event_type, remove=True)

@abstractmethod
async def emit(
self, event_type: str, data: Any, id: Optional[str] = None
Expand Down
11 changes: 7 additions & 4 deletions aiodistbus/entrypoint/dentrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,20 @@ async def _run(self):
if len(coros) > 0:
await asyncio.gather(*coros)

async def _update_handlers(self, event_type: Optional[str] = None):
async def _update_handlers(
self, event_type: Optional[str] = None, remove: bool = False
):
if not self.subscriber:
return

if remove and event_type:
self.subscriber.setsockopt(zmq.UNSUBSCRIBE, event_type.encode("utf-8"))

if event_type:
self.subscriber.setsockopt(zmq.SUBSCRIBE, event_type.encode("utf-8"))
# logger.debug("SUBSCRIBER: Subscribed to %s", event_type)
else:
for event_type in self._handlers.keys():
self.subscriber.setsockopt(zmq.SUBSCRIBE, event_type.encode("utf-8"))
# logger.debug("SUBSCRIBER: Subscribed to %s", event_type)

async def _pulse_sub(self):
self.pulse_count += 1
Expand All @@ -120,7 +123,7 @@ async def _pulse_check(self):

# If too many failures, close
if self.pulse_fail > self.pulse_limit:
logger.error(f"aiodistbus: Pulse failure limit reached")
logger.error("aiodistbus: Pulse failure limit reached")
if self._on_disrupt:
await self._on_disrupt()
await self.close()
Expand Down
7 changes: 6 additions & 1 deletion aiodistbus/entrypoint/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ def __init__(self, block: bool = True):
self.block = block
self._bus: Optional[EventBus] = None

async def _update_handlers(self, event_type: Optional[str] = None):
async def _update_handlers(
self, event_type: Optional[str] = None, remove: bool = False
):
if self._bus is None:
return

if remove and event_type:
await self._bus._off(self.id, event_type)

if event_type:
if event_type in self._handlers:
await self._bus._on(self.id, self._handlers[event_type])
Expand Down
7 changes: 7 additions & 0 deletions aiodistbus/eventbus/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ async def _on(self, id: str, handler: Handler):
self._subs[handler.event_type][id] = sub
self._dtypes[handler.event_type] = handler.dtype

async def _off(self, id: str, event_type: str):
if "*" in event_type:
del self._wildcard_subs[event_type][id]
else:
del self._subs[event_type][id]
del self._dtypes[event_type]

def _remove(self, id: str):
to_be_removed: List[str] = []
for route, subs in self._subs.items():
Expand Down
1 change: 1 addition & 0 deletions aiodistbus/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.2"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

[project]
name = "aiodistbus"
version = "0.0.1"
version = "0.0.2"
description = "ZeroMQ Distributed EventBus for Python"
authors = [
{name = "Eduardo Davalos", email="eduardo.davalos.anaya@vanderbilt.edu"},
Expand Down
23 changes: 23 additions & 0 deletions test/test_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,26 @@ async def test_local_bus_wildcard(bus, entrypoints):
# Assert
assert event.id in e1._received
assert len(e1._received) == 1


async def test_local_bus_off(bus, entrypoints):

# Create resources
e1, e2 = entrypoints

# Add funcs
await e1.on("test", func, ExampleEvent)

# Connect
await e1.connect(bus)
await e2.connect(bus)

# Remove
await e1.off("test")

# Send message
event = await e2.emit("test", ExampleEvent("Hello"))

# Assert
assert event.id not in e1._received
assert len(e1._received) == 0
25 changes: 25 additions & 0 deletions test/test_dbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,28 @@ async def test_dbus_emit_wildcard(dbus, dentrypoints):
# Assert
assert event1 and event1.id not in e1._received
assert event2 and event2.id in e1._received


async def test_dbus_off(dbus, dentrypoints):

# Create resources
e1, e2 = dentrypoints

# Add funcs
await e1.on("test", func, ExampleEvent)

# Connect
await e1.connect(dbus.ip, dbus.port)
await e2.connect(dbus.ip, dbus.port)

# Off
await e1.off("test")

# Send message
event1 = await e2.emit("test", ExampleEvent("Hello"))

# Need to flush
await dbus.flush()

# Assert
assert event1 and event1.id not in e1._received
2 changes: 1 addition & 1 deletion test/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger("aiodistbus")

N = 5
M = 1000
M = 100


@dataclass
Expand Down

0 comments on commit bd6710c

Please sign in to comment.