Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import struct
import time
from random import randint
from typing import Any, Callable, Coroutine, Optional
from typing import Any, Callable, Coroutine, Mapping, Optional

import aiohttp
from Crypto.Cipher import AES
Expand Down Expand Up @@ -85,8 +85,8 @@ async def request(self, method: str, url: str, params=None, data=None, headers=N


class RoborockClient:
def __init__(self, endpoint: str, device_info: RoborockDeviceInfo) -> None:
self.device_info = device_info
def __init__(self, endpoint: str, devices_info: Mapping[str, RoborockDeviceInfo]) -> None:
self.devices_info = devices_info
self._endpoint = endpoint
self._nonce = secrets.token_bytes(16)
self._waiting_queue: dict[int, RoborockFuture] = {}
Expand Down Expand Up @@ -200,27 +200,27 @@ def _get_payload(self, method: RoborockCommand, params: Optional[list] = None, s
)
return request_id, timestamp, payload

async def send_command(self, method: RoborockCommand, params: Optional[list] = None):
async def send_command(self, device_id: str, method: RoborockCommand, params: Optional[list] = None):
raise NotImplementedError

async def get_status(self) -> Status | None:
status = await self.send_command(RoborockCommand.GET_STATUS)
async def get_status(self, device_id: str) -> Status | None:
status = await self.send_command(device_id, RoborockCommand.GET_STATUS)
if isinstance(status, dict):
return Status.from_dict(status)
return None

async def get_dnd_timer(self) -> DNDTimer | None:
async def get_dnd_timer(self, device_id: str) -> DNDTimer | None:
try:
dnd_timer = await self.send_command(RoborockCommand.GET_DND_TIMER)
dnd_timer = await self.send_command(device_id, RoborockCommand.GET_DND_TIMER)
if isinstance(dnd_timer, dict):
return DNDTimer.from_dict(dnd_timer)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_clean_summary(self) -> CleanSummary | None:
async def get_clean_summary(self, device_id: str) -> CleanSummary | None:
try:
clean_summary = await self.send_command(RoborockCommand.GET_CLEAN_SUMMARY)
clean_summary = await self.send_command(device_id, RoborockCommand.GET_CLEAN_SUMMARY)
if isinstance(clean_summary, dict):
return CleanSummary.from_dict(clean_summary)
elif isinstance(clean_summary, list):
Expand All @@ -232,54 +232,55 @@ async def get_clean_summary(self) -> CleanSummary | None:
_LOGGER.error(e)
return None

async def get_clean_record(self, record_id: int) -> CleanRecord | None:
async def get_clean_record(self, device_id: str, record_id: int) -> CleanRecord | None:
try:
clean_record = await self.send_command(RoborockCommand.GET_CLEAN_RECORD, [record_id])
clean_record = await self.send_command(device_id, RoborockCommand.GET_CLEAN_RECORD, [record_id])
if isinstance(clean_record, dict):
return CleanRecord.from_dict(clean_record)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_consumable(self) -> Consumable | None:
async def get_consumable(self, device_id: str) -> Consumable | None:
try:
consumable = await self.send_command(RoborockCommand.GET_CONSUMABLE)
consumable = await self.send_command(device_id, RoborockCommand.GET_CONSUMABLE)
if isinstance(consumable, dict):
return Consumable.from_dict(consumable)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_wash_towel_mode(self) -> WashTowelMode | None:
async def get_wash_towel_mode(self, device_id: str) -> WashTowelMode | None:
try:
washing_mode = await self.send_command(RoborockCommand.GET_WASH_TOWEL_MODE)
washing_mode = await self.send_command(device_id, RoborockCommand.GET_WASH_TOWEL_MODE)
if isinstance(washing_mode, dict):
return WashTowelMode.from_dict(washing_mode)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_dust_collection_mode(self) -> DustCollectionMode | None:
async def get_dust_collection_mode(self, device_id: str) -> DustCollectionMode | None:
try:
dust_collection = await self.send_command(RoborockCommand.GET_DUST_COLLECTION_MODE)
dust_collection = await self.send_command(device_id, RoborockCommand.GET_DUST_COLLECTION_MODE)
if isinstance(dust_collection, dict):
return DustCollectionMode.from_dict(dust_collection)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_smart_wash_params(self) -> SmartWashParams | None:
async def get_smart_wash_params(self, device_id: str) -> SmartWashParams | None:
try:
mop_wash_mode = await self.send_command(RoborockCommand.GET_SMART_WASH_PARAMS)
mop_wash_mode = await self.send_command(device_id, RoborockCommand.GET_SMART_WASH_PARAMS)
if isinstance(mop_wash_mode, dict):
return SmartWashParams.from_dict(mop_wash_mode)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None:
async def get_dock_summary(self, device_id: str, dock_type: RoborockEnum) -> DockSummary | None:
"""Gets the status summary from the dock with the methods available for a given dock.

:param device_id: Device id
:param dock_type: RoborockDockTypeCode"""
try:
commands: list[
Expand All @@ -288,11 +289,11 @@ async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None:
Any,
DustCollectionMode | WashTowelMode | SmartWashParams | None,
]
] = [self.get_dust_collection_mode()]
] = [self.get_dust_collection_mode(device_id)]
if dock_type == RoborockDockTypeCode["3"]:
commands += [
self.get_wash_towel_mode(),
self.get_smart_wash_params(),
self.get_wash_towel_mode(device_id),
self.get_smart_wash_params(device_id),
]
[dust_collection_mode, wash_towel_mode, smart_wash_params] = unpack_list(
list(await asyncio.gather(*commands)), 3
Expand All @@ -303,21 +304,21 @@ async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None:
_LOGGER.error(e)
return None

async def get_prop(self) -> DeviceProp | None:
async def get_prop(self, device_id: str) -> DeviceProp | None:
[status, dnd_timer, clean_summary, consumable] = await asyncio.gather(
*[
self.get_status(),
self.get_dnd_timer(),
self.get_clean_summary(),
self.get_consumable(),
self.get_status(device_id),
self.get_dnd_timer(device_id),
self.get_clean_summary(device_id),
self.get_consumable(device_id),
]
)
last_clean_record = None
if clean_summary and clean_summary.records and len(clean_summary.records) > 0:
last_clean_record = await self.get_clean_record(clean_summary.records[0])
last_clean_record = await self.get_clean_record(device_id, clean_summary.records[0])
dock_summary = None
if status and status.dock_type is not None and status.dock_type != RoborockDockTypeCode["0"]:
dock_summary = await self.get_dock_summary(status.dock_type)
dock_summary = await self.get_dock_summary(device_id, status.dock_type)
if any([status, dnd_timer, clean_summary, consumable]):
return DeviceProp(
status,
Expand All @@ -329,27 +330,27 @@ async def get_prop(self) -> DeviceProp | None:
)
return None

async def get_multi_maps_list(self) -> MultiMapsList | None:
async def get_multi_maps_list(self, device_id) -> MultiMapsList | None:
try:
multi_maps_list = await self.send_command(RoborockCommand.GET_MULTI_MAPS_LIST)
multi_maps_list = await self.send_command(device_id, RoborockCommand.GET_MULTI_MAPS_LIST)
if isinstance(multi_maps_list, dict):
return MultiMapsList.from_dict(multi_maps_list)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_networking(self) -> NetworkInfo | None:
async def get_networking(self, device_id) -> NetworkInfo | None:
try:
networking_info = await self.send_command(RoborockCommand.GET_NETWORK_INFO)
networking_info = await self.send_command(device_id, RoborockCommand.GET_NETWORK_INFO)
if isinstance(networking_info, dict):
return NetworkInfo.from_dict(networking_info)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_room_mapping(self) -> list[RoomMapping]:
async def get_room_mapping(self, device_id: str) -> list[RoomMapping]:
"""Gets the mapping from segment id -> iot id. Only works on local api."""
mapping = await self.send_command(RoborockCommand.GET_ROOM_MAPPING)
mapping = await self.send_command(device_id, RoborockCommand.GET_ROOM_MAPPING)
if isinstance(mapping, list):
return [
RoomMapping(segment_id=segment_id, iot_id=iot_id) # type: ignore
Expand Down
18 changes: 7 additions & 11 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,26 @@ async def list_devices(ctx):
await _discover(ctx)
login_data = context.login_data()
home_data = login_data.home_data
device_name_id = ", ".join(
[f"{device.name}: {device.duid}" for device in home_data.devices + home_data.received_devices]
)
click.echo(f"Known devices {device_name_id}")
click.echo(f"Known devices {', '.join([device.name for device in home_data.devices + home_data.received_devices])}")


@click.command()
@click.option("--device_id", required=True)
@click.option("--cmd", required=True)
@click.option("--params", required=False)
@click.pass_context
@run_sync()
async def command(ctx, cmd, device_id, params):
async def command(ctx, cmd, params):
context: RoborockContext = ctx.obj
login_data = context.login_data()
if not login_data.home_data:
await _discover(ctx)
login_data = context.login_data()
home_data = login_data.home_data
devices = home_data.devices + home_data.received_devices
device = next((device for device in devices if device.duid == device_id), None)
device_info = RoborockDeviceInfo(device=device)
mqtt_client = RoborockMqttClient(login_data.user_data, device_info)
await mqtt_client.send_command(cmd, params)
device_map: dict[str, RoborockDeviceInfo] = {}
for device in home_data.devices + home_data.received_devices:
device_map[device.duid] = RoborockDeviceInfo(device=device)
mqtt_client = RoborockMqttClient(login_data.user_data, device_map)
await mqtt_client.send_command(home_data.devices[0].duid, cmd, params)
mqtt_client.__del__()


Expand Down
25 changes: 13 additions & 12 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
from asyncio import Lock
from typing import Optional
from typing import Mapping, Optional
from urllib.parse import urlparse

import paho.mqtt.client as mqtt
Expand All @@ -25,12 +25,12 @@
class RoborockMqttClient(RoborockClient, mqtt.Client):
_thread: threading.Thread

def __init__(self, user_data: UserData, device_info: RoborockDeviceInfo) -> None:
def __init__(self, user_data: UserData, devices_info: Mapping[str, RoborockDeviceInfo]) -> None:
rriot = user_data.rriot
if rriot is None:
raise RoborockException("Got no rriot data from user_data")
endpoint = base64.b64encode(md5bin(rriot.k)[8:14]).decode()
RoborockClient.__init__(self, endpoint, device_info)
RoborockClient.__init__(self, endpoint, devices_info)
mqtt.Client.__init__(self, protocol=mqtt.MQTTv5)
self._mqtt_user = rriot.u
self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10]
Expand Down Expand Up @@ -63,7 +63,7 @@ def on_connect(self, *args, **kwargs) -> None:
connection_queue.resolve((None, VacuumError(rc, message)))
return
_LOGGER.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/#"
(result, mid) = self.subscribe(topic)
if result != 0:
message = f"Failed to subscribe (rc: {result})"
Expand All @@ -77,7 +77,8 @@ def on_connect(self, *args, **kwargs) -> None:

def on_message(self, *args, **kwargs) -> None:
_, __, msg = args
messages, _ = RoborockParser.decode(msg.payload, self.device_info.device.local_key)
device_id = msg.topic.split("/").pop()
messages, _ = RoborockParser.decode(msg.payload, self.devices_info[device_id].device.local_key)
super().on_message(messages)

def on_disconnect(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -150,21 +151,21 @@ async def async_connect(self) -> None:
async def validate_connection(self) -> None:
await self.async_connect()

def _send_msg_raw(self, msg) -> None:
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)
def _send_msg_raw(self, device_id, msg) -> None:
info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{device_id}", msg)
if info.rc != mqtt.MQTT_ERR_SUCCESS:
raise RoborockException(f"Failed to publish (rc: {info.rc})")

async def send_command(self, method: RoborockCommand, params: Optional[list] = None):
async def send_command(self, device_id: str, method: RoborockCommand, params: Optional[list] = None):
await self.validate_connection()
request_id, timestamp, payload = super()._get_payload(method, params, True)
_LOGGER.debug(f"id={request_id} Requesting method {method} with {params}")
request_protocol = 101
response_protocol = 301 if method in SPECIAL_COMMANDS else 102
roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload)
local_key = self.device_info.device.local_key
local_key = self.devices_info[device_id].device.local_key
msg = RoborockParser.encode(roborock_message, local_key)
self._send_msg_raw(msg)
self._send_msg_raw(device_id, msg)
(response, err) = await self._async_response(request_id, response_protocol)
if err:
raise CommandVacuumError(method, err) from err
Expand All @@ -174,9 +175,9 @@ async def send_command(self, method: RoborockCommand, params: Optional[list] = N
_LOGGER.debug(f"id={request_id} Response from {method}: {response}")
return response

async def get_map_v1(self):
async def get_map_v1(self, device_id):
try:
return await self.send_command(RoborockCommand.GET_MAP_V1)
return await self.send_command(device_id, RoborockCommand.GET_MAP_V1)
except RoborockException as e:
_LOGGER.error(e)
return None
12 changes: 10 additions & 2 deletions roborock/code_mappings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Type, TypeVar

_StrEnumT = TypeVar("_StrEnumT", bound="RoborockEnum")

_LOGGER = logging.getLogger(__name__)


class RoborockEnum(str, Enum):
def __new__(cls: Type[_StrEnumT], value: str, *args: Any, **kwargs: Any) -> _StrEnumT:
Expand All @@ -18,11 +21,15 @@ def __str__(self):

@classmethod
def _missing_(cls: Type[_StrEnumT], code: object):
return cls._member_map_.get(str(code))
if cls._member_map_.get(str(code)):
return cls._member_map_.get(str(code))
else:
_LOGGER.warning(f"Unknown code {code} for {cls.__name__}")
return cls._member_map_.get(str(-9999))

@classmethod
def as_dict(cls: Type[_StrEnumT]):
return {int(i.name): i.value for i in cls}
return {int(i.name): i.value for i in cls if i.value != "UNKNOWN"}

@classmethod
def values(cls: Type[_StrEnumT]):
Expand All @@ -42,6 +49,7 @@ def __getitem__(cls: Type[_StrEnumT], item):


def create_code_enum(name: str, data: dict) -> RoborockEnum:
data[-9999] = "UNKNOWN"
return RoborockEnum(name, {str(key): value for key, value in data.items()})


Expand Down
Loading