From eeefbac9779404258b14ee31bcb58b0c4a6de92f Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 27 Apr 2023 17:19:35 -0400 Subject: [PATCH 1/5] fix: add functionality for missing enum values --- roborock/code_mappings.py | 13 +++++++++++-- tests/test_containers.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/roborock/code_mappings.py b/roborock/code_mappings.py index 9e147460..bd8e8bbf 100644 --- a/roborock/code_mappings.py +++ b/roborock/code_mappings.py @@ -1,10 +1,13 @@ from __future__ import annotations +import logging from enum import Enum from typing import Any, Type, TypeVar _StrEnumT = TypeVar("_StrEnumT", bound="RoborockEnum") +_LOGGER = logging.getLogger(__name__) + class RoborockEnum(str, Enum): def __new__(cls: Type[_StrEnumT], value: str, *args: Any, **kwargs: Any) -> _StrEnumT: @@ -18,11 +21,15 @@ def __str__(self): @classmethod def _missing_(cls: Type[_StrEnumT], code: object): - return cls._member_map_.get(str(code)) + if cls._member_map_.get(str(code)): + return cls._member_map_.get(str(code)) + else: + _LOGGER.warning(f"Unknown code {code} for {cls.__name__}") + return cls._member_map_.get(str(-9999)) @classmethod def as_dict(cls: Type[_StrEnumT]): - return {int(i.name): i.value for i in cls} + return {int(i.name): i.value for i in cls if i.value != "UNKNOWN"} @classmethod def values(cls: Type[_StrEnumT]): @@ -42,6 +49,7 @@ def __getitem__(cls: Type[_StrEnumT], item): def create_code_enum(name: str, data: dict) -> RoborockEnum: + data[-9999] = "UNKNOWN" return RoborockEnum(name, {str(key): value for key, value in data.items()}) @@ -138,6 +146,7 @@ def create_code_enum(name: str, data: dict) -> RoborockEnum: 202: "moderate", 203: "intense", 204: "custom", + 207: "custom | relatively large", }, ) diff --git a/tests/test_containers.py b/tests/test_containers.py index f14fc699..0b82c0cb 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -187,3 +187,13 @@ def test_clean_record(): assert cr.avoid_count == 19 assert cr.wash_count == 2 assert cr.map_flag == 0 + + +def test_no_value(): + modified_status = STATUS.copy() + modified_status["mop_mode"] = 9999 + s = Status.from_dict(modified_status) + + assert s.mop_mode == RoborockMopModeCode["-9999"] + assert "-9999" not in RoborockMopModeCode.keys() + assert "UNKNOWN" not in RoborockMopModeCode.values() From a1f727c0bcdc25145d37fd66722ef944dbcaf39b Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 27 Apr 2023 17:21:17 -0400 Subject: [PATCH 2/5] fix: temp removed 207 --- roborock/code_mappings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/roborock/code_mappings.py b/roborock/code_mappings.py index bd8e8bbf..9bdd3cb5 100644 --- a/roborock/code_mappings.py +++ b/roborock/code_mappings.py @@ -146,7 +146,6 @@ def create_code_enum(name: str, data: dict) -> RoborockEnum: 202: "moderate", 203: "intense", 204: "custom", - 207: "custom | relatively large", }, ) From f32e163003a98758a27f6ddc106c33fca2cf00af Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 27 Apr 2023 17:21:38 -0400 Subject: [PATCH 3/5] Revert "chore: linting" This reverts commit 58b46835d609794210f8c49daddbc7d25cee011d. --- tests/conftest.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d5cb5a54..ab90395e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ from roborock import HomeData, UserData from roborock.cloud_api import RoborockMqttClient -from roborock.containers import RoborockDeviceInfo from tests.mock_data import HOME_DATA_RAW, USER_DATA @@ -10,7 +9,7 @@ def mqtt_client(): user_data = UserData.from_dict(USER_DATA) home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = RoborockDeviceInfo(device=home_data.devices[0]) - client = RoborockMqttClient(user_data, device_info) + device_map = {home_data.devices[0].duid: home_data.devices[0].local_key} + client = RoborockMqttClient(user_data, device_map) yield client # Clean up any resources after the test From 0a237a2a4c64d0900c1aaec34b0d586cc3534f19 Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 27 Apr 2023 17:21:42 -0400 Subject: [PATCH 4/5] Revert "chore: linting" This reverts commit 2ed367cba5e9b4199fdea935305fb47f85a8c1e7. --- roborock/local_api.py | 2 +- tests/test_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/roborock/local_api.py b/roborock/local_api.py index 3eebd5e1..ac33e5f0 100644 --- a/roborock/local_api.py +++ b/roborock/local_api.py @@ -8,12 +8,12 @@ import async_timeout +from roborock.util import get_running_loop_or_create_one from .api import QUEUE_TIMEOUT, SPECIAL_COMMANDS, RoborockClient from .containers import RoborockLocalDeviceInfo from .exceptions import CommandVacuumError, RoborockConnectionException, RoborockException from .roborock_message import RoborockMessage, RoborockParser from .typing import CommandInfoMap, RoborockCommand -from .util import get_running_loop_or_create_one _LOGGER = logging.getLogger(__name__) diff --git a/tests/test_api.py b/tests/test_api.py index 99c10a0a..4e821b0e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,11 +2,11 @@ import paho.mqtt.client as mqtt import pytest +from roborock.containers import RoborockDeviceInfo from roborock import HomeData, RoborockDockDustCollectionModeCode, RoborockDockWashTowelModeCode, UserData from roborock.api import PreparedRequest, RoborockApiClient from roborock.cloud_api import RoborockMqttClient -from roborock.containers import RoborockDeviceInfo from tests.mock_data import BASE_URL_REQUEST, GET_CODE_RESPONSE, HOME_DATA_RAW, USER_DATA From aa9d87a7707146115b62c79d05da16593057f529 Mon Sep 17 00:00:00 2001 From: Luke Date: Thu, 27 Apr 2023 17:21:48 -0400 Subject: [PATCH 5/5] Revert "fix: using single device api" This reverts commit e689e8d141acff998fd524ace923621fc0f91d0c. --- roborock/api.py | 75 +++++++++---------- roborock/cli.py | 18 ++--- roborock/cloud_api.py | 25 +++---- roborock/local_api.py | 164 +++++++++++++++++++++++++----------------- tests/test_api.py | 23 +++--- 5 files changed, 168 insertions(+), 137 deletions(-) diff --git a/roborock/api.py b/roborock/api.py index 92f8034c..bcda9a26 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -14,7 +14,7 @@ import struct import time from random import randint -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Callable, Coroutine, Mapping, Optional import aiohttp from Crypto.Cipher import AES @@ -85,8 +85,8 @@ async def request(self, method: str, url: str, params=None, data=None, headers=N class RoborockClient: - def __init__(self, endpoint: str, device_info: RoborockDeviceInfo) -> None: - self.device_info = device_info + def __init__(self, endpoint: str, devices_info: Mapping[str, RoborockDeviceInfo]) -> None: + self.devices_info = devices_info self._endpoint = endpoint self._nonce = secrets.token_bytes(16) self._waiting_queue: dict[int, RoborockFuture] = {} @@ -200,27 +200,27 @@ def _get_payload(self, method: RoborockCommand, params: Optional[list] = None, s ) return request_id, timestamp, payload - async def send_command(self, method: RoborockCommand, params: Optional[list] = None): + async def send_command(self, device_id: str, method: RoborockCommand, params: Optional[list] = None): raise NotImplementedError - async def get_status(self) -> Status | None: - status = await self.send_command(RoborockCommand.GET_STATUS) + async def get_status(self, device_id: str) -> Status | None: + status = await self.send_command(device_id, RoborockCommand.GET_STATUS) if isinstance(status, dict): return Status.from_dict(status) return None - async def get_dnd_timer(self) -> DNDTimer | None: + async def get_dnd_timer(self, device_id: str) -> DNDTimer | None: try: - dnd_timer = await self.send_command(RoborockCommand.GET_DND_TIMER) + dnd_timer = await self.send_command(device_id, RoborockCommand.GET_DND_TIMER) if isinstance(dnd_timer, dict): return DNDTimer.from_dict(dnd_timer) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_clean_summary(self) -> CleanSummary | None: + async def get_clean_summary(self, device_id: str) -> CleanSummary | None: try: - clean_summary = await self.send_command(RoborockCommand.GET_CLEAN_SUMMARY) + clean_summary = await self.send_command(device_id, RoborockCommand.GET_CLEAN_SUMMARY) if isinstance(clean_summary, dict): return CleanSummary.from_dict(clean_summary) elif isinstance(clean_summary, list): @@ -232,54 +232,55 @@ async def get_clean_summary(self) -> CleanSummary | None: _LOGGER.error(e) return None - async def get_clean_record(self, record_id: int) -> CleanRecord | None: + async def get_clean_record(self, device_id: str, record_id: int) -> CleanRecord | None: try: - clean_record = await self.send_command(RoborockCommand.GET_CLEAN_RECORD, [record_id]) + clean_record = await self.send_command(device_id, RoborockCommand.GET_CLEAN_RECORD, [record_id]) if isinstance(clean_record, dict): return CleanRecord.from_dict(clean_record) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_consumable(self) -> Consumable | None: + async def get_consumable(self, device_id: str) -> Consumable | None: try: - consumable = await self.send_command(RoborockCommand.GET_CONSUMABLE) + consumable = await self.send_command(device_id, RoborockCommand.GET_CONSUMABLE) if isinstance(consumable, dict): return Consumable.from_dict(consumable) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_wash_towel_mode(self) -> WashTowelMode | None: + async def get_wash_towel_mode(self, device_id: str) -> WashTowelMode | None: try: - washing_mode = await self.send_command(RoborockCommand.GET_WASH_TOWEL_MODE) + washing_mode = await self.send_command(device_id, RoborockCommand.GET_WASH_TOWEL_MODE) if isinstance(washing_mode, dict): return WashTowelMode.from_dict(washing_mode) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_dust_collection_mode(self) -> DustCollectionMode | None: + async def get_dust_collection_mode(self, device_id: str) -> DustCollectionMode | None: try: - dust_collection = await self.send_command(RoborockCommand.GET_DUST_COLLECTION_MODE) + dust_collection = await self.send_command(device_id, RoborockCommand.GET_DUST_COLLECTION_MODE) if isinstance(dust_collection, dict): return DustCollectionMode.from_dict(dust_collection) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_smart_wash_params(self) -> SmartWashParams | None: + async def get_smart_wash_params(self, device_id: str) -> SmartWashParams | None: try: - mop_wash_mode = await self.send_command(RoborockCommand.GET_SMART_WASH_PARAMS) + mop_wash_mode = await self.send_command(device_id, RoborockCommand.GET_SMART_WASH_PARAMS) if isinstance(mop_wash_mode, dict): return SmartWashParams.from_dict(mop_wash_mode) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None: + async def get_dock_summary(self, device_id: str, dock_type: RoborockEnum) -> DockSummary | None: """Gets the status summary from the dock with the methods available for a given dock. + :param device_id: Device id :param dock_type: RoborockDockTypeCode""" try: commands: list[ @@ -288,11 +289,11 @@ async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None: Any, DustCollectionMode | WashTowelMode | SmartWashParams | None, ] - ] = [self.get_dust_collection_mode()] + ] = [self.get_dust_collection_mode(device_id)] if dock_type == RoborockDockTypeCode["3"]: commands += [ - self.get_wash_towel_mode(), - self.get_smart_wash_params(), + self.get_wash_towel_mode(device_id), + self.get_smart_wash_params(device_id), ] [dust_collection_mode, wash_towel_mode, smart_wash_params] = unpack_list( list(await asyncio.gather(*commands)), 3 @@ -303,21 +304,21 @@ async def get_dock_summary(self, dock_type: RoborockEnum) -> DockSummary | None: _LOGGER.error(e) return None - async def get_prop(self) -> DeviceProp | None: + async def get_prop(self, device_id: str) -> DeviceProp | None: [status, dnd_timer, clean_summary, consumable] = await asyncio.gather( *[ - self.get_status(), - self.get_dnd_timer(), - self.get_clean_summary(), - self.get_consumable(), + self.get_status(device_id), + self.get_dnd_timer(device_id), + self.get_clean_summary(device_id), + self.get_consumable(device_id), ] ) last_clean_record = None if clean_summary and clean_summary.records and len(clean_summary.records) > 0: - last_clean_record = await self.get_clean_record(clean_summary.records[0]) + last_clean_record = await self.get_clean_record(device_id, clean_summary.records[0]) dock_summary = None if status and status.dock_type is not None and status.dock_type != RoborockDockTypeCode["0"]: - dock_summary = await self.get_dock_summary(status.dock_type) + dock_summary = await self.get_dock_summary(device_id, status.dock_type) if any([status, dnd_timer, clean_summary, consumable]): return DeviceProp( status, @@ -329,27 +330,27 @@ async def get_prop(self) -> DeviceProp | None: ) return None - async def get_multi_maps_list(self) -> MultiMapsList | None: + async def get_multi_maps_list(self, device_id) -> MultiMapsList | None: try: - multi_maps_list = await self.send_command(RoborockCommand.GET_MULTI_MAPS_LIST) + multi_maps_list = await self.send_command(device_id, RoborockCommand.GET_MULTI_MAPS_LIST) if isinstance(multi_maps_list, dict): return MultiMapsList.from_dict(multi_maps_list) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_networking(self) -> NetworkInfo | None: + async def get_networking(self, device_id) -> NetworkInfo | None: try: - networking_info = await self.send_command(RoborockCommand.GET_NETWORK_INFO) + networking_info = await self.send_command(device_id, RoborockCommand.GET_NETWORK_INFO) if isinstance(networking_info, dict): return NetworkInfo.from_dict(networking_info) except RoborockTimeout as e: _LOGGER.error(e) return None - async def get_room_mapping(self) -> list[RoomMapping]: + async def get_room_mapping(self, device_id: str) -> list[RoomMapping]: """Gets the mapping from segment id -> iot id. Only works on local api.""" - mapping = await self.send_command(RoborockCommand.GET_ROOM_MAPPING) + mapping = await self.send_command(device_id, RoborockCommand.GET_ROOM_MAPPING) if isinstance(mapping, list): return [ RoomMapping(segment_id=segment_id, iot_id=iot_id) # type: ignore diff --git a/roborock/cli.py b/roborock/cli.py index 8669be46..2727a67d 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -102,30 +102,26 @@ async def list_devices(ctx): await _discover(ctx) login_data = context.login_data() home_data = login_data.home_data - device_name_id = ", ".join( - [f"{device.name}: {device.duid}" for device in home_data.devices + home_data.received_devices] - ) - click.echo(f"Known devices {device_name_id}") + click.echo(f"Known devices {', '.join([device.name for device in home_data.devices + home_data.received_devices])}") @click.command() -@click.option("--device_id", required=True) @click.option("--cmd", required=True) @click.option("--params", required=False) @click.pass_context @run_sync() -async def command(ctx, cmd, device_id, params): +async def command(ctx, cmd, params): context: RoborockContext = ctx.obj login_data = context.login_data() if not login_data.home_data: await _discover(ctx) login_data = context.login_data() home_data = login_data.home_data - devices = home_data.devices + home_data.received_devices - device = next((device for device in devices if device.duid == device_id), None) - device_info = RoborockDeviceInfo(device=device) - mqtt_client = RoborockMqttClient(login_data.user_data, device_info) - await mqtt_client.send_command(cmd, params) + device_map: dict[str, RoborockDeviceInfo] = {} + for device in home_data.devices + home_data.received_devices: + device_map[device.duid] = RoborockDeviceInfo(device=device) + mqtt_client = RoborockMqttClient(login_data.user_data, device_map) + await mqtt_client.send_command(home_data.devices[0].duid, cmd, params) mqtt_client.__del__() diff --git a/roborock/cloud_api.py b/roborock/cloud_api.py index baa71aec..0444be28 100644 --- a/roborock/cloud_api.py +++ b/roborock/cloud_api.py @@ -5,7 +5,7 @@ import threading import uuid from asyncio import Lock -from typing import Optional +from typing import Mapping, Optional from urllib.parse import urlparse import paho.mqtt.client as mqtt @@ -25,12 +25,12 @@ class RoborockMqttClient(RoborockClient, mqtt.Client): _thread: threading.Thread - def __init__(self, user_data: UserData, device_info: RoborockDeviceInfo) -> None: + def __init__(self, user_data: UserData, devices_info: Mapping[str, RoborockDeviceInfo]) -> None: rriot = user_data.rriot if rriot is None: raise RoborockException("Got no rriot data from user_data") endpoint = base64.b64encode(md5bin(rriot.k)[8:14]).decode() - RoborockClient.__init__(self, endpoint, device_info) + RoborockClient.__init__(self, endpoint, devices_info) mqtt.Client.__init__(self, protocol=mqtt.MQTTv5) self._mqtt_user = rriot.u self._hashed_user = md5hex(self._mqtt_user + ":" + rriot.k)[2:10] @@ -63,7 +63,7 @@ def on_connect(self, *args, **kwargs) -> None: connection_queue.resolve((None, VacuumError(rc, message))) return _LOGGER.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}") - topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}" + topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/#" (result, mid) = self.subscribe(topic) if result != 0: message = f"Failed to subscribe (rc: {result})" @@ -77,7 +77,8 @@ def on_connect(self, *args, **kwargs) -> None: def on_message(self, *args, **kwargs) -> None: _, __, msg = args - messages, _ = RoborockParser.decode(msg.payload, self.device_info.device.local_key) + device_id = msg.topic.split("/").pop() + messages, _ = RoborockParser.decode(msg.payload, self.devices_info[device_id].device.local_key) super().on_message(messages) def on_disconnect(self, *args, **kwargs) -> None: @@ -150,21 +151,21 @@ async def async_connect(self) -> None: async def validate_connection(self) -> None: await self.async_connect() - def _send_msg_raw(self, msg) -> None: - info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg) + def _send_msg_raw(self, device_id, msg) -> None: + info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{device_id}", msg) if info.rc != mqtt.MQTT_ERR_SUCCESS: raise RoborockException(f"Failed to publish (rc: {info.rc})") - async def send_command(self, method: RoborockCommand, params: Optional[list] = None): + async def send_command(self, device_id: str, method: RoborockCommand, params: Optional[list] = None): await self.validate_connection() request_id, timestamp, payload = super()._get_payload(method, params, True) _LOGGER.debug(f"id={request_id} Requesting method {method} with {params}") request_protocol = 101 response_protocol = 301 if method in SPECIAL_COMMANDS else 102 roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload) - local_key = self.device_info.device.local_key + local_key = self.devices_info[device_id].device.local_key msg = RoborockParser.encode(roborock_message, local_key) - self._send_msg_raw(msg) + self._send_msg_raw(device_id, msg) (response, err) = await self._async_response(request_id, response_protocol) if err: raise CommandVacuumError(method, err) from err @@ -174,9 +175,9 @@ async def send_command(self, method: RoborockCommand, params: Optional[list] = N _LOGGER.debug(f"id={request_id} Response from {method}: {response}") return response - async def get_map_v1(self): + async def get_map_v1(self, device_id): try: - return await self.send_command(RoborockCommand.GET_MAP_V1) + return await self.send_command(device_id, RoborockCommand.GET_MAP_V1) except RoborockException as e: _LOGGER.error(e) return None diff --git a/roborock/local_api.py b/roborock/local_api.py index ac33e5f0..89ea7559 100644 --- a/roborock/local_api.py +++ b/roborock/local_api.py @@ -4,57 +4,39 @@ import logging import socket from asyncio import Lock, Transport -from typing import Optional +from typing import Callable, Mapping, Optional import async_timeout -from roborock.util import get_running_loop_or_create_one from .api import QUEUE_TIMEOUT, SPECIAL_COMMANDS, RoborockClient from .containers import RoborockLocalDeviceInfo from .exceptions import CommandVacuumError, RoborockConnectionException, RoborockException from .roborock_message import RoborockMessage, RoborockParser from .typing import CommandInfoMap, RoborockCommand +from .util import get_running_loop_or_create_one _LOGGER = logging.getLogger(__name__) -class RoborockLocalClient(RoborockClient, asyncio.Protocol): - def __init__(self, device_info: RoborockLocalDeviceInfo): - super().__init__("abc", device_info) +class RoborockLocalClient(RoborockClient): + def __init__(self, devices_info: Mapping[str, RoborockLocalDeviceInfo]): + super().__init__("abc", devices_info) self.loop = get_running_loop_or_create_one() - self.ip = device_info.network_info.ip + self.device_listener: dict[str, RoborockSocketListener] = { + device_id: RoborockSocketListener( + device_info.network_info.ip, device_info.device.local_key, self.on_message, self.on_disconnect + ) + for device_id, device_info in devices_info.items() + } self._batch_structs: list[RoborockMessage] = [] self._executing = False - self.remaining = b"" - self.transport: Transport | None = None - self._mutex = Lock() - - def data_received(self, message): - if self.remaining: - message = self.remaining + message - self.remaining = b"" - (parser_msg, remaining) = RoborockParser.decode(message, self.device_info.device.local_key) - self.remaining = remaining - self.on_message(parser_msg) - - def connection_lost(self, exc: Optional[Exception]): - self.on_disconnect(exc) - - def is_connected(self): - return self.transport and self.transport.is_reading() async def async_connect(self) -> None: - try: - if not self.is_connected(): - async with async_timeout.timeout(QUEUE_TIMEOUT): - _LOGGER.info(f"Connecting to {self.ip}") - self.transport, _ = await self.loop.create_connection(lambda: self, self.ip, 58867) # type: ignore - except Exception as e: - raise RoborockConnectionException(f"Failed connecting to {self.ip}") from e + await asyncio.gather(*[listener.connect() for listener in self.device_listener.values()]) async def async_disconnect(self) -> None: - if self.transport: - self.transport.close() + for listener in self.device_listener.values(): + listener.disconnect() def build_roborock_message(self, method: RoborockCommand, params: Optional[list] = None) -> RoborockMessage: secured = True if method in SPECIAL_COMMANDS else False @@ -75,9 +57,9 @@ def build_roborock_message(self, method: RoborockCommand, params: Optional[list] payload=payload, ) - async def send_command(self, method: RoborockCommand, params: Optional[list] = None): + async def send_command(self, device_id: str, method: RoborockCommand, params: Optional[list] = None): roborock_message = self.build_roborock_message(method, params) - return (await self.send_message(roborock_message))[0] + return (await self.send_message(device_id, roborock_message))[0] async def async_local_response(self, roborock_message: RoborockMessage): request_id = roborock_message.get_request_id() @@ -90,37 +72,30 @@ async def async_local_response(self, roborock_message: RoborockMessage): _LOGGER.debug(f"id={request_id} Response from {roborock_message.get_method()}: {response}") return response - def _send_msg_raw(self, data: bytes): - try: - if not self.transport: - raise RoborockException("Can not send message without connection") - self.transport.write(data) - except Exception as e: - raise RoborockException(e) from e - - async def send_message(self, roborock_messages: list[RoborockMessage] | RoborockMessage): - async with self._mutex: - await self.async_connect() - if isinstance(roborock_messages, RoborockMessage): - roborock_messages = [roborock_messages] - local_key = self.device_info.device.local_key - msg = RoborockParser.encode(roborock_messages, local_key) - # Send the command to the Roborock device - if not self.should_keepalive(): - await self.async_disconnect() - - _LOGGER.debug(f"Requesting device with {roborock_messages}") - self._send_msg_raw(msg) - - responses = await asyncio.gather( - *[self.async_local_response(roborock_message) for roborock_message in roborock_messages], - return_exceptions=True, - ) - exception = next((response for response in responses if isinstance(response, BaseException)), None) - if exception: - await self.async_disconnect() - raise exception - return responses + async def send_message(self, device_id: str, roborock_messages: list[RoborockMessage] | RoborockMessage): + if isinstance(roborock_messages, RoborockMessage): + roborock_messages = [roborock_messages] + local_key = self.devices_info[device_id].device.local_key + msg = RoborockParser.encode(roborock_messages, local_key) + # Send the command to the Roborock device + listener = self.device_listener.get(device_id) + if listener is None: + raise RoborockException(f"No device listener for {device_id}") + if not self.should_keepalive(): + listener.disconnect() + + _LOGGER.debug(f"Requesting device with {roborock_messages}") + await listener.send_message(msg) + + responses = await asyncio.gather( + *[self.async_local_response(roborock_message) for roborock_message in roborock_messages], + return_exceptions=True, + ) + exception = next((response for response in responses if isinstance(response, BaseException)), None) + if exception: + listener.disconnect() + raise exception + return responses class RoborockSocket(socket.socket): @@ -129,3 +104,62 @@ class RoborockSocket(socket.socket): @property def is_closed(self): return self._closed + + +class RoborockSocketListener(asyncio.Protocol): + roborock_port = 58867 + + def __init__( + self, + ip: str, + local_key: str, + on_message: Callable[[list[RoborockMessage]], None], + on_disconnect: Callable[[Optional[Exception]], None], + timeout: float | int = QUEUE_TIMEOUT, + ): + self.ip = ip + self.local_key = local_key + self.loop = get_running_loop_or_create_one() + self.on_message = on_message + self.on_disconnect = on_disconnect + self.timeout = timeout + self.remaining = b"" + self.transport: Transport | None = None + self._mutex = Lock() + + def data_received(self, message): + if self.remaining: + message = self.remaining + message + self.remaining = b"" + (parser_msg, remaining) = RoborockParser.decode(message, self.local_key) + self.remaining = remaining + self.on_message(parser_msg) + + def connection_lost(self, exc: Optional[Exception]): + self.on_disconnect(exc) + + def is_connected(self): + return self.transport and self.transport.is_reading() + + async def connect(self): + try: + if not self.is_connected(): + async with async_timeout.timeout(self.timeout): + _LOGGER.info(f"Connecting to {self.ip}") + self.transport, _ = await self.loop.create_connection(lambda: self, self.ip, 58867) # type: ignore + except Exception as e: + raise RoborockConnectionException(f"Failed connecting to {self.ip}") from e + + def disconnect(self): + if self.transport: + self.transport.close() + + async def send_message(self, data: bytes) -> None: + async with self._mutex: + await self.connect() + try: + if not self.transport: + raise RoborockException("Can not send message without connection") + self.transport.write(data) + except Exception as e: + raise RoborockException(e) from e diff --git a/tests/test_api.py b/tests/test_api.py index 4e821b0e..5fa0c85a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,7 +2,6 @@ import paho.mqtt.client as mqtt import pytest -from roborock.containers import RoborockDeviceInfo from roborock import HomeData, RoborockDockDustCollectionModeCode, RoborockDockWashTowelModeCode, UserData from roborock.api import PreparedRequest, RoborockApiClient @@ -20,8 +19,8 @@ def test_can_create_prepared_request(): def test_can_create_mqtt_roborock(): home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = RoborockDeviceInfo(device=home_data.devices[0]) - RoborockMqttClient(UserData.from_dict(USER_DATA), device_info) + device_map = {home_data.devices[0].duid: home_data.devices[0]} + RoborockMqttClient(UserData.from_dict(USER_DATA), device_map) def test_sync_connect(mqtt_client): @@ -69,11 +68,11 @@ async def test_get_home_data(): @pytest.mark.asyncio async def test_get_dust_collection_mode(): home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = RoborockDeviceInfo(device=home_data.devices[0]) - rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info) + device_map = {home_data.devices[0].duid: home_data.devices[0]} + rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_map) with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command: command.return_value = {"mode": 1} - dust = await rmc.get_dust_collection_mode() + dust = await rmc.get_dust_collection_mode(home_data.devices[0].duid) assert dust is not None assert dust.mode == RoborockDockDustCollectionModeCode["1"] @@ -81,11 +80,11 @@ async def test_get_dust_collection_mode(): @pytest.mark.asyncio async def test_get_mop_wash_mode(): home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = RoborockDeviceInfo(device=home_data.devices[0]) - rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info) + device_map = {home_data.devices[0].duid: home_data.devices[0]} + rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_map) with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command: command.return_value = {"smart_wash": 0, "wash_interval": 1500} - mop_wash = await rmc.get_smart_wash_params() + mop_wash = await rmc.get_smart_wash_params(home_data.devices[0].duid) assert mop_wash is not None assert mop_wash.smart_wash == 0 assert mop_wash.wash_interval == 1500 @@ -94,10 +93,10 @@ async def test_get_mop_wash_mode(): @pytest.mark.asyncio async def test_get_washing_mode(): home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = RoborockDeviceInfo(device=home_data.devices[0]) - rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_info) + device_map = {home_data.devices[0].duid: home_data.devices[0]} + rmc = RoborockMqttClient(UserData.from_dict(USER_DATA), device_map) with patch("roborock.cloud_api.RoborockMqttClient.send_command") as command: command.return_value = {"wash_mode": 2} - washing_mode = await rmc.get_wash_towel_mode() + washing_mode = await rmc.get_wash_towel_mode(home_data.devices[0].duid) assert washing_mode is not None assert washing_mode.wash_mode == RoborockDockWashTowelModeCode["2"]