Skip to content

Commit

Permalink
Improve ApiHandlers subscription (#205)
Browse files Browse the repository at this point in the history
* Improve ApiHandlers subscription

* Add tests

* Improve tests
  • Loading branch information
Kane610 committed Oct 11, 2022
1 parent 1dc517d commit a56269b
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 11 deletions.
65 changes: 54 additions & 11 deletions aiounifi/interfaces/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

from collections.abc import Callable, ItemsView, Iterator, ValuesView
import logging
from typing import TYPE_CHECKING, Any, Final, Generic, final
import enum
from typing import TYPE_CHECKING, Any, Final, Generic, Optional, final

from ..models import ResourceType
from ..models.request_object import RequestObject
Expand All @@ -14,10 +14,20 @@
from ..models.event import Event, EventKey
from ..models.message import Message, MessageKey

SubscriptionType = Callable[[str, str], None]

class ItemEvent(enum.Enum):
"""The event action of the item."""

ADDED = "added"
CHANGED = "changed"
DELETED = "deleted"


CallbackType = Callable[[ItemEvent, str], None]
SubscriptionType = tuple[CallbackType, Optional[tuple[ItemEvent, ...]]]
UnsubscribeType = Callable[[], None]

LOGGER = logging.getLogger(__name__)
ID_FILTER_ALL = "*"

SOURCE_DATA: Final = "data"
SOURCE_EVENT: Final = "event"
Expand All @@ -37,7 +47,7 @@ def __init__(self, controller: Controller) -> None:
"""Initialize API items."""
self.controller = controller
self._items: dict[int | str, ResourceType] = {}
self._subscribers: list[SubscriptionType] = []
self._subscribers: dict[str, list[SubscriptionType]] = {ID_FILTER_ALL: []}

if message_filter := self.process_messages + self.remove_messages:
controller.messages.subscribe(self.process_message, message_filter)
Expand Down Expand Up @@ -86,12 +96,11 @@ def process_item(self, raw: dict[str, Any]) -> str:
if (obj_id := raw[self.obj_id_key]) in self._items:
obj = self._items[obj_id]
obj.update(raw=raw)
self.signal_subscribers(ItemEvent.CHANGED, obj_id)
return ""

self._items[obj_id] = self.item_cls(raw, self.controller)

for callback in self._subscribers:
callback("added", obj_id)
self.signal_subscribers(ItemEvent.ADDED, obj_id)

return obj_id

Expand All @@ -102,19 +111,53 @@ def remove_item(self, raw: dict[str, Any]) -> str:
if (obj_id := raw[self.obj_id_key]) in self._items:
obj = self._items.pop(obj_id)
obj.clear_callbacks()
self.signal_subscribers(ItemEvent.DELETED, obj_id)
return obj_id
return ""

def subscribe(self, callback: SubscriptionType) -> UnsubscribeType:
def signal_subscribers(self, event: ItemEvent, obj_id: str) -> None:
"""Signal subscribers."""
subscribers: list[SubscriptionType] = (
self._subscribers.get(obj_id, []) + self._subscribers[ID_FILTER_ALL]
)
for callback, event_filter in subscribers:
if event_filter is not None and event not in event_filter:
continue
callback(event, obj_id)

def subscribe(
self,
callback: CallbackType,
event_filter: tuple[ItemEvent, ...] | ItemEvent | None = None,
id_filter: tuple[str] | str | None = None,
) -> UnsubscribeType:
"""Subscribe to added events.
"callback" - callback function to call when an event emits.
Return function to unsubscribe.
"""
self._subscribers.append(callback)
if isinstance(event_filter, ItemEvent):
event_filter = (event_filter,)
subscription = (callback, event_filter)

_id_filter: tuple[str]
if id_filter is None:
_id_filter = (ID_FILTER_ALL,)
elif isinstance(id_filter, str):
_id_filter = (id_filter,)

for obj_id in _id_filter:
if obj_id not in self._subscribers:
self._subscribers[obj_id] = []
self._subscribers[obj_id].append(subscription)

def unsubscribe() -> None:
self._subscribers.remove(callback)
for obj_id in _id_filter:
if obj_id not in self._subscribers:
continue
if subscription not in self._subscribers[obj_id]:
continue
self._subscribers[obj_id].remove(subscription)

return unsubscribe

Expand Down
134 changes: 134 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Test API handlers."""

from unittest.mock import Mock

import pytest

from aiounifi.interfaces.api_handlers import APIHandler, ItemEvent


@pytest.mark.parametrize(
"event_filter",
[
None,
{ItemEvent.ADDED, ItemEvent.CHANGED, ItemEvent.DELETED},
],
)
async def test_api_handler_subscriptions(event_filter):
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), event_filter)

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "1")

assert handler.remove_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "1")

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "2")

assert handler.process_raw([{"key": "2"}]) == set()
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "2")

assert handler.remove_item({"key": "2"}) == "2"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "2")

unsub()

unsub() # Empty list of object ID

handler._subscribers.clear()

unsub() # Object ID does not exist in subscribers


async def test_api_handler_subscriptions_event_filter_added():
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), ItemEvent.ADDED)

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
assert mock_subscribe_cb.call_count == 1

assert handler.remove_item({"key": "1"}) == "1"
assert mock_subscribe_cb.call_count == 1

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "2")

assert handler.process_raw([{"key": "2"}]) == set()
assert mock_subscribe_cb.call_count == 2

assert handler.remove_item({"key": "2"}) == "2"
assert mock_subscribe_cb.call_count == 2

unsub()


async def test_api_handler_subscriptions_id_filter():
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), id_filter="1")

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "1")

assert handler.remove_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "1")

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
assert mock_subscribe_cb.call_count == 3

assert handler.process_raw([{"key": "2"}]) == set()
assert mock_subscribe_cb.call_count == 3

assert handler.remove_item({"key": "2"}) == "2"
assert mock_subscribe_cb.call_count == 3

unsub()

0 comments on commit a56269b

Please sign in to comment.