Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ jobs:
with:
fetch-depth: 0
- uses: wagoid/commitlint-github-action@v5.3.0
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0

test:
strategy:
Expand Down Expand Up @@ -57,8 +65,8 @@ jobs:
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false
fetch-depth: 0
persist-credentials: false

# Run semantic release:
# - Update CHANGELOG.md
Expand Down
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [ commit ]


repos:
- repo: https://github.com/python-poetry/poetry
rev: 1.3.2
hooks:
- id: poetry-check
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
hooks:
- id: mypy
exclude: cli.py
additional_dependencies: [ "types-paho-mqtt" ]
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy]
check_untyped_defs = True
exclude = cli.py
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.dev-dependencies]
pytest-asyncio = "*"
pytest = "*"
pre-commit = "*"
mypy = "*"

[tool.semantic_release]
branch = "main"
Expand Down
111 changes: 76 additions & 35 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import struct
import time
from random import randint
from typing import Any, Callable
from typing import Optional, Any, Callable, Coroutine, Mapping

import aiohttp
from Crypto.Cipher import AES
Expand All @@ -23,7 +23,7 @@
from roborock.exceptions import (
RoborockException, RoborockTimeout, VacuumError,
)
from .code_mappings import RoborockDockTypeCode
from .code_mappings import RoborockDockTypeCode, RoborockEnum
from .containers import (
UserData,
Status,
Expand Down Expand Up @@ -63,29 +63,29 @@ def md5hex(message: str) -> str:


class PreparedRequest:
def __init__(self, base_url: str, base_headers: dict = None) -> None:
def __init__(self, base_url: str, base_headers: Optional[dict] = None) -> None:
self.base_url = base_url
self.base_headers = base_headers or {}

async def request(
self, method: str, url: str, params=None, data=None, headers=None
) -> dict | list:
self, method: str, url: str, params=None, data=None, headers=None
) -> dict:
_url = "/".join(s.strip("/") for s in [self.base_url, url])
_headers = {**self.base_headers, **(headers or {})}
async with aiohttp.ClientSession() as session:
async with session.request(
method,
_url,
params=params,
data=data,
headers=_headers,
method,
_url,
params=params,
data=data,
headers=_headers,
) as resp:
return await resp.json()


class RoborockClient:

def __init__(self, endpoint: str, devices_info: dict[str, RoborockDeviceInfo]) -> None:
def __init__(self, endpoint: str, devices_info: Mapping[str, RoborockDeviceInfo]) -> None:
self.devices_info = devices_info
self._endpoint = endpoint
self._nonce = secrets.token_bytes(16)
Expand Down Expand Up @@ -161,7 +161,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[
del self._waiting_queue[request_id]

def _get_payload(
self, method: RoborockCommand, params: list = None, secured=False
self, method: RoborockCommand, params: Optional[list] = None, secured=False
):
timestamp = math.floor(time.time())
request_id = randint(10000, 99999)
Expand All @@ -187,24 +187,26 @@ def _get_payload(
return request_id, timestamp, payload

async def send_command(
self, device_id: str, method: RoborockCommand, params: list = None
self, device_id: str, method: RoborockCommand, params: Optional[list] = None
):
raise NotImplementedError

async def get_status(self, device_id: str) -> Status:
async def get_status(self, device_id: str) -> Status | None:
status = await self.send_command(device_id, RoborockCommand.GET_STATUS)
if isinstance(status, dict):
return Status.from_dict(status)
return None

async def get_dnd_timer(self, device_id: str) -> DNDTimer:
async def get_dnd_timer(self, device_id: str) -> DNDTimer | None:
try:
dnd_timer = await self.send_command(device_id, RoborockCommand.GET_DND_TIMER)
if isinstance(dnd_timer, dict):
return DNDTimer.from_dict(dnd_timer)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_clean_summary(self, device_id: str) -> CleanSummary:
async def get_clean_summary(self, device_id: str) -> CleanSummary | None:
try:
clean_summary = await self.send_command(
device_id, RoborockCommand.GET_CLEAN_SUMMARY
Expand All @@ -215,8 +217,9 @@ async def get_clean_summary(self, device_id: str) -> CleanSummary:
return CleanSummary(clean_time=int.from_bytes(clean_summary, 'big'))
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_clean_record(self, device_id: str, record_id: int) -> CleanRecord:
async def get_clean_record(self, device_id: str, record_id: int) -> CleanRecord | None:
try:
clean_record = await self.send_command(
device_id, RoborockCommand.GET_CLEAN_RECORD, [record_id]
Expand All @@ -225,56 +228,68 @@ async def get_clean_record(self, device_id: str, record_id: int) -> CleanRecord:
return CleanRecord.from_dict(clean_record)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_consumable(self, device_id: str) -> Consumable:
async def get_consumable(self, device_id: str) -> Consumable | None:
try:
consumable = await self.send_command(device_id, RoborockCommand.GET_CONSUMABLE)
if isinstance(consumable, dict):
return Consumable.from_dict(consumable)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_wash_towel_mode(self, device_id: str) -> WashTowelMode:
async def get_wash_towel_mode(self, device_id: str) -> WashTowelMode | None:
try:
washing_mode = await self.send_command(device_id, RoborockCommand.GET_WASH_TOWEL_MODE)
if isinstance(washing_mode, dict):
return WashTowelMode.from_dict(washing_mode)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_dust_collection_mode(self, device_id: str) -> DustCollectionMode:
async def get_dust_collection_mode(self, device_id: str) -> DustCollectionMode | None:
try:
dust_collection = await self.send_command(device_id, RoborockCommand.GET_DUST_COLLECTION_MODE)
if isinstance(dust_collection, dict):
return DustCollectionMode.from_dict(dust_collection)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_smart_wash_params(self, device_id: str) -> SmartWashParams:
async def get_smart_wash_params(self, device_id: str) -> SmartWashParams | None:
try:
mop_wash_mode = await self.send_command(device_id, RoborockCommand.GET_SMART_WASH_PARAMS)
if isinstance(mop_wash_mode, dict):
return SmartWashParams.from_dict(mop_wash_mode)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_dock_summary(self, device_id: str, dock_type: RoborockDockTypeCode) -> RoborockDockSummary:
async def get_dock_summary(self, device_id: str, dock_type: RoborockEnum) -> RoborockDockSummary | None:
"""Gets the status summary from the dock with the methods available for a given dock.

:param dock_type: RoborockDockTypeCode"""
if RoborockDockTypeCode.name != "RoborockDockTypeCode":
raise RoborockException("Invalid enum given for dock type")
try:
commands = [self.get_dust_collection_mode(device_id)]
commands: list[Coroutine[Any, Any, DustCollectionMode | WashTowelMode | SmartWashParams | None]] = [
self.get_dust_collection_mode(device_id)]
if dock_type == RoborockDockTypeCode['3']:
commands += [self.get_wash_towel_mode(device_id), self.get_smart_wash_params(device_id)]
[
dust_collection_mode,
wash_towel_mode,
smart_wash_params
] = (
list(await asyncio.gather(*commands))
+ [None, None]
list(await asyncio.gather(*commands))
+ [None, None]
)[:3]

return RoborockDockSummary(dust_collection_mode, wash_towel_mode, smart_wash_params)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_prop(self, device_id: str) -> RoborockDeviceProp | None:
[status, dnd_timer, clean_summary, consumable] = await asyncio.gather(
Expand All @@ -299,7 +314,7 @@ async def get_prop(self, device_id: str) -> RoborockDeviceProp | None:
)
return None

async def get_multi_maps_list(self, device_id) -> MultiMapsList:
async def get_multi_maps_list(self, device_id) -> MultiMapsList | None:
try:
multi_maps_list = await self.send_command(
device_id, RoborockCommand.GET_MULTI_MAPS_LIST
Expand All @@ -308,14 +323,16 @@ async def get_multi_maps_list(self, device_id) -> MultiMapsList:
return MultiMapsList.from_dict(multi_maps_list)
except RoborockTimeout as e:
_LOGGER.error(e)
return None

async def get_networking(self, device_id) -> NetworkInfo:
async def get_networking(self, device_id) -> NetworkInfo | None:
try:
networking_info = await self.send_command(device_id, RoborockCommand.GET_NETWORK_INFO)
if isinstance(networking_info, dict):
return NetworkInfo.from_dict(networking_info)
except RoborockTimeout as e:
_LOGGER.error(e)
return None


class RoborockApiClient:
Expand All @@ -334,9 +351,14 @@ async def _get_base_url(self) -> str:
"/api/v1/getUrlByEmail",
params={"email": self._username, "needtwostepauth": "false"},
)
if response is None:
raise RoborockException("get url by email returned None")
if response.get("code") != 200:
raise RoborockException(response.get("error"))
self.base_url = response.get("data").get("url")
response_data = response.get("data")
if response_data is None:
raise RoborockException("response does not have 'data'")
self.base_url = response_data.get("url")
return self.base_url

def _get_header_client_id(self):
Expand All @@ -358,7 +380,8 @@ async def request_code(self) -> None:
"type": "auth",
},
)

if code_response is None:
raise RoborockException("Failed to get a response from send email code")
if code_response.get("code") != 200:
raise RoborockException(code_response.get("msg"))

Expand All @@ -376,10 +399,14 @@ async def pass_login(self, password: str) -> UserData:
"needtwostepauth": "false",
},
)

if login_response is None:
raise RoborockException("Login response is none")
if login_response.get("code") != 200:
raise RoborockException(login_response.get("msg"))
return UserData.from_dict(login_response.get("data"))
user_data = login_response.get("data")
if not isinstance(user_data, dict):
raise RoborockException("Got unexpected data type for user_data")
return UserData.from_dict(user_data)

async def code_login(self, code) -> UserData:
base_url = await self._get_base_url()
Expand All @@ -395,15 +422,21 @@ async def code_login(self, code) -> UserData:
"verifycodetype": "AUTH_EMAIL_CODE",
},
)

if login_response is None:
raise RoborockException("Login request response is None")
if login_response.get("code") != 200:
raise RoborockException(login_response.get("msg"))
return UserData.from_dict(login_response.get("data"))
user_data = login_response.get("data")
if not isinstance(user_data, dict):
raise RoborockException("Got unexpected data type for user_data")
return UserData.from_dict(user_data)

async def get_home_data(self, user_data: UserData) -> HomeData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()
rriot = user_data.rriot
if rriot is None:
raise RoborockException("rriot is none")
home_id_request = PreparedRequest(
base_url, {"header_clientid": header_clientid}
)
Expand All @@ -412,9 +445,12 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
"/api/v1/getHomeDetail",
headers={"Authorization": user_data.token},
)
if home_id_response is None:
raise RoborockException("home_id_response is None")
if home_id_response.get("code") != 200:
raise RoborockException(home_id_response.get("msg"))
home_id = home_id_response.get("data").get("rrHomeId")

home_id = home_id_response['data'].get("rrHomeId")
timestamp = math.floor(time.time())
nonce = secrets.token_urlsafe(6)
prestr = ":".join(
Expand All @@ -431,6 +467,8 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
mac = base64.b64encode(
hmac.new(rriot.h.encode(), prestr.encode(), hashlib.sha256).digest()
).decode()
if rriot.r.a is None:
raise RoborockException("Missing field 'a' in rriot reference")
home_request = PreparedRequest(
rriot.r.a,
{
Expand All @@ -442,4 +480,7 @@ async def get_home_data(self, user_data: UserData) -> HomeData:
if not home_response.get("success"):
raise RoborockException(home_response)
home_data = home_response.get("result")
return HomeData.from_dict(home_data)
if isinstance(home_data, dict):
return HomeData.from_dict(home_data)
else:
raise RoborockException("home_response result was an unexpected type")
4 changes: 3 additions & 1 deletion roborock/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import logging
from pathlib import Path
Expand All @@ -16,7 +18,7 @@

class RoborockContext:
roborock_file = Path("~/.roborock").expanduser()
_login_data: LoginData = None
_login_data: LoginData | None = None

def __init__(self):
self.reload()
Expand Down
Loading