Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions roborock/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Module for managing callback utility functions."""

import logging
from collections.abc import Callable
from typing import Generic, TypeVar

_LOGGER = logging.getLogger(__name__)

K = TypeVar("K")
V = TypeVar("V")


def safe_callback(callback: Callable[[V], None], logger: logging.Logger | None = None) -> Callable[[V], None]:
"""Wrap a callback to catch and log exceptions.

This is useful for ensuring that errors in callbacks do not propagate
and cause unexpected behavior. Any failures during callback execution will be logged.
"""

if logger is None:
logger = _LOGGER

def wrapper(value: V) -> None:
try:
callback(value)
except Exception as ex: # noqa: BLE001
logger.error("Uncaught error in callback '%s': %s", callback.__name__, ex)

return wrapper


class CallbackMap(Generic[K, V]):
"""A mapping of callbacks for specific keys.

This allows for registering multiple callbacks for different keys and invoking them
when a value is received for a specific key.
"""

def __init__(self, logger: logging.Logger | None = None) -> None:
self._callbacks: dict[K, list[Callable[[V], None]]] = {}
self._logger = logger or _LOGGER

def keys(self) -> list[K]:
"""Get all keys in the callback map."""
return list(self._callbacks.keys())

def add_callback(self, key: K, callback: Callable[[V], None]) -> Callable[[], None]:
"""Add a callback for a specific key.

Any failures during callback execution will be logged.

Returns a callable that can be used to remove the callback.
"""
self._callbacks.setdefault(key, []).append(callback)

def remove_callback() -> None:
"""Remove the callback for the specific key."""
if cb_list := self._callbacks.get(key):
cb_list.remove(callback)
if not cb_list:
del self._callbacks[key]

return remove_callback

def get_callbacks(self, key: K) -> list[Callable[[V], None]]:
"""Get all callbacks for a specific key."""
return self._callbacks.get(key, [])

def __call__(self, key: K, value: V) -> None:
"""Invoke all callbacks for a specific key."""
for callback in self.get_callbacks(key):
safe_callback(callback, self._logger)(value)


class CallbackList(Generic[V]):
"""A list of callbacks that can be invoked.

This combines a list of callbacks into a single callable. Callers can add
additional callbacks to the list at any time.
"""

def __init__(self, logger: logging.Logger | None = None) -> None:
self._callbacks: list[Callable[[V], None]] = []
self._logger = logger or _LOGGER

def add_callback(self, callback: Callable[[V], None]) -> Callable[[], None]:
"""Add a callback to the list.

Any failures during callback execution will be logged.

Returns a callable that can be used to remove the callback.
"""
self._callbacks.append(callback)

return lambda: self._callbacks.remove(callback)

def __call__(self, value: V) -> None:
"""Invoke all callbacks in the list."""
for callback in self._callbacks:
safe_callback(callback, self._logger)(value)


def decoder_callback(
decoder: Callable[[K], list[V]], callback: Callable[[V], None], logger: logging.Logger | None = None
) -> Callable[[K], None]:
"""Create a callback that decodes messages using a decoder and invokes a callback.

The decoder converts a value into a list of values. The callback is then invoked
for each value in the list.

Any failures during decoding or invoking the callbacks will be logged.
"""
if logger is None:
logger = _LOGGER

safe_cb = safe_callback(callback, logger)

def wrapper(data: K) -> None:
if not (messages := decoder(data)):
logger.warning("Failed to decode message: %s", data)
return
for message in messages:
_LOGGER.debug("Decoded message: %s", message)
safe_cb(message)

return wrapper
25 changes: 5 additions & 20 deletions roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable
from dataclasses import dataclass

from roborock.callbacks import CallbackList, decoder_callback
from roborock.exceptions import RoborockConnectionException, RoborockException
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage
Expand Down Expand Up @@ -42,11 +43,13 @@ def __init__(self, host: str, local_key: str):
self._host = host
self._transport: asyncio.Transport | None = None
self._protocol: _LocalProtocol | None = None
self._subscribers: list[Callable[[RoborockMessage], None]] = []
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
self._is_connected = False

self._decoder: Decoder = create_local_decoder(local_key)
self._encoder: Encoder = create_local_encoder(local_key)
# Callback to decode messages and dispatch to subscribers
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -76,19 +79,6 @@ def close(self) -> None:
self._transport = None
self._is_connected = False

def _data_received(self, data: bytes) -> None:
"""Handle incoming data from the transport."""
if not (messages := self._decoder(data)):
_LOGGER.warning("Failed to decode local message: %s", data)
return
for message in messages:
_LOGGER.debug("Received message: %s", message)
for callback in self._subscribers:
try:
callback(message)
except Exception as e:
_LOGGER.exception("Uncaught error in message handler callback: %s", e)

def _connection_lost(self, exc: Exception | None) -> None:
"""Handle connection loss."""
_LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
Expand All @@ -97,12 +87,7 @@ def _connection_lost(self, exc: Exception | None) -> None:

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Subscribe to all messages from the device."""
self._subscribers.append(callback)

def unsubscribe() -> None:
self._subscribers.remove(callback)

return unsubscribe
return self._subscribers.add_callback(callback)

async def publish(self, message: RoborockMessage) -> None:
"""Send a command message.
Expand Down
16 changes: 3 additions & 13 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from collections.abc import Callable

from roborock.callbacks import decoder_callback
from roborock.containers import HomeDataDevice, RRiot, UserData
from roborock.exceptions import RoborockException
from roborock.mqtt.session import MqttParams, MqttSession, MqttSessionException
Expand Down Expand Up @@ -56,19 +57,8 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab

Returns a callable that can be used to unsubscribe from the topic.
"""

def message_handler(payload: bytes) -> None:
if not (messages := self._decoder(payload)):
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
return
for message in messages:
_LOGGER.debug("Received message: %s", message)
try:
callback(message)
except Exception as e:
_LOGGER.exception("Uncaught error in message handler callback: %s", e)

return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)
dispatch = decoder_callback(self._decoder, callback, _LOGGER)
return await self._mqtt_session.subscribe(self._subscribe_topic, dispatch)

async def publish(self, message: RoborockMessage) -> None:
"""Publish a command message.
Expand Down
20 changes: 7 additions & 13 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import aiomqtt
from aiomqtt import MqttError, TLSParameters

from roborock.callbacks import CallbackMap

from .session import MqttParams, MqttSession, MqttSessionException

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,7 +55,7 @@ def __init__(self, params: MqttParams):
self._backoff = MIN_BACKOFF_INTERVAL
self._client: aiomqtt.Client | None = None
self._client_lock = asyncio.Lock()
self._listeners: dict[str, list[Callable[[bytes], None]]] = {}
self._listeners: CallbackMap[str, bytes] = CallbackMap(_LOGGER)

@property
def connected(self) -> bool:
Expand Down Expand Up @@ -164,7 +166,7 @@ async def _mqtt_client(self, params: MqttParams) -> aiomqtt.Client:
# Re-establish any existing subscriptions
async with self._client_lock:
self._client = client
for topic in self._listeners:
for topic in self._listeners.keys():
_LOGGER.debug("Re-establishing subscription to topic %s", topic)
# TODO: If this fails it will break the whole connection. Make
# this retry again in the background with backoff.
Expand All @@ -179,13 +181,7 @@ async def _process_message_loop(self, client: aiomqtt.Client) -> None:
_LOGGER.debug("Processing MQTT messages")
async for message in client.messages:
_LOGGER.debug("Received message: %s", message)
for listener in self._listeners.get(message.topic.value, []):
try:
listener(message.payload)
except asyncio.CancelledError:
raise
except Exception as e:
_LOGGER.exception("Uncaught exception in subscriber callback: %s", e)
self._listeners(message.topic.value, message.payload)

async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Subscribe to messages on the specified topic and invoke the callback for new messages.
Expand All @@ -196,9 +192,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
The returned callable unsubscribes from the topic when called.
"""
_LOGGER.debug("Subscribing to topic %s", topic)
if topic not in self._listeners:
self._listeners[topic] = []
self._listeners[topic].append(callback)
unsub = self._listeners.add_callback(topic, callback)

async with self._client_lock:
if self._client:
Expand All @@ -210,7 +204,7 @@ async def subscribe(self, topic: str, callback: Callable[[bytes], None]) -> Call
else:
_LOGGER.debug("Client not connected, will establish subscription later")

return lambda: self._listeners[topic].remove(callback)
return unsub

async def publish(self, topic: str, message: bytes) -> None:
"""Publish a message on the topic."""
Expand Down
4 changes: 2 additions & 2 deletions tests/devices/test_local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def test_message_decode_error(local_channel: LocalChannel, caplog: pytest.

assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert "Failed to decode local message" in caplog.records[0].message
assert "Failed to decode message" in caplog.records[0].message


async def test_subscribe_callback(
Expand Down Expand Up @@ -181,7 +181,7 @@ def failing_callback(message: RoborockMessage) -> None:
await asyncio.sleep(0.01) # yield

# Should log the exception but not crash
assert any("Uncaught error in message handler callback" in record.message for record in caplog.records)
assert any("Uncaught error in callback 'failing_callback'" in record.message for record in caplog.records)


async def test_unsubscribe(local_channel: LocalChannel, mock_loop: Mock) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/devices/test_mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def test_message_decode_error(

assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"
assert "Failed to decode MQTT message" in caplog.records[0].message
assert "Failed to decode message" in caplog.records[0].message
unsub()


Expand Down Expand Up @@ -255,7 +255,7 @@ def failing_callback(message: RoborockMessage) -> None:
# Check that exception was logged
error_records = [record for record in caplog.records if record.levelname == "ERROR"]
assert len(error_records) == 1
assert "Uncaught error in message handler callback" in error_records[0].message
assert "Uncaught error in callback 'failing_callback'" in error_records[0].message

# Unsubscribe all remaining subscribers
unsub1()
Expand Down
Loading