Skip to content
49 changes: 15 additions & 34 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from pyshark.packet.packet import Packet # type: ignore

from roborock import RoborockException
from roborock.containers import DeviceData, HomeDataProduct, LoginData
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.protocol import MessageParser, create_mqtt_params
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData
from roborock.devices.device_manager import create_device_manager, create_home_data_api
from roborock.protocol import MessageParser
from roborock.util import run_sync
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
Expand Down Expand Up @@ -101,44 +101,25 @@ async def session(ctx, duration: int):
context: RoborockContext = ctx.obj
login_data = context.login_data()

# Discovery devices if not already available
if not login_data.home_data:
await _discover(ctx)
login_data = context.login_data()
if not login_data.home_data or not login_data.home_data.devices:
raise RoborockException("Unable to discover devices")

all_devices = login_data.home_data.devices + login_data.home_data.received_devices
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")

rriot = login_data.user_data.rriot
params = create_mqtt_params(rriot)

mqtt_session = await create_mqtt_session(params)
click.echo("Starting MQTT session...")
if not mqtt_session.connected:
raise RoborockException("Failed to connect to MQTT broker")
home_data_api = create_home_data_api(login_data.email, login_data.user_data)

def on_message(bytes: bytes):
"""Callback function to handle incoming MQTT messages."""
# Decode the first 20 bytes of the message for display
bytes = bytes[:20]
async def home_data_cache() -> HomeData:
if login_data.home_data is None:
login_data.home_data = await home_data_api()
context.update(login_data)
return login_data.home_data

click.echo(f"Received message: {bytes}...")
# Create device manager
device_manager = await create_device_manager(login_data.user_data, home_data_cache)

unsubs = []
for device in all_devices:
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
unsub = await mqtt_session.subscribe(device_topic, on_message)
unsubs.append(unsub)
devices = await device_manager.get_devices()
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")

click.echo("MQTT session started. Listening for messages...")
await asyncio.sleep(duration)

click.echo("Stopping MQTT session...")
for unsub in unsubs:
unsub()
await mqtt_session.close()
# Close the device manager (this will close all devices and MQTT session)
await device_manager.close()


async def _discover(ctx):
Expand Down
47 changes: 45 additions & 2 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

import enum
import logging
from collections.abc import Callable
from functools import cached_property

from roborock.containers import HomeDataDevice, HomeDataProduct, UserData
from roborock.roborock_message import RoborockMessage

from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)

Expand All @@ -29,11 +33,25 @@ class DeviceVersion(enum.StrEnum):
class RoborockDevice:
"""Unified Roborock device class with automatic connection setup."""

def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None:
"""Initialize the RoborockDevice with device info, user data, and capabilities."""
def __init__(
self,
user_data: UserData,
device_info: HomeDataDevice,
product_info: HomeDataProduct,
mqtt_channel: MqttChannel,
) -> None:
"""Initialize the RoborockDevice.

The device takes ownership of the MQTT channel for communication with the device.
Use `connect()` to establish the connection, which will set up the MQTT channel
for receiving messages from the device. Use `close()` to unsubscribe from the MQTT
channel.
"""
self._user_data = user_data
self._device_info = device_info
self._product_info = product_info
self._mqtt_channel = mqtt_channel
self._unsub: Callable[[], None] | None = None

@property
def duid(self) -> str:
Expand Down Expand Up @@ -63,3 +81,28 @@ def device_version(self) -> str:
self._device_info.name,
)
return DeviceVersion.UNKNOWN

async def connect(self) -> None:
"""Connect to the device using MQTT.

This method will set up the MQTT channel for communication with the device.
"""
if self._unsub:
raise ValueError("Already connected to the device")
self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)

async def close(self) -> None:
"""Close the MQTT connection to the device.

This method will unsubscribe from the MQTT channel and clean up resources.
"""
if self._unsub:
self._unsub()
self._unsub = None

def _on_mqtt_message(self, message: RoborockMessage) -> None:
"""Handle incoming MQTT messages from the device.

This method should be overridden in subclasses to handle specific device messages.
"""
_LOGGER.debug("Received message from device %s: %s", self.duid, message)
45 changes: 38 additions & 7 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module for discovering Roborock devices."""

import asyncio
import logging
from collections.abc import Awaitable, Callable

Expand All @@ -10,8 +11,13 @@
UserData,
)
from roborock.devices.device import RoborockDevice
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient

from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)

__all__ = [
Expand All @@ -34,21 +40,33 @@ def __init__(
self,
home_data_api: HomeDataApi,
device_creator: DeviceCreator,
mqtt_session: MqttSession,
) -> None:
"""Initialize the DeviceManager with user data and optional cache storage."""
"""Initialize the DeviceManager with user data and optional cache storage.

This takes ownership of the MQTT session and will close it when the manager is closed.
"""
self._home_data_api = home_data_api
self._device_creator = device_creator
self._devices: dict[str, RoborockDevice] = {}
self._mqtt_session = mqtt_session

async def discover_devices(self) -> list[RoborockDevice]:
"""Discover all devices for the logged-in user."""
home_data = await self._home_data_api()
device_products = home_data.device_products
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)

self._devices = {
duid: self._device_creator(device, product) for duid, (device, product) in device_products.items()
}
# These are connected serially to avoid overwhelming the MQTT broker
new_devices = {}
for duid, (device, product) in device_products.items():
if duid in self._devices:
continue
new_device = self._device_creator(device, product)
await new_device.connect()
new_devices[duid] = new_device

self._devices.update(new_devices)
return list(self._devices.values())

async def get_device(self, duid: str) -> RoborockDevice | None:
Expand All @@ -59,6 +77,13 @@ async def get_devices(self) -> list[RoborockDevice]:
"""Get all discovered devices."""
return list(self._devices.values())

async def close(self) -> None:
"""Close all MQTT connections and clean up resources."""
tasks = [device.close() for device in self._devices.values()]
self._devices.clear()
tasks.append(self._mqtt_session.close())
await asyncio.gather(*tasks)


def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
"""Create a home data API wrapper.
Expand All @@ -67,7 +92,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
home data for the user.
"""

client = RoborockApiClient(email, user_data)
# Note: This will auto discover the API base URL. This can be improved
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
client = RoborockApiClient(email)

async def home_data_api() -> HomeData:
return await client.get_home_data(user_data)
Expand All @@ -83,9 +110,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
include caching or other optimizations.
"""

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_session = await create_mqtt_session(mqtt_params)

def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
return RoborockDevice(user_data, device, product)
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
return RoborockDevice(user_data, device, product, mqtt_channel)

manager = DeviceManager(home_data_api, device_creator)
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
await manager.discover_devices()
return manager
115 changes: 115 additions & 0 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Modules for communicating with specific Roborock devices over MQTT."""

import asyncio
import logging
from collections.abc import Callable
from json import JSONDecodeError

from roborock.containers import RRiot
from roborock.exceptions import RoborockException
from roborock.mqtt.session import MqttParams, MqttSession
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
from roborock.roborock_message import RoborockMessage

_LOGGER = logging.getLogger(__name__)


class MqttChannel:
"""Simple RPC-style channel for communicating with a device over MQTT.

Handles request/response correlation and timeouts, but leaves message
format most parsing to higher-level components.
"""

def __init__(self, mqtt_session: MqttSession, duid: str, local_key: str, rriot: RRiot, mqtt_params: MqttParams):
self._mqtt_session = mqtt_session
self._duid = duid
self._local_key = local_key
self._rriot = rriot
self._mqtt_params = mqtt_params

# RPC support
self._waiting_queue: dict[int, asyncio.Future[RoborockMessage]] = {}
self._decoder = create_mqtt_decoder(local_key)
self._encoder = create_mqtt_encoder(local_key)
self._queue_lock = asyncio.Lock()

@property
def _publish_topic(self) -> str:
"""Topic to send commands to the device."""
return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"

@property
def _subscribe_topic(self) -> str:
"""Topic to receive responses from the device."""
return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Subscribe to the device's response topic.

The callback will be called with the message payload when a message is received.

All messages received will be processed through the provided callback, even
those sent in response to the `send_command` command.

Returns a callable that can be used to unsubscribe from the topic.
"""

def message_handler(payload: bytes) -> None:
if not (messages := self._decoder(payload)):
_LOGGER.warning("Failed to decode MQTT message: %s", payload)
return
for message in messages:
_LOGGER.debug("Received message: %s", message)
asyncio.create_task(self._resolve_future_with_lock(message))
try:
callback(message)
except Exception as e:
_LOGGER.exception("Uncaught error in message handler callback: %s", e)

return await self._mqtt_session.subscribe(self._subscribe_topic, message_handler)

async def _resolve_future_with_lock(self, message: RoborockMessage) -> None:
"""Resolve waiting future with proper locking."""
if (request_id := message.get_request_id()) is None:
_LOGGER.debug("Received message with no request_id")
return
async with self._queue_lock:
if (future := self._waiting_queue.pop(request_id, None)) is not None:
future.set_result(message)
else:
_LOGGER.debug("Received message with no waiting handler: request_id=%s", request_id)

async def send_command(self, message: RoborockMessage, timeout: float = 10.0) -> RoborockMessage:
"""Send a command message and wait for the response message.

Returns the raw response message - caller is responsible for parsing.
"""
try:
if (request_id := message.get_request_id()) is None:
raise RoborockException("Message must have a request_id for RPC calls")
except (ValueError, JSONDecodeError) as err:
_LOGGER.exception("Error getting request_id from message: %s", err)
raise RoborockException(f"Invalid message format, Message must have a request_id: {err}") from err

future: asyncio.Future[RoborockMessage] = asyncio.Future()
async with self._queue_lock:
if request_id in self._waiting_queue:
raise RoborockException(f"Request ID {request_id} already pending, cannot send command")
self._waiting_queue[request_id] = future

try:
encoded_msg = self._encoder(message)
await self._mqtt_session.publish(self._publish_topic, encoded_msg)

return await asyncio.wait_for(future, timeout=timeout)

except asyncio.TimeoutError as ex:
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
raise RoborockException(f"Command timed out after {timeout}s") from ex
except Exception:
logging.exception("Uncaught error sending command")
async with self._queue_lock:
self._waiting_queue.pop(request_id, None)
raise
40 changes: 40 additions & 0 deletions tests/devices/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Tests for the Device class."""

from unittest.mock import AsyncMock, Mock

from roborock.containers import HomeData, UserData
from roborock.devices.device import DeviceVersion, RoborockDevice

from .. import mock_data

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW)


async def test_device_connection() -> None:
"""Test the Device connection setup."""

unsub = Mock()
subscribe = AsyncMock()
subscribe.return_value = unsub
mqtt_channel = AsyncMock()
mqtt_channel.subscribe = subscribe

device = RoborockDevice(
USER_DATA,
device_info=HOME_DATA.devices[0],
product_info=HOME_DATA.products[0],
mqtt_channel=mqtt_channel,
)
assert device.duid == "abc123"
assert device.name == "Roborock S7 MaxV"
assert device.device_version == DeviceVersion.V1

assert not subscribe.called

await device.connect()
assert subscribe.called
assert not unsub.called

await device.close()
assert unsub.called
Loading