From e640e47819064f07ef21df0bdb6629180a747935 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Tue, 1 Jul 2025 08:33:23 -0700 Subject: [PATCH 1/9] feat: Update device manager and device to establish an MQTT subscription --- roborock/cli.py | 49 +++++++++--------------------- roborock/devices/device.py | 38 +++++++++++++++++++++-- roborock/devices/device_manager.py | 44 ++++++++++++++++++++++----- roborock/devices/mqtt_channel.py | 44 +++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 roborock/devices/mqtt_channel.py diff --git a/roborock/cli.py b/roborock/cli.py index 4532ca21..0e6881eb 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -12,9 +12,9 @@ from pyshark.packet.packet import Packet # type: ignore from roborock import RoborockException -from roborock.containers import DeviceData, HomeDataProduct, LoginData -from roborock.mqtt.roborock_session import create_mqtt_session -from roborock.protocol import MessageParser, create_mqtt_params +from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData +from roborock.devices.device_manager import create_device_manager, create_home_data_api +from roborock.protocol import MessageParser from roborock.util import run_sync from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 @@ -101,44 +101,25 @@ async def session(ctx, duration: int): context: RoborockContext = ctx.obj login_data = context.login_data() - # Discovery devices if not already available - if not login_data.home_data: - await _discover(ctx) - login_data = context.login_data() - if not login_data.home_data or not login_data.home_data.devices: - raise RoborockException("Unable to discover devices") - - all_devices = login_data.home_data.devices + login_data.home_data.received_devices - click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}") - - rriot = login_data.user_data.rriot - params = create_mqtt_params(rriot) - - mqtt_session = await create_mqtt_session(params) - click.echo("Starting MQTT session...") - if not mqtt_session.connected: - raise RoborockException("Failed to connect to MQTT broker") + home_data_api = create_home_data_api(login_data.email, login_data.user_data) - def on_message(bytes: bytes): - """Callback function to handle incoming MQTT messages.""" - # Decode the first 20 bytes of the message for display - bytes = bytes[:20] + async def home_data_cache() -> HomeData: + if login_data.home_data is None: + login_data.home_data = await home_data_api() + context.update(login_data) + return login_data.home_data - click.echo(f"Received message: {bytes}...") + # Create device manager + device_manager = await create_device_manager(login_data.user_data, home_data_cache) - unsubs = [] - for device in all_devices: - device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}" - unsub = await mqtt_session.subscribe(device_topic, on_message) - unsubs.append(unsub) + devices = await device_manager.get_devices() + click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}") click.echo("MQTT session started. Listening for messages...") await asyncio.sleep(duration) - click.echo("Stopping MQTT session...") - for unsub in unsubs: - unsub() - await mqtt_session.close() + # Close the device manager (this will close all devices and MQTT session) + await device_manager.close() async def _discover(ctx): diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 926be8c4..fbf37f5b 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -10,6 +10,8 @@ from roborock.containers import HomeDataDevice, HomeDataProduct, UserData +from .mqtt_channel import MqttChannel + _LOGGER = logging.getLogger(__name__) __all__ = [ @@ -29,11 +31,22 @@ class DeviceVersion(enum.StrEnum): class RoborockDevice: """Unified Roborock device class with automatic connection setup.""" - def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None: - """Initialize the RoborockDevice with device info, user data, and capabilities.""" + def __init__( + self, + user_data: UserData, + device_info: HomeDataDevice, + product_info: HomeDataProduct, + mqtt_channel: MqttChannel, + ) -> None: + """Initialize the RoborockDevice. + + The device takes ownership of the MQTT channel for communication with the device + and will close it when the device is closed. + """ self._user_data = user_data self._device_info = device_info self._product_info = product_info + self._mqtt_channel = mqtt_channel @property def duid(self) -> str: @@ -63,3 +76,24 @@ def device_version(self) -> str: self._device_info.name, ) return DeviceVersion.UNKNOWN + + async def connect(self) -> None: + """Connect to the device using MQTT. + + This method will set up the MQTT channel for communication with the device. + """ + await self._mqtt_channel.subscribe(self._on_mqtt_message) + + async def close(self) -> None: + """Close the MQTT connection to the device. + + This method will unsubscribe from the MQTT channel and clean up resources. + """ + await self._mqtt_channel.close() + + def _on_mqtt_message(self, message: bytes) -> None: + """Handle incoming MQTT messages from the device. + + This method should be overridden in subclasses to handle specific device messages. + """ + _LOGGER.debug("Received message from device %s: %s", self.duid, message[:50]) # Log first 50 bytes for brevity diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 3a95dd13..b664fc3d 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -10,8 +10,13 @@ UserData, ) from roborock.devices.device import RoborockDevice +from roborock.mqtt.roborock_session import create_mqtt_session +from roborock.mqtt.session import MqttSession +from roborock.protocol import create_mqtt_params from roborock.web_api import RoborockApiClient +from .mqtt_channel import MqttChannel + _LOGGER = logging.getLogger(__name__) __all__ = [ @@ -34,11 +39,16 @@ def __init__( self, home_data_api: HomeDataApi, device_creator: DeviceCreator, + mqtt_session: MqttSession, ) -> None: - """Initialize the DeviceManager with user data and optional cache storage.""" + """Initialize the DeviceManager with user data and optional cache storage. + + This takes ownership of the MQTT session and will close it when the manager is closed. + """ self._home_data_api = home_data_api self._device_creator = device_creator self._devices: dict[str, RoborockDevice] = {} + self._mqtt_session = mqtt_session async def discover_devices(self) -> list[RoborockDevice]: """Discover all devices for the logged-in user.""" @@ -46,9 +56,15 @@ async def discover_devices(self) -> list[RoborockDevice]: device_products = home_data.device_products _LOGGER.debug("Discovered %d devices %s", len(device_products), home_data) - self._devices = { - duid: self._device_creator(device, product) for duid, (device, product) in device_products.items() - } + new_devices = {} + for duid, (device, product) in device_products.items(): + if duid in self._devices: + continue + new_device = self._device_creator(device, product) + await new_device.connect() + new_devices[duid] = new_device + + self._devices.update(new_devices) return list(self._devices.values()) async def get_device(self, duid: str) -> RoborockDevice | None: @@ -59,6 +75,14 @@ async def get_devices(self) -> list[RoborockDevice]: """Get all discovered devices.""" return list(self._devices.values()) + async def close(self) -> None: + """Close all MQTT connections and clean up resources.""" + for device in self._devices.values(): + await device.close() + self._devices.clear() + if self._mqtt_session: + await self._mqtt_session.close() + def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi: """Create a home data API wrapper. @@ -67,7 +91,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi: home data for the user. """ - client = RoborockApiClient(email, user_data) + # Note: This will auto discover the API base URL. This can be improved + # by caching this next to `UserData` if needed to avoid unnecessary API calls. + client = RoborockApiClient(email) async def home_data_api() -> HomeData: return await client.get_home_data(user_data) @@ -83,9 +109,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) include caching or other optimizations. """ + mqtt_params = create_mqtt_params(user_data.rriot) + mqtt_session = await create_mqtt_session(mqtt_params) + def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: - return RoborockDevice(user_data, device, product) + mqtt_channel = MqttChannel(mqtt_session, device.duid, user_data.rriot, mqtt_params) + return RoborockDevice(user_data, device, product, mqtt_channel) - manager = DeviceManager(home_data_api, device_creator) + manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session) await manager.discover_devices() return manager diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py new file mode 100644 index 00000000..f8095b3b --- /dev/null +++ b/roborock/devices/mqtt_channel.py @@ -0,0 +1,44 @@ +import logging +from collections.abc import Callable + +from roborock.containers import RRiot +from roborock.mqtt.session import MqttParams, MqttSession + +_LOGGER = logging.getLogger(__name__) + + +class MqttChannel: + """RPC-style channel for communicating with a specific device over MQTT. + + This currently only supports listening to messages and does not yet + support RPC functionality. + """ + + def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_params: MqttParams): + self._mqtt_session = mqtt_session + self._duid = duid + self._rriot = rriot + self._mqtt_params = mqtt_params + self._unsub: Callable[[], None] | None = None + + @property + def _publish_topic(self) -> str: + """Topic to send commands to the device.""" + return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" + + @property + def _subscribe_topic(self) -> str: + """Topic to receive responses from the device.""" + return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" + + async def subscribe(self, callback: Callable[[bytes], None]) -> None: + """Subscribe to the device's response topic.""" + if self._unsub: + raise ValueError("Already subscribed to the response topic") + self._unsub = await self._mqtt_session.subscribe(self._subscribe_topic, callback) + + async def close(self) -> None: + """Close the MQTT subscription.""" + if self._unsub: + self._unsub() + self._unsub = None From 8d16c95bb648d78854a4fbaa989d372e42c2b064 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Tue, 1 Jul 2025 21:58:54 -0700 Subject: [PATCH 2/9] feat: Add test coverage to device modules --- roborock/devices/device.py | 16 +++++-- roborock/devices/mqtt_channel.py | 18 ++++---- tests/devices/test_device.py | 40 +++++++++++++++++ tests/devices/test_device_manager.py | 8 ++++ tests/devices/test_mqtt_channel.py | 64 ++++++++++++++++++++++++++++ 5 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 tests/devices/test_device.py create mode 100644 tests/devices/test_mqtt_channel.py diff --git a/roborock/devices/device.py b/roborock/devices/device.py index fbf37f5b..0ac43b28 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -6,6 +6,7 @@ import enum import logging +from collections.abc import Callable from functools import cached_property from roborock.containers import HomeDataDevice, HomeDataProduct, UserData @@ -40,13 +41,16 @@ def __init__( ) -> None: """Initialize the RoborockDevice. - The device takes ownership of the MQTT channel for communication with the device - and will close it when the device is closed. + The device takes ownership of the MQTT channel for communication with the device. + Use `connect()` to establish the connection, which will set up the MQTT channel + for receiving messages from the device. Use `close()` to unsubscribe from the MQTT + channel. """ self._user_data = user_data self._device_info = device_info self._product_info = product_info self._mqtt_channel = mqtt_channel + self._unsub: Callable[[], None] | None = None @property def duid(self) -> str: @@ -82,14 +86,18 @@ async def connect(self) -> None: This method will set up the MQTT channel for communication with the device. """ - await self._mqtt_channel.subscribe(self._on_mqtt_message) + if self._unsub: + raise ValueError("Already connected to the device") + self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) async def close(self) -> None: """Close the MQTT connection to the device. This method will unsubscribe from the MQTT channel and clean up resources. """ - await self._mqtt_channel.close() + if self._unsub: + self._unsub() + self._unsub = None def _on_mqtt_message(self, message: bytes) -> None: """Handle incoming MQTT messages from the device. diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index f8095b3b..6dab1f90 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -31,14 +31,14 @@ def _subscribe_topic(self) -> str: """Topic to receive responses from the device.""" return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" - async def subscribe(self, callback: Callable[[bytes], None]) -> None: - """Subscribe to the device's response topic.""" - if self._unsub: - raise ValueError("Already subscribed to the response topic") - self._unsub = await self._mqtt_session.subscribe(self._subscribe_topic, callback) + async def subscribe(self, callback: Callable[[bytes], None]) -> Callable[[], None]: + """Subscribe to the device's response topic. - async def close(self) -> None: - """Close the MQTT subscription.""" + The callback will be called with the message payload when a message is received. + If already subscribed, raises ValueError. + + Returns a callable that can be used to unsubscribe from the topic. + """ if self._unsub: - self._unsub() - self._unsub = None + raise ValueError("Already subscribed to the response topic") + return await self._mqtt_session.subscribe(self._subscribe_topic, callback) diff --git a/tests/devices/test_device.py b/tests/devices/test_device.py new file mode 100644 index 00000000..6e2b5d5f --- /dev/null +++ b/tests/devices/test_device.py @@ -0,0 +1,40 @@ +"""Tests for the Device class.""" + +from unittest.mock import AsyncMock, Mock + +from roborock.containers import HomeData, UserData +from roborock.devices.device import DeviceVersion, RoborockDevice + +from .. import mock_data + +USER_DATA = UserData.from_dict(mock_data.USER_DATA) +HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW) + + +async def test_device_connection() -> None: + """Test the Device connection setup.""" + + unsub = Mock() + subscribe = AsyncMock() + subscribe.return_value = unsub + mqtt_channel = AsyncMock() + mqtt_channel.subscribe = subscribe + + device = RoborockDevice( + USER_DATA, + device_info=HOME_DATA.devices[0], + product_info=HOME_DATA.products[0], + mqtt_channel=mqtt_channel, + ) + assert device.duid == "abc123" + assert device.name == "Roborock S7 MaxV" + assert device.device_version == DeviceVersion.V1 + + assert not subscribe.called + + await device.connect() + assert subscribe.called + assert not unsub.called + + await device.close() + assert unsub.called diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index fac33344..6c97c23b 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -1,5 +1,6 @@ """Tests for the DeviceManager class.""" +from collections.abc import Generator from unittest.mock import patch import pytest @@ -14,6 +15,13 @@ USER_DATA = UserData.from_dict(mock_data.USER_DATA) +@pytest.fixture(autouse=True) +def setup_mqtt_session() -> Generator[None, None, None]: + """Fixture to set up the MQTT session for the tests.""" + with patch("roborock.devices.device_manager.create_mqtt_session"): + yield + + async def home_home_data_no_devices() -> HomeData: """Mock home data API that returns no devices.""" return HomeData( diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py new file mode 100644 index 00000000..cbf59d1f --- /dev/null +++ b/tests/devices/test_mqtt_channel.py @@ -0,0 +1,64 @@ +"""Tests for the MqttChannel class.""" + +from collections.abc import Generator +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from roborock.containers import HomeData, UserData +from roborock.devices.mqtt_channel import MqttChannel +from roborock.mqtt.session import MqttParams + +from .. import mock_data + +USER_DATA = UserData.from_dict(mock_data.USER_DATA) +TEST_MQTT_PARAMS = MqttParams( + host="localhost", + port=1883, + tls=False, + username="username", + password="password", + timeout=10.0, +) + + +@pytest.fixture(autouse=True) +def setup_mqtt_session() -> Generator[None, None, None]: + """Fixture to set up the MQTT session for the tests.""" + with patch("roborock.devices.device_manager.create_mqtt_session"): + yield + + +async def home_home_data_no_devices() -> HomeData: + """Mock home data API that returns no devices.""" + return HomeData( + id=1, + name="Test Home", + devices=[], + products=[], + ) + + +async def mock_home_data() -> HomeData: + """Mock home data API that returns devices.""" + return HomeData.from_dict(mock_data.HOME_DATA_RAW) + + +async def test_mqtt_channel() -> None: + """Test MQTT channel setup.""" + + mock_session = AsyncMock() + + channel = MqttChannel(mock_session, duid="abc123", rriot=USER_DATA.rriot, mqtt_params=TEST_MQTT_PARAMS) + + unsub = Mock() + mock_session.subscribe.return_value = unsub + + callback = Mock() + result = await channel.subscribe(callback) + + assert mock_session.subscribe.called + assert mock_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123" + assert mock_session.subscribe.call_args[0][1] == callback + + assert result == unsub From e49fd9547717dfccac3c8770902b9e61e0aa5e63 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Tue, 1 Jul 2025 22:08:57 -0700 Subject: [PATCH 3/9] feat: Add test coverage for device manager close --- tests/devices/test_device_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 6c97c23b..e09087f5 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -60,12 +60,15 @@ async def test_with_device() -> None: assert device.name == "Roborock S7 MaxV" assert device.device_version == DeviceVersion.V1 + await device_manager.close() + async def test_get_non_existent_device() -> None: """Test getting a non-existent device.""" device_manager = await create_device_manager(USER_DATA, mock_home_data) device = await device_manager.get_device("non_existent_duid") assert device is None + await device_manager.close() async def test_home_data_api_exception() -> None: From 5c33055b0e681c67695db26aebca839c6064c44b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 06:34:53 -0700 Subject: [PATCH 4/9] feat: Update roborock/devices/mqtt_channel.py --- roborock/devices/mqtt_channel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index 6dab1f90..dae04bc0 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -19,7 +19,6 @@ def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_para self._duid = duid self._rriot = rriot self._mqtt_params = mqtt_params - self._unsub: Callable[[], None] | None = None @property def _publish_topic(self) -> str: From 341d96d1a13a8f2bd8f8b0fbdba8fb0eecd7c137 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 07:07:33 -0700 Subject: [PATCH 5/9] feat: Apply suggestions from code review --- roborock/devices/mqtt_channel.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index dae04bc0..a0e223fb 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -34,10 +34,7 @@ async def subscribe(self, callback: Callable[[bytes], None]) -> Callable[[], Non """Subscribe to the device's response topic. The callback will be called with the message payload when a message is received. - If already subscribed, raises ValueError. Returns a callable that can be used to unsubscribe from the topic. """ - if self._unsub: - raise ValueError("Already subscribed to the response topic") return await self._mqtt_session.subscribe(self._subscribe_topic, callback) From 06b9178d40c7edf458875e272a63e8a548605900 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 07:34:23 -0700 Subject: [PATCH 6/9] feat: Add support for sending/recieving messages --- roborock/devices/device.py | 5 +- roborock/devices/device_manager.py | 2 +- roborock/devices/mqtt_channel.py | 88 ++++++++++- tests/devices/test_mqtt_channel.py | 235 +++++++++++++++++++++++++++-- 4 files changed, 306 insertions(+), 24 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index 0ac43b28..44cdfd01 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -10,6 +10,7 @@ from functools import cached_property from roborock.containers import HomeDataDevice, HomeDataProduct, UserData +from roborock.roborock_message import RoborockMessage from .mqtt_channel import MqttChannel @@ -99,9 +100,9 @@ async def close(self) -> None: self._unsub() self._unsub = None - def _on_mqtt_message(self, message: bytes) -> None: + def _on_mqtt_message(self, message: RoborockMessage) -> None: """Handle incoming MQTT messages from the device. This method should be overridden in subclasses to handle specific device messages. """ - _LOGGER.debug("Received message from device %s: %s", self.duid, message[:50]) # Log first 50 bytes for brevity + _LOGGER.debug("Received message from device %s: %s", self.duid, message) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index b664fc3d..22d3cb9e 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -113,7 +113,7 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) mqtt_session = await create_mqtt_session(mqtt_params) def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: - mqtt_channel = MqttChannel(mqtt_session, device.duid, user_data.rriot, mqtt_params) + mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params) return RoborockDevice(user_data, device, product, mqtt_channel) manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index a0e223fb..e75f1d16 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -1,25 +1,40 @@ +"""Modules for communicating with specific Roborock devices over MQTT.""" + +import asyncio import logging from collections.abc import Callable +from json import JSONDecodeError from roborock.containers import RRiot +from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams, MqttSession +from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder +from roborock.roborock_message import RoborockMessage _LOGGER = logging.getLogger(__name__) class MqttChannel: - """RPC-style channel for communicating with a specific device over MQTT. + """Simple RPC-style channel for communicating with a device over MQTT. - This currently only supports listening to messages and does not yet - support RPC functionality. + Handles request/response correlation and timeouts, but leaves message + format most parsing to higher-level components. """ - def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_params: MqttParams): + def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams): self._mqtt_session = mqtt_session self._duid = duid + self._local_key = local_key self._rriot = rriot self._mqtt_params = mqtt_params + # RPC support + self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {} + self._decoder = create_mqtt_decoder(local_key) + self._encoder = create_mqtt_encoder(local_key) + # Use a regular lock since we need to access from sync callback + self._queue_lock = asyncio.Lock() + @property def _publish_topic(self) -> str: """Topic to send commands to the device.""" @@ -30,11 +45,72 @@ def _subscribe_topic(self) -> str: """Topic to receive responses from the device.""" return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}" - async def subscribe(self, callback: Callable[[bytes], None]) -> Callable[[], None]: + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: """Subscribe to the device's response topic. The callback will be called with the message payload when a message is received. + All messages received will be processed through the provided callback, even + those sent in response to the `send_command` command. + Returns a callable that can be used to unsubscribe from the topic. """ - return await self._mqtt_session.subscribe(self._subscribe_topic, callback) + + def message_handler(payload: bytes) -> None: + if not (messages := self._decoder(payload)): + _LOGGER.warning("Failed to decode MQTT message: %s", payload) + return + for message in messages: + asyncio.create_task(self._resolve_future_with_lock(message)) + try: + callback(message) + except Exception as e: + _LOGGER.exception("Uncaught error in message handler callback: %s", e) + + return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler) + + async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: + """Resolve waiting future with proper locking.""" + if (request_id := message.get_request_id()) is None: + _LOGGER.debug("Received message with no request_id") + return + async with self._queue_lock: + if (future := self._waiting_queue.pop(request_id, None)) is not None: + if not future.done(): + future.set_result(message) + else: + _LOGGER.warning("Received message for completed future: request_id=%s", request_id) + else: + _LOGGER.warning("Received message with no waiting handler: request_id=%s", request_id) + + async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: + """Send a command message and wait for the response message. + + Returns the raw response message - caller is responsible for parsing. + """ + try: + if (request_id := message.get_request_id()) is None: + raise RoborockException("Message must have a request_id for RPC calls") + except (ValueError, JSONDecodeError) as err: + _LOGGER.exception("Error getting request_id from message: %s", err) + raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err + + future: asyncio.Future[RoborockMessage] = asyncio.Future() + async with self._queue_lock: + self._waiting_queue[request_id] = future + + try: + encoded_msg = self._encoder(message) + await self._mqtt_session.publish(self._publish_topic, encoded_msg) + + return await asyncio.wait_for(future, timeout=timeout) + + except asyncio.TimeoutError as ex: + async with self._queue_lock: + self._waiting_queue.pop(request_id, None) + raise RoborockException(f"Command timed out after {timeout}s") from ex + except Exception: + logging.exception("Uncaught error sending command") + async with self._queue_lock: + self._waiting_queue.pop(request_id, None) + raise diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index cbf59d1f..bdffb7ee 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -1,13 +1,18 @@ """Tests for the MqttChannel class.""" -from collections.abc import Generator +import asyncio +import json +from collections.abc import Callable, Generator from unittest.mock import AsyncMock, Mock, patch import pytest from roborock.containers import HomeData, UserData from roborock.devices.mqtt_channel import MqttChannel +from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams +from roborock.protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from .. import mock_data @@ -20,13 +25,61 @@ password="password", timeout=10.0, ) +TEST_LOCAL_KEY = "local_key" +TEST_REQUEST = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=json.dumps({"dps": {"101": json.dumps({"id": 12345, "method": "get_status"})}}).encode(), +) +TEST_RESPONSE = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps({"dps": {"102": json.dumps({"id": 12345, "result": {"state": "cleaning"}})}}).encode(), +) +TEST_REQUEST2 = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=json.dumps({"dps": {"101": json.dumps({"id": 54321, "method": "get_status"})}}).encode(), +) +TEST_RESPONSE2 = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps({"dps": {"102": json.dumps({"id": 54321, "result": {"state": "cleaning"}})}}).encode(), +) +ENCODER = create_mqtt_encoder(TEST_LOCAL_KEY) +DECODER = create_mqtt_decoder(TEST_LOCAL_KEY) -@pytest.fixture(autouse=True) -def setup_mqtt_session() -> Generator[None, None, None]: + +@pytest.fixture(name="mqtt_session", autouse=True) +def setup_mqtt_session() -> Generator[Mock, None, None]: """Fixture to set up the MQTT session for the tests.""" - with patch("roborock.devices.device_manager.create_mqtt_session"): - yield + mock_session = AsyncMock() + with patch("roborock.devices.device_manager.create_mqtt_session", return_value=mock_session): + yield mock_session + + +@pytest.fixture(name="mqtt_channel", autouse=True) +def setup_mqtt_channel(mqtt_session: Mock) -> MqttChannel: + """Fixture to set up the MQTT channel for the tests.""" + return MqttChannel( + mqtt_session, duid="abc123", local_key=TEST_LOCAL_KEY, rriot=USER_DATA.rriot, mqtt_params=TEST_MQTT_PARAMS + ) + + +@pytest.fixture(name="received_messages", autouse=True) +async def setup_subscribe_callback(mqtt_channel: MqttChannel) -> list[RoborockMessage]: + """Fixture to record messages received by the subscriber.""" + messages: list[RoborockMessage] = [] + await mqtt_channel.subscribe(messages.append) + return messages + + +@pytest.fixture(name="mqtt_message_handler") +async def setup_message_handler(mqtt_session: Mock, mqtt_channel: MqttChannel) -> Callable[[bytes], None]: + """Fixture to allow simulating incoming MQTT messages.""" + # Subscribe to set up message handling. We grab the message handler callback + # and use it to simulate receiving a response. + assert mqtt_session.subscribe + subscribe_call_args = mqtt_session.subscribe.call_args + message_handler = subscribe_call_args[0][1] + return message_handler async def home_home_data_no_devices() -> HomeData: @@ -44,21 +97,173 @@ async def mock_home_data() -> HomeData: return HomeData.from_dict(mock_data.HOME_DATA_RAW) -async def test_mqtt_channel() -> None: +async def test_mqtt_channel(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None: """Test MQTT channel setup.""" - mock_session = AsyncMock() - - channel = MqttChannel(mock_session, duid="abc123", rriot=USER_DATA.rriot, mqtt_params=TEST_MQTT_PARAMS) - unsub = Mock() - mock_session.subscribe.return_value = unsub + mqtt_session.subscribe.return_value = unsub callback = Mock() - result = await channel.subscribe(callback) + result = await mqtt_channel.subscribe(callback) - assert mock_session.subscribe.called - assert mock_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123" - assert mock_session.subscribe.call_args[0][1] == callback + assert mqtt_session.subscribe.called + assert mqtt_session.subscribe.call_args[0][0] == "rr/m/o/user123/username/abc123" assert result == unsub + + +async def test_send_command_success( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test successful RPC command sending and response handling.""" + # Send a test request. We use a task so we can simulate receiving the response + # while the command is still being processed. + command_task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST)) + await asyncio.sleep(0.01) # yield + + # Simulate receiving the response message via MQTT + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Get the result + result = await command_task + + # Verify the command was sent + assert mqtt_session.publish.called + assert mqtt_session.publish.call_args[0][0] == "rr/m/i/user123/username/abc123" + raw_sent_msg = mqtt_session.publish.call_args[0][1] # == b"encoded_message" + decoded_message = next(iter(DECODER(raw_sent_msg))) + assert decoded_message == TEST_REQUEST + assert decoded_message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert decoded_message.get_request_id() == 12345 + + # Verify we got the response message back + assert result == TEST_RESPONSE + + +async def test_send_command_without_request_id( + mqtt_session: Mock, mqtt_channel: MqttChannel, mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test sending command without request ID raises exception.""" + # Create a message without request ID + test_message = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + payload=b"no_request_id", + ) + + with pytest.raises(RoborockException, match="Message must have a request_id"): + await mqtt_channel.send_command(test_message) + + +async def test_handle_messages_no_waiting_handler( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling messages when no handler is waiting.""" + # Simulate receiving the response message via MQTT + mqtt_message_handler(ENCODER(TEST_REQUEST)) + await asyncio.sleep(0.01) # yield + + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert "Received message with no waiting handler: request_id=12345" in caplog.records[0].message + + +async def test_concurrent_commands( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling multiple concurrent RPC commands.""" + + # Create multiple test messages with different request IDs + # Start both commands concurrently + task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST2, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Create responses for both + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + mqtt_message_handler(ENCODER(TEST_RESPONSE2)) + await asyncio.sleep(0.01) # yield + + # Both should complete successfully + result1 = await task1 + result2 = await task2 + + assert result1 == TEST_RESPONSE + assert result2 == TEST_RESPONSE2 + + assert not caplog.records + + +async def test_handle_completed_future( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test handling response for an already completed future.""" + # Send request + task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Send the response twice + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Task completes and second message is dropped with a warning + result = await task + assert result == TEST_RESPONSE + + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert "Received message with no waiting handler: request_id=12345" in caplog.records[0].message + + +async def test_subscribe_callback_with_rpc_response( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + received_messages: list[RoborockMessage], + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test that subscribe callback is called along with RPC handling.""" + # Send request + task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + assert not received_messages + + # Send the response + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Task completes and second message is dropped with a warning + result = await task + assert result == TEST_RESPONSE + + # The subscribe callback should have been called with the same response + assert len(received_messages) == 1 + assert received_messages[0] == TEST_RESPONSE + + +async def test_message_decode_error( + mqtt_message_handler: Callable[[bytes], None], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test an error during message decoding.""" + mqtt_message_handler(b"invalid_payload") + await asyncio.sleep(0.01) # yield + + assert len(caplog.records) == 1 + assert caplog.records[0].levelname == "WARNING" + assert "Failed to decode MQTT message" in caplog.records[0].message From 364e88ee056d47ca8edccead36b29c6cecddda42 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 07:42:39 -0700 Subject: [PATCH 7/9] feat: Simplify rpc handling and tests --- roborock/devices/mqtt_channel.py | 9 ++--- tests/devices/test_mqtt_channel.py | 64 +++++++++++++++++------------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index e75f1d16..c40103db 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -76,12 +76,9 @@ async def _resolve_future_with_lock(self, message: RoborockMessage) -> None: return async with self._queue_lock: if (future := self._waiting_queue.pop(request_id, None)) is not None: - if not future.done(): - future.set_result(message) - else: - _LOGGER.warning("Received message for completed future: request_id=%s", request_id) + future.set_result(message) else: - _LOGGER.warning("Received message with no waiting handler: request_id=%s", request_id) + _LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id) async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage: """Send a command message and wait for the response message. @@ -97,6 +94,8 @@ async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> future: asyncio.Future[RoborockMessage] = asyncio.Future() async with self._queue_lock: + if request_id in self._waiting_queue: + raise RoborockException(f"Request ID {request_id} already pending, cannot send command") self._waiting_queue[request_id] = future try: diff --git a/tests/devices/test_mqtt_channel.py b/tests/devices/test_mqtt_channel.py index bdffb7ee..8efa5664 100644 --- a/tests/devices/test_mqtt_channel.py +++ b/tests/devices/test_mqtt_channel.py @@ -11,7 +11,7 @@ from roborock.devices.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams -from roborock.protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder +from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from .. import mock_data @@ -144,7 +144,9 @@ async def test_send_command_success( async def test_send_command_without_request_id( - mqtt_session: Mock, mqtt_channel: MqttChannel, mqtt_message_handler: Callable[[bytes], None], + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], ) -> None: """Test sending command without request ID raises exception.""" # Create a message without request ID @@ -157,22 +159,6 @@ async def test_send_command_without_request_id( await mqtt_channel.send_command(test_message) -async def test_handle_messages_no_waiting_handler( - mqtt_session: Mock, - mqtt_channel: MqttChannel, - mqtt_message_handler: Callable[[bytes], None], - caplog: pytest.LogCaptureFixture, -) -> None: - """Test handling messages when no handler is waiting.""" - # Simulate receiving the response message via MQTT - mqtt_message_handler(ENCODER(TEST_REQUEST)) - await asyncio.sleep(0.01) # yield - - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Received message with no waiting handler: request_id=12345" in caplog.records[0].message - - async def test_concurrent_commands( mqtt_session: Mock, mqtt_channel: MqttChannel, @@ -204,6 +190,31 @@ async def test_concurrent_commands( assert not caplog.records +async def test_concurrent_commands_same_request_id( + mqtt_session: Mock, + mqtt_channel: MqttChannel, + mqtt_message_handler: Callable[[bytes], None], +) -> None: + """Test that we are not allowed to send two commands with the same request id.""" + + # Create multiple test messages with different request IDs + # Start both commands concurrently + task1 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + task2 = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) + await asyncio.sleep(0.01) # yield + + # Create response + mqtt_message_handler(ENCODER(TEST_RESPONSE)) + await asyncio.sleep(0.01) # yield + + # Both should complete successfully + result1 = await task1 + assert result1 == TEST_RESPONSE + + with pytest.raises(RoborockException, match="Request ID 12345 already pending, cannot send command"): + await task2 + + async def test_handle_completed_future( mqtt_session: Mock, mqtt_channel: MqttChannel, @@ -221,14 +232,10 @@ async def test_handle_completed_future( mqtt_message_handler(ENCODER(TEST_RESPONSE)) await asyncio.sleep(0.01) # yield - # Task completes and second message is dropped with a warning + # Task completes and second message is not associated with a waiting handler result = await task assert result == TEST_RESPONSE - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert "Received message with no waiting handler: request_id=12345" in caplog.records[0].message - async def test_subscribe_callback_with_rpc_response( mqtt_session: Mock, @@ -236,24 +243,25 @@ async def test_subscribe_callback_with_rpc_response( received_messages: list[RoborockMessage], mqtt_message_handler: Callable[[bytes], None], ) -> None: - """Test that subscribe callback is called along with RPC handling.""" + """Test that subscribe callback is called independent of RPC handling.""" # Send request task = asyncio.create_task(mqtt_channel.send_command(TEST_REQUEST, timeout=5.0)) await asyncio.sleep(0.01) # yield assert not received_messages - # Send the response + # Send the response for this command and an unrelated command mqtt_message_handler(ENCODER(TEST_RESPONSE)) await asyncio.sleep(0.01) # yield + mqtt_message_handler(ENCODER(TEST_RESPONSE2)) + await asyncio.sleep(0.01) # yield - # Task completes and second message is dropped with a warning + # Task completes result = await task assert result == TEST_RESPONSE # The subscribe callback should have been called with the same response - assert len(received_messages) == 1 - assert received_messages[0] == TEST_RESPONSE + assert received_messages == [TEST_RESPONSE, TEST_RESPONSE2] async def test_message_decode_error( From a558e122001748f67c0be6885d66e629aa1243ae Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 19:45:07 -0700 Subject: [PATCH 8/9] feat: Gather tasks --- roborock/devices/device_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 22d3cb9e..3244b261 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -1,5 +1,6 @@ """Module for discovering Roborock devices.""" +import asyncio import logging from collections.abc import Awaitable, Callable @@ -56,6 +57,7 @@ async def discover_devices(self) -> list[RoborockDevice]: device_products = home_data.device_products _LOGGER.debug("Discovered %d devices %s", len(device_products), home_data) + # These are connected serially to avoid overwhelming the MQTT broker new_devices = {} for duid, (device, product) in device_products.items(): if duid in self._devices: @@ -77,11 +79,10 @@ async def get_devices(self) -> list[RoborockDevice]: async def close(self) -> None: """Close all MQTT connections and clean up resources.""" - for device in self._devices.values(): - await device.close() + tasks = [device.close() for device in self._devices.values()] self._devices.clear() - if self._mqtt_session: - await self._mqtt_session.close() + tasks.append(self._mqtt_session.close()) + await asyncio.gather(*tasks) def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi: From 63d7bba4be6ed0496314f158ef61db607c3ec84b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 2 Jul 2025 19:56:40 -0700 Subject: [PATCH 9/9] feat: Add debug lines --- roborock/devices/mqtt_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/devices/mqtt_channel.py b/roborock/devices/mqtt_channel.py index c40103db..00a01210 100644 --- a/roborock/devices/mqtt_channel.py +++ b/roborock/devices/mqtt_channel.py @@ -32,7 +32,6 @@ def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {} self._decoder = create_mqtt_decoder(local_key) self._encoder = create_mqtt_encoder(local_key) - # Use a regular lock since we need to access from sync callback self._queue_lock = asyncio.Lock() @property @@ -61,6 +60,7 @@ def message_handler(payload: bytes) -> None: _LOGGER.warning("Failed to decode MQTT message: %s", payload) return for message in messages: + _LOGGER.debug("Received message: %s", message) asyncio.create_task(self._resolve_future_with_lock(message)) try: callback(message)