Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ApiHandlers subscription #205

Merged
merged 3 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions aiounifi/interfaces/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

from collections.abc import Callable, ItemsView, Iterator, ValuesView
import logging
from typing import TYPE_CHECKING, Any, Final, Generic, final
import enum
from typing import TYPE_CHECKING, Any, Final, Generic, Optional, final

from ..models import ResourceType
from ..models.request_object import RequestObject
Expand All @@ -14,10 +14,20 @@
from ..models.event import Event, EventKey
from ..models.message import Message, MessageKey

SubscriptionType = Callable[[str, str], None]

class ItemEvent(enum.Enum):
"""The event action of the item."""

ADDED = "added"
CHANGED = "changed"
DELETED = "deleted"


CallbackType = Callable[[ItemEvent, str], None]
SubscriptionType = tuple[CallbackType, Optional[tuple[ItemEvent, ...]]]
UnsubscribeType = Callable[[], None]

LOGGER = logging.getLogger(__name__)
ID_FILTER_ALL = "*"

SOURCE_DATA: Final = "data"
SOURCE_EVENT: Final = "event"
Expand All @@ -37,7 +47,7 @@ def __init__(self, controller: Controller) -> None:
"""Initialize API items."""
self.controller = controller
self._items: dict[int | str, ResourceType] = {}
self._subscribers: list[SubscriptionType] = []
self._subscribers: dict[str, list[SubscriptionType]] = {ID_FILTER_ALL: []}

if message_filter := self.process_messages + self.remove_messages:
controller.messages.subscribe(self.process_message, message_filter)
Expand Down Expand Up @@ -86,12 +96,11 @@ def process_item(self, raw: dict[str, Any]) -> str:
if (obj_id := raw[self.obj_id_key]) in self._items:
obj = self._items[obj_id]
obj.update(raw=raw)
self.signal_subscribers(ItemEvent.CHANGED, obj_id)
return ""

self._items[obj_id] = self.item_cls(raw, self.controller)

for callback in self._subscribers:
callback("added", obj_id)
self.signal_subscribers(ItemEvent.ADDED, obj_id)

return obj_id

Expand All @@ -102,19 +111,53 @@ def remove_item(self, raw: dict[str, Any]) -> str:
if (obj_id := raw[self.obj_id_key]) in self._items:
obj = self._items.pop(obj_id)
obj.clear_callbacks()
self.signal_subscribers(ItemEvent.DELETED, obj_id)
return obj_id
return ""

def subscribe(self, callback: SubscriptionType) -> UnsubscribeType:
def signal_subscribers(self, event: ItemEvent, obj_id: str) -> None:
"""Signal subscribers."""
subscribers: list[SubscriptionType] = (
self._subscribers.get(obj_id, []) + self._subscribers[ID_FILTER_ALL]
)
for callback, event_filter in subscribers:
if event_filter is not None and event not in event_filter:
continue
callback(event, obj_id)

def subscribe(
self,
callback: CallbackType,
event_filter: tuple[ItemEvent, ...] | ItemEvent | None = None,
id_filter: tuple[str] | str | None = None,
) -> UnsubscribeType:
"""Subscribe to added events.

"callback" - callback function to call when an event emits.
Return function to unsubscribe.
"""
self._subscribers.append(callback)
if isinstance(event_filter, ItemEvent):
event_filter = (event_filter,)
subscription = (callback, event_filter)

_id_filter: tuple[str]
if id_filter is None:
_id_filter = (ID_FILTER_ALL,)
elif isinstance(id_filter, str):
_id_filter = (id_filter,)

for obj_id in _id_filter:
if obj_id not in self._subscribers:
self._subscribers[obj_id] = []
self._subscribers[obj_id].append(subscription)

def unsubscribe() -> None:
self._subscribers.remove(callback)
for obj_id in _id_filter:
if obj_id not in self._subscribers:
continue
if subscription not in self._subscribers[obj_id]:
continue
self._subscribers[obj_id].remove(subscription)

return unsubscribe

Expand Down
134 changes: 134 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Test API handlers."""

from unittest.mock import Mock

import pytest

from aiounifi.interfaces.api_handlers import APIHandler, ItemEvent


@pytest.mark.parametrize(
"event_filter",
[
None,
{ItemEvent.ADDED, ItemEvent.CHANGED, ItemEvent.DELETED},
],
)
async def test_api_handler_subscriptions(event_filter):
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), event_filter)

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "1")

assert handler.remove_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "1")

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "2")

assert handler.process_raw([{"key": "2"}]) == set()
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "2")

assert handler.remove_item({"key": "2"}) == "2"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "2")

unsub()

unsub() # Empty list of object ID

handler._subscribers.clear()

unsub() # Object ID does not exist in subscribers


async def test_api_handler_subscriptions_event_filter_added():
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), ItemEvent.ADDED)

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
assert mock_subscribe_cb.call_count == 1

assert handler.remove_item({"key": "1"}) == "1"
assert mock_subscribe_cb.call_count == 1

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "2")

assert handler.process_raw([{"key": "2"}]) == set()
assert mock_subscribe_cb.call_count == 2

assert handler.remove_item({"key": "2"}) == "2"
assert mock_subscribe_cb.call_count == 2

unsub()


async def test_api_handler_subscriptions_id_filter():
"""Test process and remove item."""
handler = APIHandler(Mock())
handler.obj_id_key = "key"
handler.item_cls = Mock()

unsub = handler.subscribe(mock_subscribe_cb := Mock(), id_filter="1")

assert handler.process_item({}) == ""
mock_subscribe_cb.assert_not_called()

assert handler.process_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.ADDED, "1")

assert handler.process_item({"key": "1"}) == ""
mock_subscribe_cb.assert_called_with(ItemEvent.CHANGED, "1")

assert handler.remove_item({"key": "1"}) == "1"
mock_subscribe_cb.assert_called_with(ItemEvent.DELETED, "1")

assert handler.remove_item({"key": "2"}) == ""

# Process raw

assert handler.process_raw([{}]) == set()

assert handler.process_raw([{"key": "2"}]) == {"2"}
assert mock_subscribe_cb.call_count == 3

assert handler.process_raw([{"key": "2"}]) == set()
assert mock_subscribe_cb.call_count == 3

assert handler.remove_item({"key": "2"}) == "2"
assert mock_subscribe_cb.call_count == 3

unsub()