From 7705cdabdaf73f22b2d4e8445c20fc7fbff3ca1f Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 2 Sep 2025 18:42:00 -0400 Subject: [PATCH 1/2] chore: move broadcast_protocol to t's own file --- roborock/broadcast_protocol.py | 71 ++++++++++++++++++++++++++++++++++ roborock/protocol.py | 52 +------------------------ 2 files changed, 72 insertions(+), 51 deletions(-) create mode 100644 roborock/broadcast_protocol.py diff --git a/roborock/broadcast_protocol.py b/roborock/broadcast_protocol.py new file mode 100644 index 00000000..02a56438 --- /dev/null +++ b/roborock/broadcast_protocol.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from asyncio import BaseTransport, Lock + +from construct import ( # type: ignore + Bytes, + Checksum, + Int16ub, + Int32ub, + RawCopy, + Struct, +) + +from roborock.containers import BroadcastMessage +from roborock.protocol import EncryptionAdapter, Utils, _Parser + +_LOGGER = logging.getLogger(__name__) + +BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe" + + +class RoborockProtocol(asyncio.DatagramProtocol): + def __init__(self, timeout: int = 5): + self.timeout = timeout + self.transport: BaseTransport | None = None + self.devices_found: list[BroadcastMessage] = [] + self._mutex = Lock() + + def __del__(self): + self.close() + + def datagram_received(self, data, _): + [broadcast_message], _ = BroadcastParser.parse(data) + if broadcast_message.payload: + parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload)) + _LOGGER.debug(f"Received broadcast: {parsed_message}") + self.devices_found.append(parsed_message) + + async def discover(self): + async with self._mutex: + try: + loop = asyncio.get_event_loop() + self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866)) + await asyncio.sleep(self.timeout) + return self.devices_found + finally: + self.close() + self.devices_found = [] + + def close(self): + self.transport.close() if self.transport else None + + +_BroadcastMessage = Struct( + "message" + / RawCopy( + Struct( + "version" / Bytes(3), + "seq" / Int32ub, + "protocol" / Int16ub, + "payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN), + ) + ), + "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data), +) + + +BroadcastParser: _Parser = _Parser(_BroadcastMessage, False) diff --git a/roborock/protocol.py b/roborock/protocol.py index fb4f11ef..08b04cca 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -1,12 +1,9 @@ from __future__ import annotations -import asyncio import binascii import gzip import hashlib -import json import logging -from asyncio import BaseTransport, Lock from collections.abc import Callable from urllib.parse import urlparse @@ -31,7 +28,7 @@ from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from roborock.containers import BroadcastMessage, RRiot +from roborock.containers import RRiot from roborock.exceptions import RoborockException from roborock.mqtt.session import MqttParams from roborock.roborock_message import RoborockMessage @@ -40,7 +37,6 @@ SALT = b"TXdfu$jyZ#TZHsg4" A01_HASH = "726f626f726f636b2d67a6d6da" B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E" -BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe" AP_CONFIG = 1 SOCK_DISCOVERY = 2 @@ -51,38 +47,6 @@ def md5hex(message: str) -> str: return md5.hexdigest() -class RoborockProtocol(asyncio.DatagramProtocol): - def __init__(self, timeout: int = 5): - self.timeout = timeout - self.transport: BaseTransport | None = None - self.devices_found: list[BroadcastMessage] = [] - self._mutex = Lock() - - def __del__(self): - self.close() - - def datagram_received(self, data, _): - [broadcast_message], _ = BroadcastParser.parse(data) - if broadcast_message.payload: - parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload)) - _LOGGER.debug(f"Received broadcast: {parsed_message}") - self.devices_found.append(parsed_message) - - async def discover(self): - async with self._mutex: - try: - loop = asyncio.get_event_loop() - self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866)) - await asyncio.sleep(self.timeout) - return self.devices_found - finally: - self.close() - self.devices_found = [] - - def close(self): - self.transport.close() if self.transport else None - - class Utils: """Util class for protocol manipulation.""" @@ -324,19 +288,6 @@ def _build(self, obj, stream, context, path): "remaining" / Optional(GreedyBytes), ) -_BroadcastMessage = Struct( - "message" - / RawCopy( - Struct( - "version" / Bytes(3), - "seq" / Int32ub, - "protocol" / Int16ub, - "payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN), - ) - ), - "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data), -) - class _Parser: def __init__(self, con: Construct, required_local_key: bool): @@ -390,7 +341,6 @@ def build( MessageParser: _Parser = _Parser(_Messages, True) -BroadcastParser: _Parser = _Parser(_BroadcastMessage, False) def create_mqtt_params(rriot: RRiot) -> MqttParams: From b2b5f5e8a96633f3ac0c2e6337b5a00a0a88bb25 Mon Sep 17 00:00:00 2001 From: Luke Lashley Date: Sun, 7 Sep 2025 11:45:37 -0400 Subject: [PATCH 2/2] fix: remove __del__ Co-authored-by: Allen Porter --- roborock/broadcast_protocol.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/roborock/broadcast_protocol.py b/roborock/broadcast_protocol.py index 02a56438..93b5b0a7 100644 --- a/roborock/broadcast_protocol.py +++ b/roborock/broadcast_protocol.py @@ -29,9 +29,6 @@ def __init__(self, timeout: int = 5): self.devices_found: list[BroadcastMessage] = [] self._mutex = Lock() - def __del__(self): - self.close() - def datagram_received(self, data, _): [broadcast_message], _ = BroadcastParser.parse(data) if broadcast_message.payload: