diff --git a/roborock/api.py b/roborock/api.py index c7b3d781..f7802df9 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -17,7 +17,7 @@ RoborockTimeout, UnknownMethodError, ) -from .roborock_future import RoborockFuture +from .roborock_future import RequestKey, RoborockFuture, WaitingQueue from .roborock_message import ( RoborockMessage, RoborockMessageProtocol, @@ -38,7 +38,7 @@ def __init__(self, device_info: DeviceData) -> None: """Initialize RoborockClient.""" self.device_info = device_info self._nonce = secrets.token_bytes(16) - self._waiting_queue: dict[int, RoborockFuture] = {} + self._waiting_queue = WaitingQueue() self._last_device_msg_in = time.monotonic() self._last_disconnection = time.monotonic() self.keep_alive = KEEPALIVE @@ -89,33 +89,22 @@ async def validate_connection(self) -> None: await self.async_disconnect() await self.async_connect() - async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any: + async def _wait_response(self, request_key: RequestKey, future: RoborockFuture) -> Any: try: - response = await queue.async_get(self.queue_timeout) + response = await future.async_get(self.queue_timeout) if response == "unknown_method": raise UnknownMethodError("Unknown method") return response except (asyncio.TimeoutError, asyncio.CancelledError): - raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None + raise RoborockTimeout(f"id={request_key} Timeout after {self.queue_timeout} seconds") from None finally: - self._waiting_queue.pop(request_id, None) - - def _async_response(self, request_id: int, protocol_id: int = 0) -> Any: - queue = RoborockFuture(protocol_id) - if request_id in self._waiting_queue and not ( - request_id == 2 and protocol_id == RoborockMessageProtocol.PING_REQUEST - ): - new_id = get_next_int(10000, 32767) - self._logger.warning( - "Attempting to create a future with an existing id %s (%s)... New id is %s. " - "Code may not function properly.", - request_id, - protocol_id, - new_id, - ) - request_id = new_id - self._waiting_queue[request_id] = queue - return asyncio.ensure_future(self._wait_response(request_id, queue)) + self._waiting_queue.safe_pop(request_key) + + + def _async_response(self, request_key: RequestKey) -> Any: + future = RoborockFuture() + self._waiting_queue.put(request_key, future) + return asyncio.ensure_future(self._wait_response(request_key, future)) @abstractmethod async def send_message(self, roborock_message: RoborockMessage): diff --git a/roborock/cloud_api.py b/roborock/cloud_api.py index 4387fcbf..92f581fe 100644 --- a/roborock/cloud_api.py +++ b/roborock/cloud_api.py @@ -14,7 +14,7 @@ from .containers import DeviceData, UserData from .exceptions import RoborockException, VacuumError from .protocol import MessageParser, md5hex -from .roborock_future import RoborockFuture +from .roborock_future import RequestKey _LOGGER = logging.getLogger(__name__) CONNECT_REQUEST_ID = 0 @@ -72,19 +72,17 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None: self._mqtt_password = rriot.s self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:] self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password) - self._waiting_queue: dict[int, RoborockFuture] = {} self._mutex = Lock() def _mqtt_on_connect(self, *args, **kwargs): _, __, ___, rc, ____ = args - connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID) + if not (connection_queue := self._waiting_queue.safe_pop(RequestKey(CONNECT_REQUEST_ID), "connect")): + self._logger.info("Received unexpected connect event") + return if rc != mqtt.MQTT_ERR_SUCCESS: message = f"Failed to connect ({mqtt.error_string(rc)})" self._logger.error(message) - if connection_queue: - connection_queue.set_exception(VacuumError(message)) - else: - self._logger.debug("Failed to notify connect future, not in queue") + connection_queue.set_exception(VacuumError(message)) return self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}") topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}" @@ -92,12 +90,10 @@ def _mqtt_on_connect(self, *args, **kwargs): if result != 0: message = f"Failed to subscribe ({mqtt.error_string(rc)})" self._logger.error(message) - if connection_queue: - connection_queue.set_exception(VacuumError(message)) + connection_queue.set_exception(VacuumError(message)) return self._logger.info(f"Subscribed to topic {topic}") - if connection_queue: - connection_queue.set_result(True) + connection_queue.set_result(True) def _mqtt_on_message(self, *args, **kwargs): client, __, msg = args @@ -112,8 +108,7 @@ def _mqtt_on_disconnect(self, *args, **kwargs): try: exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None super().on_connection_lost(exc) - connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID) - if connection_queue: + if connection_queue := self._waiting_queue.safe_pop(RequestKey(DISCONNECT_REQUEST_ID), "disconnect"): connection_queue.set_result(True) except Exception as ex: self._logger.exception(ex) @@ -124,10 +119,11 @@ def is_connected(self) -> bool: def _sync_disconnect(self) -> Any: if not self.is_connected(): + self._logger.debug("Already disconnected from mqtt") return None self._logger.info("Disconnecting from mqtt") - disconnected_future = self._async_response(DISCONNECT_REQUEST_ID) + disconnected_future = self._async_response(RequestKey(DISCONNECT_REQUEST_ID)) rc = self._mqtt_client.disconnect() if rc == mqtt.MQTT_ERR_NO_CONN: @@ -149,7 +145,7 @@ def _sync_connect(self) -> Any: raise RoborockException("Mqtt information was not entered. Cannot connect.") self._logger.debug("Connecting to mqtt") - connected_future = self._async_response(CONNECT_REQUEST_ID) + connected_future = self._async_response(RequestKey(CONNECT_REQUEST_ID)) self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE) self._mqtt_client.maybe_restart_loop() return connected_future diff --git a/roborock/roborock_future.py b/roborock/roborock_future.py index 9563f785..2c8f9861 100644 --- a/roborock/roborock_future.py +++ b/roborock/roborock_future.py @@ -1,16 +1,69 @@ from __future__ import annotations +import logging from asyncio import Future +from dataclasses import dataclass +from threading import Lock from typing import Any import async_timeout from .exceptions import VacuumError +from .roborock_message import RoborockMessageProtocol + +_LOGGER = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class RequestKey: + """A key for a Roborock message request.""" + + request_id: int + protocol: RoborockMessageProtocol | int = 0 + + def __str__(self) -> str: + """Get the key for the request.""" + return f"{self.request_id}-{self.protocol}" + + +class WaitingQueue: + """A threadsafe waiting queue for Roborock messages.""" + + def __init__(self) -> None: + """Initialize the waiting queue.""" + self._lock = Lock() + self._queue: dict[RequestKey, RoborockFuture] = {} + + def put(self, request_key: RequestKey, future: RoborockFuture) -> None: + """Create a future for the given protocol.""" + _LOGGER.debug("Putting request key %s in the queue", request_key) + with self._lock: + if request_key in self._queue: + raise ValueError(f"Request key {request_key} already exists in the queue") + self._queue[request_key] = future + + def safe_pop(self, request_key: RequestKey, label: str | None = None) -> RoborockFuture | None: + """Get the future from the queue if it has not yet been popped, otherwise ignore. + + The label is used for logging when the request key is not found in the queue. + """ + _LOGGER.debug("Popping request key %s (%s) from the queue", request_key, label) + with self._lock: + future = self._queue.pop(request_key, None) + if future is None and label is not None: + _LOGGER.warning("Received message for key %s (%s) not found in the queue", request_key, label) + return future class RoborockFuture: - def __init__(self, protocol: int): - self.protocol = protocol + """A threadsafe asyncio Future for Roborock messages. + + The results may be set from a background thread. The future + must be awaited in an asyncio event loop. + """ + + def __init__(self): + """Initialize the Roborock future.""" self.fut: Future = Future() self.loop = self.fut.get_loop() @@ -28,7 +81,8 @@ def _set_exception(self, exc: VacuumError) -> None: def set_exception(self, exc: VacuumError) -> None: self.loop.call_soon_threadsafe(self._set_exception, exc) - async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]: + async def async_get(self, timeout: float | int) -> Any: + """Get the result from the future or raises an error.""" try: async with async_timeout.timeout(timeout): return await self.fut diff --git a/roborock/version_1_apis/roborock_client_v1.py b/roborock/version_1_apis/roborock_client_v1.py index 2068975b..e82b2196 100644 --- a/roborock/version_1_apis/roborock_client_v1.py +++ b/roborock/version_1_apis/roborock_client_v1.py @@ -47,6 +47,7 @@ WashTowelMode, ) from roborock.protocol import Utils +from roborock.roborock_future import RequestKey from roborock.roborock_message import ( ROBOROCK_DATA_CONSUMABLE_PROTOCOL, ROBOROCK_DATA_STATUS_PROTOCOL, @@ -391,8 +392,8 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None: if data_point_number == "102": data_point_response = json.loads(data_point) request_id = data_point_response.get("id") - queue = self._waiting_queue.get(request_id) - if queue and queue.protocol == protocol: + request_key = RequestKey(request_id, protocol) + if queue := self._waiting_queue.safe_pop(request_key, "v1_rpc"): error = data_point_response.get("error") if error: queue.set_exception( @@ -406,8 +407,6 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None: if isinstance(result, list) and len(result) == 1: result = result[0] queue.set_result(result) - else: - self._logger.debug("Received response for unknown request id %s", request_id) else: try: data_protocol = RoborockDataProtocol(int(data_point_number)) @@ -467,19 +466,15 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None: except ValueError as err: raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err decompressed = Utils.decompress(decrypted) - queue = self._waiting_queue.get(request_id) - if queue: + request_key = RequestKey(request_id, protocol) + if queue := self._waiting_queue.safe_pop(request_key, "v1_map"): if isinstance(decompressed, list): decompressed = decompressed[0] queue.set_result(decompressed) - else: - self._logger.debug("Received response for unknown request id %s", request_id) else: - queue = self._waiting_queue.get(data.seq) - if queue: + request_key = RequestKey(data.seq, protocol) + if queue := self._waiting_queue.safe_pop(request_key, "v1_other"): queue.set_result(data.payload) - else: - self._logger.debug("Received response for unknown request id %s", data.seq) except Exception as ex: self._logger.exception(ex) diff --git a/roborock/version_1_apis/roborock_local_client_v1.py b/roborock/version_1_apis/roborock_local_client_v1.py index 8d11a225..948b69d8 100644 --- a/roborock/version_1_apis/roborock_local_client_v1.py +++ b/roborock/version_1_apis/roborock_local_client_v1.py @@ -5,6 +5,7 @@ from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException from ..exceptions import VacuumError from ..protocol import MessageParser +from ..roborock_future import RequestKey from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol from ..util import RoborockLoggerAdapter from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1 @@ -54,15 +55,19 @@ async def send_message(self, roborock_message: RoborockMessage): response_protocol = request_id + 1 else: request_id = roborock_message.get_request_id() + _LOGGER.debug("Getting next request id: %s", request_id) response_protocol = RoborockMessageProtocol.GENERAL_REQUEST if request_id is None: raise RoborockException(f"Failed build message {roborock_message}") local_key = self.device_info.device.local_key msg = MessageParser.build(roborock_message, local_key=local_key) + request_key = RequestKey(request_id, response_protocol) if method: - self._logger.debug(f"id={request_id} Requesting method {method} with {params}") + self._logger.debug(f"id={request_key} Requesting method {method} with {params}") + else: + self._logger.debug(f"id={request_key} Requesting with {params}") # Send the command to the Roborock device - async_response = self._async_response(request_id, response_protocol) + async_response = self._async_response(request_key) self._send_msg_raw(msg) diagnostic_key = method if method is not None else "unknown" try: diff --git a/roborock/version_1_apis/roborock_mqtt_client_v1.py b/roborock/version_1_apis/roborock_mqtt_client_v1.py index d4d074e2..8c201784 100644 --- a/roborock/version_1_apis/roborock_mqtt_client_v1.py +++ b/roborock/version_1_apis/roborock_mqtt_client_v1.py @@ -11,6 +11,7 @@ from ..containers import DeviceData, UserData from ..exceptions import CommandVacuumError, RoborockException, VacuumError from ..protocol import MessageParser, Utils +from ..roborock_future import RequestKey from ..roborock_message import ( RoborockMessage, RoborockMessageProtocol, @@ -47,11 +48,11 @@ async def send_message(self, roborock_message: RoborockMessage): response_protocol = ( RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE ) - + request_key = RequestKey(request_id, response_protocol) local_key = self.device_info.device.local_key msg = MessageParser.build(roborock_message, local_key, False) - self._logger.debug(f"id={request_id} Requesting method {method} with {params}") - async_response = self._async_response(request_id, response_protocol) + self._logger.debug(f"id={request_key} Requesting method {method} with {params}") + async_response = self._async_response(request_key) self._send_msg_raw(msg) diagnostic_key = method if method is not None else "unknown" try: @@ -67,9 +68,9 @@ async def send_message(self, roborock_message: RoborockMessage): "response": response, } if response_protocol == RoborockMessageProtocol.MAP_RESPONSE: - self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes") + self._logger.debug(f"id={request_key} Response from {method}: {len(response)} bytes") else: - self._logger.debug(f"id={request_id} Response from {method}: {response}") + self._logger.debug(f"id={request_key} Response from {method}: {response}") return response async def _send_command( diff --git a/roborock/version_a01_apis/roborock_client_a01.py b/roborock/version_a01_apis/roborock_client_a01.py index b736c348..519ae2d4 100644 --- a/roborock/version_a01_apis/roborock_client_a01.py +++ b/roborock/version_a01_apis/roborock_client_a01.py @@ -33,6 +33,7 @@ ZeoTemperature, ) from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory +from roborock.roborock_future import RequestKey from roborock.roborock_message import ( RoborockDyadDataProtocol, RoborockMessage, @@ -142,9 +143,9 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None: if data_point_protocol in entries: # Auto convert into data struct we want. converted_response = entries[data_point_protocol].post_process_fn(data_point) - queue = self._waiting_queue.get(int(data_point_number)) - if queue and queue.protocol == protocol: - queue.set_result(converted_response) + request_key = RequestKey(int(data_point_number), protocol) + if future := self._waiting_queue.safe_pop(request_key, "a01"): + future.set_result(converted_response) @abstractmethod async def update_values( diff --git a/roborock/version_a01_apis/roborock_mqtt_client_a01.py b/roborock/version_a01_apis/roborock_mqtt_client_a01.py index f19daf44..93ec2705 100644 --- a/roborock/version_a01_apis/roborock_mqtt_client_a01.py +++ b/roborock/version_a01_apis/roborock_mqtt_client_a01.py @@ -10,6 +10,7 @@ from roborock.containers import DeviceData, RoborockCategory, UserData from roborock.exceptions import RoborockException from roborock.protocol import MessageParser +from roborock.roborock_future import RequestKey from roborock.roborock_message import ( RoborockDyadDataProtocol, RoborockMessage, @@ -50,7 +51,7 @@ async def send_message(self, roborock_message: RoborockMessage): futures = [] if "10000" in payload["dps"]: for dps in json.loads(payload["dps"]["10000"]): - futures.append(self._async_response(dps, response_protocol)) + futures.append(self._async_response(RequestKey(dps, response_protocol))) self._send_msg_raw(m) responses = await asyncio.gather(*futures, return_exceptions=True) dps_responses: dict[int, typing.Any] = {} diff --git a/tests/test_queue.py b/tests/test_queue.py index 84767477..b835c4ca 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1,32 +1,123 @@ import asyncio +import logging import pytest from roborock.exceptions import VacuumError -from roborock.roborock_future import RoborockFuture +from roborock.roborock_future import RequestKey, RoborockFuture, WaitingQueue +from roborock.roborock_message import RoborockMessageProtocol +_LOGGER = logging.getLogger(__name__) +TIMEOUT = 5 -def test_can_create(): - RoborockFuture(1) +async def test_can_create() -> None: + RoborockFuture() -@pytest.mark.asyncio -async def test_set_result(): - rq = RoborockFuture(1) + +async def test_set_result() -> None: + rq = RoborockFuture() rq.set_result("test") assert await rq.async_get(1) == "test" -@pytest.mark.asyncio -async def test_set_exception(): - rq = RoborockFuture(1) +async def test_set_exception() -> None: + rq = RoborockFuture() rq.set_exception(VacuumError("test")) with pytest.raises(VacuumError): assert await rq.async_get(1) -@pytest.mark.asyncio -async def test_get_timeout(): - rq = RoborockFuture(1) +async def test_get_timeout() -> None: + rq = RoborockFuture() with pytest.raises(asyncio.TimeoutError): await rq.async_get(0.01) + + +@pytest.mark.parametrize( + "key", + [ + RequestKey(1), + RequestKey(1, RoborockMessageProtocol.RPC_RESPONSE), + ], +) +async def test_queue_result_in_thread(key: RequestKey) -> None: + queue = WaitingQueue() + + future = RoborockFuture() + queue.put(key, future) + + def set_result_in_thread(): + fut = queue.safe_pop(key) + assert fut is not None + fut.set_result("value1") + + loop = asyncio.get_event_loop() + task = loop.run_in_executor(None, set_result_in_thread) + await task + + assert await future.async_get(TIMEOUT) == "value1" + + +@pytest.mark.parametrize( + "key", + [ + RequestKey(1), + RequestKey(1, RoborockMessageProtocol.RPC_RESPONSE), + ], +) +async def test_queue_set_exception_in_thread(key: RequestKey) -> None: + queue = WaitingQueue() + + future = RoborockFuture() + queue.put(key, future) + + def set_result_in_thread(): + fut = queue.safe_pop(key) + assert fut is not None + fut.set_exception(VacuumError("value1")) + + loop = asyncio.get_event_loop() + task = loop.run_in_executor(None, set_result_in_thread) + await task + + with pytest.raises(VacuumError, match="value1"): + await future.async_get(TIMEOUT) + + +@pytest.mark.parametrize( + "key", + [ + RequestKey(1), + RequestKey(1, RoborockMessageProtocol.RPC_RESPONSE), + ], +) +async def test_queue_item_not_found(key: RequestKey) -> None: + queue = WaitingQueue() + assert queue.safe_pop(key) is None + + +@pytest.mark.parametrize( + "key", + [ + RequestKey(1), + RequestKey(1, RoborockMessageProtocol.RPC_RESPONSE), + ], +) +async def test_queue_duplicate_item_fails(key: RequestKey) -> None: + queue = WaitingQueue() + future1 = RoborockFuture() + queue.put(key, future1) + + future2 = RoborockFuture() + with pytest.raises(ValueError): + queue.put(key, future2) + + +async def test_unique_protocol() -> None: + queue = WaitingQueue() + future1 = RoborockFuture() + queue.put(RequestKey(1, RoborockMessageProtocol.RPC_RESPONSE), future1) + + future2 = RoborockFuture() + queue.put(RequestKey(1, RoborockMessageProtocol.GENERAL_RESPONSE), future2)