Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions roborock/protocols/v1_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,18 @@ class MapResponse:
"""The map data, decrypted and decompressed."""


def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse]:
def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse | None]:
"""Create a decoder for V1 map response messages."""

def _decode_map_response(message: RoborockMessage) -> MapResponse:
def _decode_map_response(message: RoborockMessage) -> MapResponse | None:
"""Decode a V1 map response message."""
if not message.payload or len(message.payload) < 24:
raise RoborockException("Invalid V1 map response format: missing payload")
header, body = message.payload[:24], message.payload[24:]
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", header)
if not endpoint.decode().startswith(security_data.endpoint):
raise RoborockException(
f"Invalid V1 map response endpoint: {endpoint!r}, expected {security_data.endpoint!r}"
)
_LOGGER.debug("Received map response requested not made by this device, ignoring.")
return None
try:
decrypted = Utils.decrypt_cbc(body, security_data.nonce)
except ValueError as err:
Expand Down
17 changes: 9 additions & 8 deletions roborock/version_1_apis/roborock_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class RoborockClientV1(RoborockClient, ABC):
"""Roborock client base class for version 1 devices."""

_listeners: dict[str, ListenerModel] = {}
_map_response_decoder: Callable[[RoborockMessage], MapResponse] | None = None
_map_response_decoder: Callable[[RoborockMessage], MapResponse | None] | None = None

def __init__(self, device_info: DeviceData, security_data: SecurityData | None) -> None:
"""Initializes the Roborock client."""
Expand Down Expand Up @@ -439,13 +439,14 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
if self._map_response_decoder is not None:
map_response = self._map_response_decoder(data)
queue = self._waiting_queue.get(map_response.request_id)
if queue:
queue.set_result(map_response.data)
else:
self._logger.debug(
"Received unsolicited map response for request_id %s", map_response.request_id
)
if map_response is not None:
queue = self._waiting_queue.get(map_response.request_id)
if queue:
queue.set_result(map_response.data)
else:
self._logger.debug(
"Received unsolicited map response for request_id %s", map_response.request_id
)
else:
queue = self._waiting_queue.get(data.seq)
if queue:
Expand Down
11 changes: 6 additions & 5 deletions tests/protocols/test_v1_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the v1 protocol message encoding and decoding."""

import json
import logging
import pathlib
from collections.abc import Generator
from unittest.mock import patch
Expand Down Expand Up @@ -183,13 +184,14 @@ def test_create_map_response_decoder():

decoder = create_map_response_decoder(SECURITY_DATA)
result = decoder(message)

assert result is not None
assert result.request_id == 44508
assert result.data == test_data


def test_create_map_response_decoder_invalid_endpoint():
def test_create_map_response_decoder_invalid_endpoint(caplog: pytest.LogCaptureFixture):
"""Test map response decoder with invalid endpoint."""
caplog.set_level(logging.DEBUG)
# Create header with wrong endpoint
header = b"wrongend" + b"\x00" * 8 + b"\xdc\xad" + b"\x00" * 6
payload = header + b"encrypted_data"
Expand All @@ -204,9 +206,8 @@ def test_create_map_response_decoder_invalid_endpoint():
)

decoder = create_map_response_decoder(SECURITY_DATA)

with pytest.raises(RoborockException, match="Invalid V1 map response endpoint"):
decoder(message)
assert decoder(message) is None
assert "Received map response requested not made by this device, ignoring." in caplog.text


def test_create_map_response_decoder_invalid_payload():
Expand Down