Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
from asyncio import Lock
from asyncio.exceptions import TimeoutError, CancelledError
from typing import Any
from urllib.parse import urlparse

import aiohttp
Expand Down Expand Up @@ -53,31 +54,31 @@
MQTT_KEEPALIVE = 60


def md5hex(message: str):
def md5hex(message: str) -> str:
md5 = hashlib.md5()
md5.update(message.encode())
return md5.hexdigest()


def md5bin(message: str):
def md5bin(message: str) -> bytes:
md5 = hashlib.md5()
md5.update(message.encode())
return md5.digest()


def encode_timestamp(_timestamp: int):
def encode_timestamp(_timestamp: int) -> str:
hex_value = f"{_timestamp:x}".zfill(8)
return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4])))


class PreparedRequest:
def __init__(self, base_url: str, base_headers: dict = None):
def __init__(self, base_url: str, base_headers: dict = None) -> None:
self.base_url = base_url
self.base_headers = base_headers or {}

async def request(
self, method: str, url: str, params=None, data=None, headers=None
):
) -> dict | list:
_url = "/".join(s.strip("/") for s in [self.base_url, url])
_headers = {**self.base_headers, **(headers or {})}
async with aiohttp.ClientSession() as session:
Expand All @@ -99,7 +100,7 @@ async def request(
class RoborockMqttClient(mqtt.Client):
_thread: threading.Thread

def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]):
def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]) -> None:
rriot = user_data.rriot
self._mqtt_user = rriot.user
self._mqtt_domain = rriot.domain
Expand All @@ -126,11 +127,11 @@ def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo
self._last_device_msg_in = mqtt.time_func()
self._last_disconnection = mqtt.time_func()

def __del__(self):
def __del__(self) -> None:
self.sync_disconnect()

@run_in_executor()
async def on_connect(self, _client, _, __, rc, ___=None):
async def on_connect(self, _client, _, __, rc, ___=None) -> None:
connection_queue = self._waiting_queue.get(0)
if rc != mqtt.MQTT_ERR_SUCCESS:
message = f"Failed to connect (rc: {rc})"
Expand All @@ -156,7 +157,7 @@ async def on_connect(self, _client, _, __, rc, ___=None):
await connection_queue.async_put((True, None), timeout=QUEUE_TIMEOUT)

@run_in_executor()
async def on_message(self, _client, _, msg, __=None):
async def on_message(self, _client, _, msg, __=None) -> None:
try:
async with self._mutex:
self._last_device_msg_in = mqtt.time_func()
Expand Down Expand Up @@ -219,7 +220,7 @@ async def on_message(self, _client, _, msg, __=None):
_LOGGER.exception(ex)

@run_in_executor()
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
try:
async with self._mutex:
self._last_disconnection = mqtt.time_func()
Expand All @@ -241,28 +242,28 @@ async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
_LOGGER.exception(ex)

@run_in_executor()
async def _async_check_keepalive(self):
async def _async_check_keepalive(self) -> None:
async with self._mutex:
now = mqtt.time_func()
if now - self._last_disconnection > self._keepalive ** 2 and now - self._last_device_msg_in > self._keepalive:
self._ping_t = self._last_device_msg_in

def _check_keepalive(self):
def _check_keepalive(self) -> None:
self._async_check_keepalive()
super()._check_keepalive()

def sync_stop_loop(self):
def sync_stop_loop(self) -> None:
if self._thread:
_LOGGER.info("Stopping mqtt loop")
super().loop_stop()

def sync_start_loop(self):
def sync_start_loop(self) -> None:
if not self._thread or not self._thread.is_alive():
self.sync_stop_loop()
_LOGGER.info("Starting mqtt loop")
super().loop_start()

def sync_disconnect(self):
def sync_disconnect(self) -> bool:
rc = mqtt.MQTT_ERR_AGAIN
if self.is_connected():
_LOGGER.info("Disconnecting from mqtt")
Expand All @@ -271,7 +272,7 @@ def sync_disconnect(self):
raise RoborockException(f"Failed to disconnect (rc:{rc})")
return rc == mqtt.MQTT_ERR_SUCCESS

def sync_connect(self):
def sync_connect(self) -> bool:
rc = mqtt.MQTT_ERR_AGAIN
self.sync_start_loop()
if not self.is_connected():
Expand All @@ -285,7 +286,7 @@ def sync_connect(self):
raise RoborockException(f"Failed to connect (rc:{rc})")
return rc == mqtt.MQTT_ERR_SUCCESS

async def _async_response(self, request_id: int, protocol_id: int = 0):
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, RoborockException | None]:
try:
queue = RoborockQueue(protocol_id)
self._waiting_queue[request_id] = queue
Expand All @@ -298,7 +299,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0):
finally:
del self._waiting_queue[request_id]

async def async_disconnect(self):
async def async_disconnect(self) -> Any:
async with self._mutex:
disconnecting = self.sync_disconnect()
if disconnecting:
Expand All @@ -307,7 +308,7 @@ async def async_disconnect(self):
raise RoborockException(err) from err
return response

async def async_connect(self):
async def async_connect(self) -> Any:
async with self._mutex:
connecting = self.sync_connect()
if connecting:
Expand All @@ -316,10 +317,10 @@ async def async_connect(self):
raise RoborockException(err) from err
return response

async def validate_connection(self):
async def validate_connection(self) -> None:
await self.async_connect()

def _decode_msg(self, msg, device: HomeDataDevice):
def _decode_msg(self, msg, device: HomeDataDevice) -> dict[str, Any]:
if msg[0:3] != "1.0".encode():
raise RoborockException("Unknown protocol version")
crc32 = binascii.crc32(msg[0: len(msg) - 4])
Expand All @@ -344,7 +345,7 @@ def _decode_msg(self, msg, device: HomeDataDevice):
"payload": decrypted_payload,
}

def _send_msg_raw(self, device_id, protocol, timestamp, payload):
def _send_msg_raw(self, device_id, protocol, timestamp, payload) -> None:
local_key = self.device_map[device_id].device.local_key
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
cipher = AES.new(aes_key, AES.MODE_ECB)
Expand Down Expand Up @@ -438,7 +439,7 @@ async def get_consumable(self, device_id: str) -> Consumable:
if isinstance(consumable, dict):
return Consumable(consumable)

async def get_prop(self, device_id: str):
async def get_prop(self, device_id: str) -> RoborockDeviceProp:
[status, dnd_timer, clean_summary, consumable] = await asyncio.gather(
*[
self.get_status(device_id),
Expand All @@ -457,7 +458,7 @@ async def get_prop(self, device_id: str):
status, dnd_timer, clean_summary, consumable, last_clean_record
)

async def get_multi_maps_list(self, device_id):
async def get_multi_maps_list(self, device_id) -> MultiMapsList:
multi_maps_list = await self.send_command(
device_id, RoborockCommand.GET_MULTI_MAPS_LIST
)
Expand All @@ -476,7 +477,7 @@ def __init__(self, username: str, base_url=None) -> None:
self.base_url = base_url
self._device_identifier = secrets.token_urlsafe(16)

async def _get_base_url(self):
async def _get_base_url(self) -> str:
if not self.base_url:
url_request = PreparedRequest(self._default_url)
response = await url_request.request(
Expand All @@ -495,7 +496,7 @@ def _get_header_client_id(self):
md5.update(self._device_identifier.encode())
return base64.b64encode(md5.digest()).decode()

async def request_code(self):
async def request_code(self) -> None:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
Expand All @@ -512,7 +513,7 @@ async def request_code(self):
if code_response.get("code") != 200:
raise RoborockException(code_response.get("msg"))

async def pass_login(self, password: str):
async def pass_login(self, password: str) -> UserData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()

Expand All @@ -531,7 +532,7 @@ async def pass_login(self, password: str):
raise RoborockException(login_response.get("msg"))
return UserData(login_response.get("data"))

async def code_login(self, code):
async def code_login(self, code) -> UserData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()

Expand All @@ -550,7 +551,7 @@ async def code_login(self, code):
raise RoborockException(login_response.get("msg"))
return UserData(login_response.get("data"))

async def get_home_data(self, user_data: UserData):
async def get_home_data(self, user_data: UserData) -> HomeData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()
rriot = user_data.rriot
Expand Down
Loading