Skip to content

Commit 30d2577

Browse files
authored
chore: move more things around in version 1 api (#198)
* chore: move more things around in version 1 api * fix: tests
1 parent 4f44d03 commit 30d2577

File tree

11 files changed

+443
-299
lines changed

11 files changed

+443
-299
lines changed

roborock/api.py

Lines changed: 3 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -4,125 +4,35 @@
44

55
import asyncio
66
import base64
7-
import dataclasses
8-
import hashlib
9-
import json
107
import logging
118
import secrets
12-
import struct
139
import time
1410
from collections.abc import Callable, Coroutine
15-
from typing import Any, TypeVar, final
11+
from typing import Any
1612

17-
from .command_cache import CacheableAttribute, CommandType, RoborockAttribute, find_cacheable_attribute, get_cache_map
1813
from .containers import (
19-
Consumable,
2014
DeviceData,
2115
ModelStatus,
22-
RoborockBase,
2316
S7MaxVStatus,
2417
Status,
2518
)
2619
from .exceptions import (
27-
RoborockException,
2820
RoborockTimeout,
2921
UnknownMethodError,
3022
VacuumError,
3123
)
32-
from .protocol import Utils
3324
from .roborock_future import RoborockFuture
3425
from .roborock_message import (
35-
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
36-
ROBOROCK_DATA_STATUS_PROTOCOL,
37-
RoborockDataProtocol,
3826
RoborockMessage,
39-
RoborockMessageProtocol,
4027
)
4128
from .roborock_typing import RoborockCommand
42-
from .util import RepeatableTask, RoborockLoggerAdapter, get_running_loop_or_create_one
29+
from .util import RoborockLoggerAdapter, get_running_loop_or_create_one
4330

4431
_LOGGER = logging.getLogger(__name__)
4532
KEEPALIVE = 60
46-
RT = TypeVar("RT", bound=RoborockBase)
47-
48-
49-
def md5hex(message: str) -> str:
50-
md5 = hashlib.md5()
51-
md5.update(message.encode())
52-
return md5.hexdigest()
53-
54-
55-
EVICT_TIME = 60
56-
57-
58-
class AttributeCache:
59-
def __init__(self, attribute: RoborockAttribute, api: RoborockClient):
60-
self.attribute = attribute
61-
self.api = api
62-
self.attribute = attribute
63-
self.task = RepeatableTask(self.api.event_loop, self._async_value, EVICT_TIME)
64-
self._value: Any = None
65-
self._mutex = asyncio.Lock()
66-
self.unsupported: bool = False
67-
68-
@property
69-
def value(self):
70-
return self._value
71-
72-
async def _async_value(self):
73-
if self.unsupported:
74-
return None
75-
try:
76-
self._value = await self.api._send_command(self.attribute.get_command)
77-
except UnknownMethodError as err:
78-
# Limit the amount of times we call unsupported methods
79-
self.unsupported = True
80-
raise err
81-
return self._value
82-
83-
async def async_value(self):
84-
async with self._mutex:
85-
if self._value is None:
86-
return await self.task.reset()
87-
return self._value
88-
89-
def stop(self):
90-
self.task.cancel()
91-
92-
async def update_value(self, params):
93-
if self.attribute.set_command is None:
94-
raise RoborockException(f"{self.attribute.attribute} have no set command")
95-
response = await self.api._send_command(self.attribute.set_command, params)
96-
await self._async_value()
97-
return response
98-
99-
async def add_value(self, params):
100-
if self.attribute.add_command is None:
101-
raise RoborockException(f"{self.attribute.attribute} have no add command")
102-
response = await self.api._send_command(self.attribute.add_command, params)
103-
await self._async_value()
104-
return response
105-
106-
async def close_value(self, params=None):
107-
if self.attribute.close_command is None:
108-
raise RoborockException(f"{self.attribute.attribute} have no close command")
109-
response = await self.api._send_command(self.attribute.close_command, params)
110-
await self._async_value()
111-
return response
112-
113-
async def refresh_value(self):
114-
await self._async_value()
115-
116-
117-
@dataclasses.dataclass
118-
class ListenerModel:
119-
protocol_handlers: dict[RoborockDataProtocol, list[Callable[[Status | Consumable], None]]]
120-
cache: dict[CacheableAttribute, AttributeCache]
12133

12234

12335
class RoborockClient:
124-
_listeners: dict[str, ListenerModel] = {}
125-
12636
def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int = 4) -> None:
12737
self.event_loop = get_running_loop_or_create_one()
12838
self.device_info = device_info
@@ -136,15 +46,9 @@ def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int =
13646
"misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
13747
}
13848
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
139-
self.cache: dict[CacheableAttribute, AttributeCache] = {
140-
cacheable_attribute: AttributeCache(attr, self) for cacheable_attribute, attr in get_cache_map().items()
141-
}
14249
self.is_available: bool = True
14350
self.queue_timeout = queue_timeout
14451
self._status_type: type[Status] = ModelStatus.get(self.device_info.model, S7MaxVStatus)
145-
if device_info.device.duid not in self._listeners:
146-
self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
147-
self.listener_model = self._listeners[device_info.device.duid]
14852

14953
def __del__(self) -> None:
15054
self.release()
@@ -156,11 +60,9 @@ def status_type(self) -> type[Status]:
15660

15761
def release(self):
15862
self.sync_disconnect()
159-
[item.stop() for item in self.cache.values()]
16063

16164
async def async_release(self):
16265
await self.async_disconnect()
163-
[item.stop() for item in self.cache.values()]
16466

16567
@property
16668
def diagnostic_data(self) -> dict:
@@ -185,95 +87,7 @@ async def async_disconnect(self) -> Any:
18587
raise NotImplementedError
18688

18789
def on_message_received(self, messages: list[RoborockMessage]) -> None:
188-
try:
189-
self._last_device_msg_in = self.time_func()
190-
for data in messages:
191-
protocol = data.protocol
192-
if data.payload and protocol in [
193-
RoborockMessageProtocol.RPC_RESPONSE,
194-
RoborockMessageProtocol.GENERAL_REQUEST,
195-
]:
196-
payload = json.loads(data.payload.decode())
197-
for data_point_number, data_point in payload.get("dps").items():
198-
if data_point_number == "102":
199-
data_point_response = json.loads(data_point)
200-
request_id = data_point_response.get("id")
201-
queue = self._waiting_queue.get(request_id)
202-
if queue and queue.protocol == protocol:
203-
error = data_point_response.get("error")
204-
if error:
205-
queue.resolve(
206-
(
207-
None,
208-
VacuumError(
209-
error.get("code"),
210-
error.get("message"),
211-
),
212-
)
213-
)
214-
else:
215-
result = data_point_response.get("result")
216-
if isinstance(result, list) and len(result) == 1:
217-
result = result[0]
218-
queue.resolve((result, None))
219-
else:
220-
try:
221-
data_protocol = RoborockDataProtocol(int(data_point_number))
222-
self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}")
223-
if data_protocol in ROBOROCK_DATA_STATUS_PROTOCOL:
224-
if data_protocol not in self.listener_model.protocol_handlers:
225-
self._logger.debug(
226-
f"Got status update({data_protocol.name}) before get_status was called."
227-
)
228-
return
229-
value = self.listener_model.cache[CacheableAttribute.status].value
230-
value[data_protocol.name] = data_point
231-
status = self._status_type.from_dict(value)
232-
for listener in self.listener_model.protocol_handlers.get(data_protocol, []):
233-
listener(status)
234-
elif data_protocol in ROBOROCK_DATA_CONSUMABLE_PROTOCOL:
235-
if data_protocol not in self.listener_model.protocol_handlers:
236-
self._logger.debug(
237-
f"Got consumable update({data_protocol.name})"
238-
+ "before get_consumable was called."
239-
)
240-
return
241-
value = self.listener_model.cache[CacheableAttribute.consumable].value
242-
value[data_protocol.name] = data_point
243-
consumable = Consumable.from_dict(value)
244-
for listener in self.listener_model.protocol_handlers.get(data_protocol, []):
245-
listener(consumable)
246-
return
247-
except ValueError:
248-
self._logger.warning(
249-
f"Got listener data for {data_point_number}, data: {data_point}. "
250-
f"This lets us update data quicker, please open an issue "
251-
f"at https://github.com/humbertogontijo/python-roborock/issues"
252-
)
253-
254-
pass
255-
dps = {data_point_number: data_point}
256-
self._logger.debug(f"Got unknown data point {dps}")
257-
elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
258-
payload = data.payload[0:24]
259-
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload)
260-
if endpoint.decode().startswith(self._endpoint):
261-
try:
262-
decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce)
263-
except ValueError as err:
264-
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
265-
decompressed = Utils.decompress(decrypted)
266-
queue = self._waiting_queue.get(request_id)
267-
if queue:
268-
if isinstance(decompressed, list):
269-
decompressed = decompressed[0]
270-
queue.resolve((decompressed, None))
271-
else:
272-
queue = self._waiting_queue.get(data.seq)
273-
if queue:
274-
queue.resolve((data.payload, None))
275-
except Exception as ex:
276-
self._logger.exception(ex)
90+
raise NotImplementedError
27791

27892
def on_connection_lost(self, exc: Exception | None) -> None:
27993
self._last_disconnection = self.time_func()
@@ -320,47 +134,3 @@ async def _send_command(
320134
params: list | dict | int | None = None,
321135
):
322136
raise NotImplementedError
323-
324-
@final
325-
async def send_command(
326-
self,
327-
method: RoborockCommand | str,
328-
params: list | dict | int | None = None,
329-
return_type: type[RT] | None = None,
330-
) -> RT:
331-
cacheable_attribute_result = find_cacheable_attribute(method)
332-
333-
cache = None
334-
command_type = None
335-
if cacheable_attribute_result is not None:
336-
cache = self.cache[cacheable_attribute_result.attribute]
337-
command_type = cacheable_attribute_result.type
338-
339-
response: Any = None
340-
if cache is not None and command_type == CommandType.GET:
341-
response = await cache.async_value()
342-
else:
343-
response = await self._send_command(method, params)
344-
if cache is not None and command_type == CommandType.CHANGE:
345-
await cache.refresh_value()
346-
347-
if return_type:
348-
return return_type.from_dict(response)
349-
return response
350-
351-
def add_listener(
352-
self, protocol: RoborockDataProtocol, listener: Callable, cache: dict[CacheableAttribute, AttributeCache]
353-
) -> None:
354-
self.listener_model.cache = cache
355-
if protocol not in self.listener_model.protocol_handlers:
356-
self.listener_model.protocol_handlers[protocol] = []
357-
self.listener_model.protocol_handlers[protocol].append(listener)
358-
359-
def remove_listener(self, protocol: RoborockDataProtocol, listener: Callable) -> None:
360-
self.listener_model.protocol_handlers[protocol].remove(listener)
361-
362-
async def get_from_cache(self, key: CacheableAttribute) -> AttributeCache | None:
363-
val = self.cache.get(key)
364-
if val is not None:
365-
return await val.async_value()
366-
return None

roborock/cloud_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@
44
import base64
55
import logging
66
import threading
7+
import typing
78
import uuid
89
from asyncio import Lock, Task
910
from typing import Any
1011
from urllib.parse import urlparse
1112

1213
import paho.mqtt.client as mqtt
1314

14-
from .api import KEEPALIVE, RoborockClient, md5hex
15+
from .api import KEEPALIVE, RoborockClient
1516
from .containers import DeviceData, UserData
1617
from .exceptions import RoborockException, VacuumError
17-
from .protocol import MessageParser, Utils
18+
from .protocol import MessageParser, Utils, md5hex
1819
from .roborock_future import RoborockFuture
1920
from .roborock_message import RoborockMessage
2021
from .roborock_typing import RoborockCommand
2122
from .util import RoborockLoggerAdapter
2223

24+
if typing.TYPE_CHECKING:
25+
pass
2326
_LOGGER = logging.getLogger(__name__)
2427
CONNECT_REQUEST_ID = 0
2528
DISCONNECT_REQUEST_ID = 1
@@ -78,7 +81,7 @@ def on_connect(self, *args, **kwargs):
7881
connection_queue.resolve((True, None))
7982

8083
def on_message(self, *args, **kwargs):
81-
_, __, msg = args
84+
client, __, msg = args
8285
try:
8386
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
8487
super().on_message_received(messages)

0 commit comments

Comments
 (0)