Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions roborock/broadcast_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import asyncio
import json
import logging
from asyncio import BaseTransport, Lock

from construct import ( # type: ignore
Bytes,
Checksum,
Int16ub,
Int32ub,
RawCopy,
Struct,
)

from roborock.containers import BroadcastMessage
from roborock.protocol import EncryptionAdapter, Utils, _Parser

_LOGGER = logging.getLogger(__name__)

BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"


class RoborockProtocol(asyncio.DatagramProtocol):
def __init__(self, timeout: int = 5):
self.timeout = timeout
self.transport: BaseTransport | None = None
self.devices_found: list[BroadcastMessage] = []
self._mutex = Lock()

def datagram_received(self, data, _):
[broadcast_message], _ = BroadcastParser.parse(data)
if broadcast_message.payload:
parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload))
_LOGGER.debug(f"Received broadcast: {parsed_message}")
self.devices_found.append(parsed_message)

async def discover(self):
async with self._mutex:
try:
loop = asyncio.get_event_loop()
self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
await asyncio.sleep(self.timeout)
return self.devices_found
finally:
self.close()
self.devices_found = []

def close(self):
self.transport.close() if self.transport else None


_BroadcastMessage = Struct(
"message"
/ RawCopy(
Struct(
"version" / Bytes(3),
"seq" / Int32ub,
"protocol" / Int16ub,
"payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
)
),
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)


BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
52 changes: 1 addition & 51 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from __future__ import annotations

import asyncio
import binascii
import gzip
import hashlib
import json
import logging
from asyncio import BaseTransport, Lock
from collections.abc import Callable
from urllib.parse import urlparse

Expand All @@ -31,7 +28,7 @@
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad

from roborock.containers import BroadcastMessage, RRiot
from roborock.containers import RRiot
from roborock.exceptions import RoborockException
from roborock.mqtt.session import MqttParams
from roborock.roborock_message import RoborockMessage
Expand All @@ -40,7 +37,6 @@
SALT = b"TXdfu$jyZ#TZHsg4"
A01_HASH = "726f626f726f636b2d67a6d6da"
B01_HASH = "5wwh9ikChRjASpMU8cxg7o1d2E"
BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"
AP_CONFIG = 1
SOCK_DISCOVERY = 2

Expand All @@ -51,38 +47,6 @@ def md5hex(message: str) -> str:
return md5.hexdigest()


class RoborockProtocol(asyncio.DatagramProtocol):
def __init__(self, timeout: int = 5):
self.timeout = timeout
self.transport: BaseTransport | None = None
self.devices_found: list[BroadcastMessage] = []
self._mutex = Lock()

def __del__(self):
self.close()

def datagram_received(self, data, _):
[broadcast_message], _ = BroadcastParser.parse(data)
if broadcast_message.payload:
parsed_message = BroadcastMessage.from_dict(json.loads(broadcast_message.payload))
_LOGGER.debug(f"Received broadcast: {parsed_message}")
self.devices_found.append(parsed_message)

async def discover(self):
async with self._mutex:
try:
loop = asyncio.get_event_loop()
self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
await asyncio.sleep(self.timeout)
return self.devices_found
finally:
self.close()
self.devices_found = []

def close(self):
self.transport.close() if self.transport else None


class Utils:
"""Util class for protocol manipulation."""

Expand Down Expand Up @@ -324,19 +288,6 @@ def _build(self, obj, stream, context, path):
"remaining" / Optional(GreedyBytes),
)

_BroadcastMessage = Struct(
"message"
/ RawCopy(
Struct(
"version" / Bytes(3),
"seq" / Int32ub,
"protocol" / Int16ub,
"payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
)
),
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)


class _Parser:
def __init__(self, con: Construct, required_local_key: bool):
Expand Down Expand Up @@ -390,7 +341,6 @@ def build(


MessageParser: _Parser = _Parser(_Messages, True)
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)


def create_mqtt_params(rriot: RRiot) -> MqttParams:
Expand Down