Skip to content

Commit d514ae0

Browse files
Merge branch 'main' into crypto_bump
2 parents 1931073 + 311af16 commit d514ae0

File tree

5 files changed

+158
-150
lines changed

5 files changed

+158
-150
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ roborock = "roborock.cli:main"
1414
python = "^3.8"
1515
click = ">=8"
1616
aiohttp = "*"
17+
async-timeout = "*"
1718
pycryptodome = "~3.17.0"
1819
pycryptodomex = {version = "~3.17.0", markers = "sys_platform == 'darwin'"}
1920
paho-mqtt = "~1.6.1"

roborock/api.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
from asyncio import Lock
1818
from asyncio.exceptions import TimeoutError, CancelledError
19+
from typing import Any
1920
from urllib.parse import urlparse
2021

2122
import aiohttp
@@ -53,31 +54,31 @@
5354
MQTT_KEEPALIVE = 60
5455

5556

56-
def md5hex(message: str):
57+
def md5hex(message: str) -> str:
5758
md5 = hashlib.md5()
5859
md5.update(message.encode())
5960
return md5.hexdigest()
6061

6162

62-
def md5bin(message: str):
63+
def md5bin(message: str) -> bytes:
6364
md5 = hashlib.md5()
6465
md5.update(message.encode())
6566
return md5.digest()
6667

6768

68-
def encode_timestamp(_timestamp: int):
69+
def encode_timestamp(_timestamp: int) -> str:
6970
hex_value = f"{_timestamp:x}".zfill(8)
7071
return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4])))
7172

7273

7374
class PreparedRequest:
74-
def __init__(self, base_url: str, base_headers: dict = None):
75+
def __init__(self, base_url: str, base_headers: dict = None) -> None:
7576
self.base_url = base_url
7677
self.base_headers = base_headers or {}
7778

7879
async def request(
7980
self, method: str, url: str, params=None, data=None, headers=None
80-
):
81+
) -> dict | list:
8182
_url = "/".join(s.strip("/") for s in [self.base_url, url])
8283
_headers = {**self.base_headers, **(headers or {})}
8384
async with aiohttp.ClientSession() as session:
@@ -99,7 +100,7 @@ async def request(
99100
class RoborockMqttClient(mqtt.Client):
100101
_thread: threading.Thread
101102

102-
def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]):
103+
def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]) -> None:
103104
rriot = user_data.rriot
104105
self._mqtt_user = rriot.user
105106
self._mqtt_domain = rriot.domain
@@ -126,11 +127,11 @@ def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo
126127
self._last_device_msg_in = mqtt.time_func()
127128
self._last_disconnection = mqtt.time_func()
128129

129-
def __del__(self):
130+
def __del__(self) -> None:
130131
self.sync_disconnect()
131132

132133
@run_in_executor()
133-
async def on_connect(self, _client, _, __, rc, ___=None):
134+
async def on_connect(self, _client, _, __, rc, ___=None) -> None:
134135
connection_queue = self._waiting_queue.get(0)
135136
if rc != mqtt.MQTT_ERR_SUCCESS:
136137
message = f"Failed to connect (rc: {rc})"
@@ -156,7 +157,7 @@ async def on_connect(self, _client, _, __, rc, ___=None):
156157
await connection_queue.async_put((True, None), timeout=QUEUE_TIMEOUT)
157158

158159
@run_in_executor()
159-
async def on_message(self, _client, _, msg, __=None):
160+
async def on_message(self, _client, _, msg, __=None) -> None:
160161
try:
161162
async with self._mutex:
162163
self._last_device_msg_in = mqtt.time_func()
@@ -219,7 +220,7 @@ async def on_message(self, _client, _, msg, __=None):
219220
_LOGGER.exception(ex)
220221

221222
@run_in_executor()
222-
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
223+
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
223224
try:
224225
async with self._mutex:
225226
self._last_disconnection = mqtt.time_func()
@@ -241,28 +242,28 @@ async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
241242
_LOGGER.exception(ex)
242243

243244
@run_in_executor()
244-
async def _async_check_keepalive(self):
245+
async def _async_check_keepalive(self) -> None:
245246
async with self._mutex:
246247
now = mqtt.time_func()
247248
if now - self._last_disconnection > self._keepalive ** 2 and now - self._last_device_msg_in > self._keepalive:
248249
self._ping_t = self._last_device_msg_in
249250

250-
def _check_keepalive(self):
251+
def _check_keepalive(self) -> None:
251252
self._async_check_keepalive()
252253
super()._check_keepalive()
253254

254-
def sync_stop_loop(self):
255+
def sync_stop_loop(self) -> None:
255256
if self._thread:
256257
_LOGGER.info("Stopping mqtt loop")
257258
super().loop_stop()
258259

259-
def sync_start_loop(self):
260+
def sync_start_loop(self) -> None:
260261
if not self._thread or not self._thread.is_alive():
261262
self.sync_stop_loop()
262263
_LOGGER.info("Starting mqtt loop")
263264
super().loop_start()
264265

265-
def sync_disconnect(self):
266+
def sync_disconnect(self) -> bool:
266267
rc = mqtt.MQTT_ERR_AGAIN
267268
if self.is_connected():
268269
_LOGGER.info("Disconnecting from mqtt")
@@ -271,7 +272,7 @@ def sync_disconnect(self):
271272
raise RoborockException(f"Failed to disconnect (rc:{rc})")
272273
return rc == mqtt.MQTT_ERR_SUCCESS
273274

274-
def sync_connect(self):
275+
def sync_connect(self) -> bool:
275276
rc = mqtt.MQTT_ERR_AGAIN
276277
self.sync_start_loop()
277278
if not self.is_connected():
@@ -285,7 +286,7 @@ def sync_connect(self):
285286
raise RoborockException(f"Failed to connect (rc:{rc})")
286287
return rc == mqtt.MQTT_ERR_SUCCESS
287288

288-
async def _async_response(self, request_id: int, protocol_id: int = 0):
289+
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, RoborockException | None]:
289290
try:
290291
queue = RoborockQueue(protocol_id)
291292
self._waiting_queue[request_id] = queue
@@ -298,7 +299,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0):
298299
finally:
299300
del self._waiting_queue[request_id]
300301

301-
async def async_disconnect(self):
302+
async def async_disconnect(self) -> Any:
302303
async with self._mutex:
303304
disconnecting = self.sync_disconnect()
304305
if disconnecting:
@@ -307,7 +308,7 @@ async def async_disconnect(self):
307308
raise RoborockException(err) from err
308309
return response
309310

310-
async def async_connect(self):
311+
async def async_connect(self) -> Any:
311312
async with self._mutex:
312313
connecting = self.sync_connect()
313314
if connecting:
@@ -316,10 +317,10 @@ async def async_connect(self):
316317
raise RoborockException(err) from err
317318
return response
318319

319-
async def validate_connection(self):
320+
async def validate_connection(self) -> None:
320321
await self.async_connect()
321322

322-
def _decode_msg(self, msg, device: HomeDataDevice):
323+
def _decode_msg(self, msg, device: HomeDataDevice) -> dict[str, Any]:
323324
if msg[0:3] != "1.0".encode():
324325
raise RoborockException("Unknown protocol version")
325326
crc32 = binascii.crc32(msg[0: len(msg) - 4])
@@ -344,7 +345,7 @@ def _decode_msg(self, msg, device: HomeDataDevice):
344345
"payload": decrypted_payload,
345346
}
346347

347-
def _send_msg_raw(self, device_id, protocol, timestamp, payload):
348+
def _send_msg_raw(self, device_id, protocol, timestamp, payload) -> None:
348349
local_key = self.device_map[device_id].device.local_key
349350
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
350351
cipher = AES.new(aes_key, AES.MODE_ECB)
@@ -438,7 +439,7 @@ async def get_consumable(self, device_id: str) -> Consumable:
438439
if isinstance(consumable, dict):
439440
return Consumable(consumable)
440441

441-
async def get_prop(self, device_id: str):
442+
async def get_prop(self, device_id: str) -> RoborockDeviceProp:
442443
[status, dnd_timer, clean_summary, consumable] = await asyncio.gather(
443444
*[
444445
self.get_status(device_id),
@@ -457,7 +458,7 @@ async def get_prop(self, device_id: str):
457458
status, dnd_timer, clean_summary, consumable, last_clean_record
458459
)
459460

460-
async def get_multi_maps_list(self, device_id):
461+
async def get_multi_maps_list(self, device_id) -> MultiMapsList:
461462
multi_maps_list = await self.send_command(
462463
device_id, RoborockCommand.GET_MULTI_MAPS_LIST
463464
)
@@ -476,7 +477,7 @@ def __init__(self, username: str, base_url=None) -> None:
476477
self.base_url = base_url
477478
self._device_identifier = secrets.token_urlsafe(16)
478479

479-
async def _get_base_url(self):
480+
async def _get_base_url(self) -> str:
480481
if not self.base_url:
481482
url_request = PreparedRequest(self._default_url)
482483
response = await url_request.request(
@@ -495,7 +496,7 @@ def _get_header_client_id(self):
495496
md5.update(self._device_identifier.encode())
496497
return base64.b64encode(md5.digest()).decode()
497498

498-
async def request_code(self):
499+
async def request_code(self) -> None:
499500
base_url = await self._get_base_url()
500501
header_clientid = self._get_header_client_id()
501502
code_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
@@ -512,7 +513,7 @@ async def request_code(self):
512513
if code_response.get("code") != 200:
513514
raise RoborockException(code_response.get("msg"))
514515

515-
async def pass_login(self, password: str):
516+
async def pass_login(self, password: str) -> UserData:
516517
base_url = await self._get_base_url()
517518
header_clientid = self._get_header_client_id()
518519

@@ -531,7 +532,7 @@ async def pass_login(self, password: str):
531532
raise RoborockException(login_response.get("msg"))
532533
return UserData(login_response.get("data"))
533534

534-
async def code_login(self, code):
535+
async def code_login(self, code) -> UserData:
535536
base_url = await self._get_base_url()
536537
header_clientid = self._get_header_client_id()
537538

@@ -550,7 +551,7 @@ async def code_login(self, code):
550551
raise RoborockException(login_response.get("msg"))
551552
return UserData(login_response.get("data"))
552553

553-
async def get_home_data(self, user_data: UserData):
554+
async def get_home_data(self, user_data: UserData) -> HomeData:
554555
base_url = await self._get_base_url()
555556
header_clientid = self._get_header_client_id()
556557
rriot = user_data.rriot

0 commit comments

Comments
 (0)