From fde1d0fee435c629ee18f6d86bc165163799721f Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 20 Aug 2025 21:59:36 -0700 Subject: [PATCH 1/2] chore: unify callback handling recipes across mqtt and local channels --- roborock/callbacks.py | 139 ++++++++++++++++ roborock/devices/local_channel.py | 25 +-- roborock/devices/mqtt_channel.py | 16 +- roborock/mqtt/roborock_session.py | 20 +-- tests/devices/test_local_channel.py | 4 +- tests/devices/test_mqtt_channel.py | 4 +- tests/test_callbacks.py | 249 ++++++++++++++++++++++++++++ 7 files changed, 407 insertions(+), 50 deletions(-) create mode 100644 roborock/callbacks.py create mode 100644 tests/test_callbacks.py diff --git a/roborock/callbacks.py b/roborock/callbacks.py new file mode 100644 index 00000000..e3f13c2b --- /dev/null +++ b/roborock/callbacks.py @@ -0,0 +1,139 @@ +"""Module for managing callback utility functions.""" + +import logging +from collections.abc import Callable +from typing import Generic, TypeVar + +_LOGGER = logging.getLogger(__name__) + +K = TypeVar("K") +V = TypeVar("V") + + +def safe_callback(callback: Callable[[V], None], logger: logging.Logger | None = None) -> Callable[[V], None]: + """Wrap a callback to catch and log exceptions. + + This is useful for ensuring that errors in callbacks do not propagate + and cause unexpected behavior. Any failures during callback execution will be logged. + """ + + if logger is None: + logger = _LOGGER + + def wrapper(value: V) -> None: + try: + callback(value) + except Exception as ex: # noqa: BLE001 + logger.error("Uncaught error in callback '%s': %s", callback.__name__, ex) + + return wrapper + + +class CallbackMap(Generic[K, V]): + """A mapping of callbacks for specific keys.""" + + def __init__(self, logger: logging.Logger | None = None) -> None: + self._callbacks: dict[K, list[Callable[[V], None]]] = {} + self._logger = logger or _LOGGER + + def keys(self) -> list[K]: + """Get all keys in the callback map.""" + return list(self._callbacks.keys()) + + def add_callback(self, key: K, callback: Callable[[V], None]) -> Callable[[], None]: + """Add a callback for a specific key. + + Any failures during callback execution will be logged. + + Returns a callable that can be used to remove the callback. + """ + self._callbacks.setdefault(key, []).append(callback) + + def remove_callback() -> None: + """Remove the callback for the specific key.""" + if cb_list := self._callbacks.get(key): + cb_list.remove(callback) + if not cb_list: + del self._callbacks[key] + + return remove_callback + + def get_callbacks(self, key: K) -> list[Callable[[V], None]]: + """Get all callbacks for a specific key.""" + return self._callbacks.get(key, []) + + def __call__(self, key: K, value: V) -> None: + """Invoke all callbacks for a specific key.""" + for callback in self.get_callbacks(key): + safe_callback(callback, self._logger)(value) + + +class CallbackList(Generic[V]): + """A list of callbacks for specific keys.""" + + def __init__(self, logger: logging.Logger | None = None) -> None: + self._callbacks: list[Callable[[V], None]] = [] + self._logger = logger or _LOGGER + + def add_callback(self, callback: Callable[[V], None]) -> Callable[[], None]: + """Add a callback to the list. + + Any failures during callback execution will be logged. + + Returns a callable that can be used to remove the callback. + """ + self._callbacks.append(callback) + + return lambda: self._callbacks.remove(callback) + + def __call__(self, value: V) -> None: + """Invoke all callbacks in the list.""" + for callback in self._callbacks: + safe_callback(callback, self._logger)(value) + + +def decoder_callback( + decoder: Callable[[K], list[V]], callback: Callable[[V], None], logger: logging.Logger | None = None +) -> Callable[[K], None]: + """Create a callback that decodes messages using a decoder and invokes a callback. + + Any failures during decoding will be logged. + """ + if logger is None: + logger = _LOGGER + + safe_cb = safe_callback(callback, logger) + + def wrapper(data: K) -> None: + if not (messages := decoder(data)): + logger.warning("Failed to decode message: %s", data) + return + for message in messages: + _LOGGER.debug("Decoded message: %s", message) + safe_cb(message) + + return wrapper + + + +def dipspatch_callback( + callback: Callable[[V], None], logger: logging.Logger | None = None +) -> Callable[[list[V]], None]: + """Create a callback that decodes messages using a decoder and invokes a callback. + + Any failures during decoding will be logged. + """ + if logger is None: + logger = _LOGGER + + safe_cb = safe_callback(callback, logger) + + def wrapper(data: K) -> None: + if not (messages := decoder(data)): + logger.warning("Failed to decode message: %s", data) + return + for message in messages: + _LOGGER.debug("Decoded message: %s", message) + safe_cb(message) + + return wrapper diff --git a/roborock/devices/local_channel.py b/roborock/devices/local_channel.py index 1401368e..eda1305a 100644 --- a/roborock/devices/local_channel.py +++ b/roborock/devices/local_channel.py @@ -5,6 +5,7 @@ from collections.abc import Callable from dataclasses import dataclass +from roborock.callbacks import CallbackList, decoder_callback from roborock.exceptions import RoborockConnectionException, RoborockException from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder from roborock.roborock_message import RoborockMessage @@ -42,11 +43,13 @@ def __init__(self, host: str, local_key: str): self._host = host self._transport: asyncio.Transport | None = None self._protocol: _LocalProtocol | None = None - self._subscribers: list[Callable[[RoborockMessage], None]] = [] + self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER) self._is_connected = False self._decoder: Decoder = create_local_decoder(local_key) self._encoder: Encoder = create_local_encoder(local_key) + # Callback to decode messages and dispatch to subscribers + self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER) @property def is_connected(self) -> bool: @@ -76,19 +79,6 @@ def close(self) -> None: self._transport = None self._is_connected = False - def _data_received(self, data: bytes) -> None: - """Handle incoming data from the transport.""" - if not (messages := self._decoder(data)): - _LOGGER.warning("Failed to decode local message: %s", data) - return - for message in messages: - _LOGGER.debug("Received message: %s", message) - for callback in self._subscribers: - try: - callback(message) - except Exception as e: - _LOGGER.exception("Uncaught error in message handler callback: %s", e) - def _connection_lost(self, exc: Exception | None) -> None: """Handle connection loss.""" _LOGGER.warning("Connection lost to %s", self._host, exc_info=exc) @@ -97,12 +87,7 @@ def _connection_lost(self, exc: Exception | None) -> None: async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: """Subscribe to all messages from the device.""" - self._subscribers.append(callback) - - def unsubscribe() -> None: - self._subscribers.remove(callback) - - return unsubscribe + return self._subscribers.add_callback(callback) async def publish(self, message: RoborockMessage) -> None: """Send a command message. diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 79877115..55e469ee 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable +from roborock.callbacks import decoder_callback from roborock.containers import HomeDataDevice, RRiot, UserData from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException @@ -56,19 +57,8 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab Returns a callable that can be used to unsubscribe from the topic. """ - - def message_handler(payload: bytes) -> None: - if not (messages := self._decoder(payload)): - _LOGGER.warning("Failed to decode MQTT message: %s", payload) - return - for message in messages: - _LOGGER.debug("Received message: %s", message) - try: - callback(message) - except Exception as e: - _LOGGER.exception("Uncaught error in message handler callback: %s", e) - - return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler) + dispatch = decoder_callback(self._decoder, callback, _LOGGER) + return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch) async def publish(self, message: RoborockMessage) -> None: """Publish a command message. diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index b253e6f3..026ed7c2 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -17,6 +17,8 @@ import aiomqtt from aiomqtt import MqttError, TLSParameters +from roborock.callbacks import CallbackMap + from .session import MqttParams, MqttSession, MqttSessionException _LOGGER = logging.getLogger(__name__) @@ -53,7 +55,7 @@ def __init__(self, params: MqttParams): self._backoff = MIN_BACKOFF_INTERVAL self._client: aiomqtt.Client | None = None self._client_lock = asyncio.Lock() - self._listeners: dict[str, list[Callable[[bytes], None]]] = {} + self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER) @property def connected(self) -> bool: @@ -164,7 +166,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client: # Re-establish any existing subscriptions async with self._client_lock: self._client = client - for topic in self._listeners: + for topic in self._listeners.keys(): _LOGGER.debug("Re-establishing subscription to topic %s", topic) # TODO: If this fails it will break the whole connection. Make # this retry again in the background with backoff. @@ -179,13 +181,7 @@ async def _process_message_loop(self, client: aiomqtt.Client) -> None: _LOGGER.debug("Processing MQTT messages") async for message in client.messages: _LOGGER.debug("Received message: %s", message) - for listener in self._listeners.get(message.topic.value, []): - try: - listener(message.payload) - except asyncio.CancelledError: - raise - except Exception as e: - _LOGGER.exception("Uncaught exception in subscriber callback: %s", e) + self._listeners(message.topic.value, message.payload) async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]: """Subscribe to messages on the specified topic and invoke the callback for new messages. @@ -196,9 +192,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call The returned callable unsubscribes from the topic when called. """ _LOGGER.debug("Subscribing to topic %s", topic) - if topic not in self._listeners: - self._listeners[topic] = [] - self._listeners[topic].append(callback) + unsub = self._listeners.add_callback(topic, callback) async with self._client_lock: if self._client: @@ -210,7 +204,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call else: _LOGGER.debug("Client not connected, will establish subscription later") - return lambda: self._listeners[topic].remove(callback) + return unsub async def publish(self, topic: str, message: bytes) -> None: """Publish a message on the topic.""" diff --git a/tests/devices/test_local_channel.py b/tests/devices/test_local_channel.py index 6fb7485a..9bc02e53 100644 --- a/tests/devices/test_local_channel.py +++ b/tests/devices/test_local_channel.py @@ -148,7 +148,7 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest. assert len(caplog.records) == 1 assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode local message" in caplog.records[0].message + assert "Failed to decode message" in caplog.records[0].message async def test_subscribe_callback( @@ -181,7 +181,7 @@ def failing_callback(message: RoborockMessage) -> None: await asyncio.sleep(0.01) # yield # Should log the exception but not crash - assert any("Uncaught error in message handler callback" in record.message for record in caplog.records) + assert any("Uncaught error in callback 'failing_callback'" in record.message for record in caplog.records) async def test_unsubscribe(local_channel: LocalChannel, mock_loop: Mock) -> None: diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index 44ac2bbf..f385596a 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -150,7 +150,7 @@ async def test_message_decode_error( assert len(caplog.records) == 1 assert caplog.records[0].levelname == "WARNING" - assert "Failed to decode MQTT message" in caplog.records[0].message + assert "Failed to decode message" in caplog.records[0].message unsub() @@ -255,7 +255,7 @@ def failing_callback(message: RoborockMessage) -> None: # Check that exception was logged error_records = [record for record in caplog.records if record.levelname == "ERROR"] assert len(error_records) == 1 - assert "Uncaught error in message handler callback" in error_records[0].message + assert "Uncaught error in callback 'failing_callback'" in error_records[0].message # Unsubscribe all remaining subscribers unsub1() diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 00000000..afd4698b --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,249 @@ +"""Tests for the callbacks module.""" + +import logging +from unittest.mock import Mock + +from roborock.callbacks import CallbackList, CallbackMap, safe_callback + + +def test_safe_callback_successful_execution(): + """Test that safe_callback executes callback successfully.""" + mock_callback = Mock() + wrapped = safe_callback(mock_callback) + + wrapped("test_value") + + mock_callback.assert_called_once_with("test_value") + + +def test_safe_callback_catches_exception(): + """Test that safe_callback catches and logs exceptions.""" + + def failing_callback(value): + raise ValueError("Test exception") + + mock_logger = Mock(spec=logging.Logger) + wrapped = safe_callback(failing_callback, mock_logger) + + # Should not raise exception + wrapped("test_value") + + mock_logger.error.assert_called_once() + assert "Uncaught error in callback" in mock_logger.error.call_args[0][0] + + +def test_safe_callback_uses_default_logger(): + """Test that safe_callback uses default logger when none provided.""" + + def failing_callback(value): + raise ValueError("Test exception") + + wrapped = safe_callback(failing_callback) + + # Should not raise exception + wrapped("test_value") + + +# CallbackMap tests + + +def test_callback_map_add_callback_and_invoke(): + """Test adding callback and invoking it.""" + callback_map = CallbackMap[str, str]() + mock_callback = Mock() + + remove_fn = callback_map.add_callback("key1", mock_callback) + callback_map("key1", "test_value") + + mock_callback.assert_called_once_with("test_value") + assert callable(remove_fn) + + +def test_callback_map_multiple_callbacks_same_key(): + """Test multiple callbacks for the same key.""" + callback_map = CallbackMap[str, str]() + mock_callback1 = Mock() + mock_callback2 = Mock() + + callback_map.add_callback("key1", mock_callback1) + callback_map.add_callback("key1", mock_callback2) + callback_map("key1", "test_value") + + mock_callback1.assert_called_once_with("test_value") + mock_callback2.assert_called_once_with("test_value") + + +def test_callback_map_different_keys(): + """Test callbacks for different keys.""" + callback_map = CallbackMap[str, str]() + mock_callback1 = Mock() + mock_callback2 = Mock() + + callback_map.add_callback("key1", mock_callback1) + callback_map.add_callback("key2", mock_callback2) + + callback_map("key1", "value1") + callback_map("key2", "value2") + + mock_callback1.assert_called_once_with("value1") + mock_callback2.assert_called_once_with("value2") + + +def test_callback_map_get_callbacks(): + """Test getting callbacks for a key.""" + callback_map = CallbackMap[str, str]() + mock_callback = Mock() + + # No callbacks initially + assert callback_map.get_callbacks("key1") == [] + + # Add callback + callback_map.add_callback("key1", mock_callback) + callbacks = callback_map.get_callbacks("key1") + + assert len(callbacks) == 1 + assert callbacks[0] == mock_callback + + +def test_callback_map_remove_callback(): + """Test removing callback.""" + callback_map = CallbackMap[str, str]() + mock_callback = Mock() + + remove_fn = callback_map.add_callback("key1", mock_callback) + + # Callback should be there + assert len(callback_map.get_callbacks("key1")) == 1 + + # Remove callback + remove_fn() + + # Callback should be gone + assert callback_map.get_callbacks("key1") == [] + + +def test_callback_map_remove_callback_cleans_up_key(): + """Test that removing last callback for a key removes the key.""" + callback_map = CallbackMap[str, str]() + mock_callback = Mock() + + remove_fn = callback_map.add_callback("key1", mock_callback) + + # Key should exist + assert "key1" in callback_map._callbacks + + # Remove callback + remove_fn() + + # Key should be removed + assert "key1" not in callback_map._callbacks + + +def test_callback_map_exception_handling(caplog): + """Test that exceptions in callbacks are handled gracefully.""" + callback_map = CallbackMap[str, str]() + + def failing_callback(value): + raise ValueError("Test exception") + + callback_map.add_callback("key1", failing_callback) + + with caplog.at_level(logging.ERROR): + callback_map("key1", "test_value") + + assert "Uncaught error in callback" in caplog.text + + +def test_callback_map_custom_logger(): + """Test using custom logger.""" + mock_logger = Mock(spec=logging.Logger) + callback_map = CallbackMap[str, str](logger=mock_logger) + + def failing_callback(value): + raise ValueError("Test exception") + + callback_map.add_callback("key1", failing_callback) + callback_map("key1", "test_value") + + mock_logger.error.assert_called_once() + + +# CallbackList tests + + +def test_callback_list_add_callback_and_invoke(): + """Test adding callback and invoking it.""" + callback_list = CallbackList[str]() + mock_callback = Mock() + + remove_fn = callback_list.add_callback(mock_callback) + callback_list("test_value") + + mock_callback.assert_called_once_with("test_value") + assert callable(remove_fn) + + +def test_callback_list_multiple_callbacks(): + """Test multiple callbacks in the list.""" + callback_list = CallbackList[str]() + mock_callback1 = Mock() + mock_callback2 = Mock() + + callback_list.add_callback(mock_callback1) + callback_list.add_callback(mock_callback2) + callback_list("test_value") + + mock_callback1.assert_called_once_with("test_value") + mock_callback2.assert_called_once_with("test_value") + + +def test_callback_list_remove_callback(): + """Test removing callback from list.""" + callback_list = CallbackList[str]() + mock_callback1 = Mock() + mock_callback2 = Mock() + + remove_fn1 = callback_list.add_callback(mock_callback1) + callback_list.add_callback(mock_callback2) + + # Both should be called + callback_list("test_value") + assert mock_callback1.call_count == 1 + assert mock_callback2.call_count == 1 + + # Remove first callback + remove_fn1() + + # Only second should be called + callback_list("test_value2") + assert mock_callback1.call_count == 1 # Still 1 + assert mock_callback2.call_count == 2 # Now 2 + + +def test_callback_list_exception_handling(caplog): + """Test that exceptions in callbacks are handled gracefully.""" + callback_list = CallbackList[str]() + + def failing_callback(value): + raise ValueError("Test exception") + + callback_list.add_callback(failing_callback) + + with caplog.at_level(logging.ERROR): + callback_list("test_value") + + assert "Uncaught error in callback" in caplog.text + + +def test_callback_list_custom_logger(): + """Test using custom logger.""" + mock_logger = Mock(spec=logging.Logger) + callback_list = CallbackList[str](logger=mock_logger) + + def failing_callback(value): + raise ValueError("Test exception") + + callback_list.add_callback(failing_callback) + callback_list("test_value") + + mock_logger.error.assert_called_once() From ad4d0556360c80b5b61ca483c208bd1fe676b9be Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 20 Aug 2025 22:04:26 -0700 Subject: [PATCH 2/2] chore: fix style and comments --- roborock/callbacks.py | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/roborock/callbacks.py b/roborock/callbacks.py index e3f13c2b..63868042 100644 --- a/roborock/callbacks.py +++ b/roborock/callbacks.py @@ -30,7 +30,11 @@ def wrapper(value: V) -> None: class CallbackMap(Generic[K, V]): - """A mapping of callbacks for specific keys.""" + """A mapping of callbacks for specific keys. + + This allows for registering multiple callbacks for different keys and invoking them + when a value is received for a specific key. + """ def __init__(self, logger: logging.Logger | None = None) -> None: self._callbacks: dict[K, list[Callable[[V], None]]] = {} @@ -69,7 +73,11 @@ def __call__(self, key: K, value: V) -> None: class CallbackList(Generic[V]): - """A list of callbacks for specific keys.""" + """A list of callbacks that can be invoked. + + This combines a list of callbacks into a single callable. Callers can add + additional callbacks to the list at any time. + """ def __init__(self, logger: logging.Logger | None = None) -> None: self._callbacks: list[Callable[[V], None]] = [] @@ -97,31 +105,10 @@ def decoder_callback( ) -> Callable[[K], None]: """Create a callback that decodes messages using a decoder and invokes a callback. - Any failures during decoding will be logged. - """ - if logger is None: - logger = _LOGGER - - safe_cb = safe_callback(callback, logger) - - def wrapper(data: K) -> None: - if not (messages := decoder(data)): - logger.warning("Failed to decode message: %s", data) - return - for message in messages: - _LOGGER.debug("Decoded message: %s", message) - safe_cb(message) - - return wrapper - - - -def dipspatch_callback( - callback: Callable[[V], None], logger: logging.Logger | None = None -) -> Callable[[list[V]], None]: - """Create a callback that decodes messages using a decoder and invokes a callback. + The decoder converts a value into a list of values. The callback is then invoked + for each value in the list. - Any failures during decoding will be logged. + Any failures during decoding or invoking the callbacks will be logged. """ if logger is None: logger = _LOGGER