1616import time
1717from asyncio import Lock
1818from asyncio .exceptions import TimeoutError , CancelledError
19+ from typing import Any
1920from urllib .parse import urlparse
2021
2122import aiohttp
5354MQTT_KEEPALIVE = 60
5455
5556
56- def md5hex (message : str ):
57+ def md5hex (message : str ) -> str :
5758 md5 = hashlib .md5 ()
5859 md5 .update (message .encode ())
5960 return md5 .hexdigest ()
6061
6162
62- def md5bin (message : str ):
63+ def md5bin (message : str ) -> bytes :
6364 md5 = hashlib .md5 ()
6465 md5 .update (message .encode ())
6566 return md5 .digest ()
6667
6768
68- def encode_timestamp (_timestamp : int ):
69+ def encode_timestamp (_timestamp : int ) -> str :
6970 hex_value = f"{ _timestamp :x} " .zfill (8 )
7071 return "" .join (list (map (lambda idx : hex_value [idx ], [5 , 6 , 3 , 7 , 1 , 2 , 0 , 4 ])))
7172
7273
7374class PreparedRequest :
74- def __init__ (self , base_url : str , base_headers : dict = None ):
75+ def __init__ (self , base_url : str , base_headers : dict = None ) -> None :
7576 self .base_url = base_url
7677 self .base_headers = base_headers or {}
7778
7879 async def request (
7980 self , method : str , url : str , params = None , data = None , headers = None
80- ):
81+ ) -> dict | list :
8182 _url = "/" .join (s .strip ("/" ) for s in [self .base_url , url ])
8283 _headers = {** self .base_headers , ** (headers or {})}
8384 async with aiohttp .ClientSession () as session :
@@ -99,7 +100,7 @@ async def request(
99100class RoborockMqttClient (mqtt .Client ):
100101 _thread : threading .Thread
101102
102- def __init__ (self , user_data : UserData , device_map : dict [str , RoborockDeviceInfo ]):
103+ def __init__ (self , user_data : UserData , device_map : dict [str , RoborockDeviceInfo ]) -> None :
103104 rriot = user_data .rriot
104105 self ._mqtt_user = rriot .user
105106 self ._mqtt_domain = rriot .domain
@@ -126,11 +127,11 @@ def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo
126127 self ._last_device_msg_in = mqtt .time_func ()
127128 self ._last_disconnection = mqtt .time_func ()
128129
129- def __del__ (self ):
130+ def __del__ (self ) -> None :
130131 self .sync_disconnect ()
131132
132133 @run_in_executor ()
133- async def on_connect (self , _client , _ , __ , rc , ___ = None ):
134+ async def on_connect (self , _client , _ , __ , rc , ___ = None ) -> None :
134135 connection_queue = self ._waiting_queue .get (0 )
135136 if rc != mqtt .MQTT_ERR_SUCCESS :
136137 message = f"Failed to connect (rc: { rc } )"
@@ -156,7 +157,7 @@ async def on_connect(self, _client, _, __, rc, ___=None):
156157 await connection_queue .async_put ((True , None ), timeout = QUEUE_TIMEOUT )
157158
158159 @run_in_executor ()
159- async def on_message (self , _client , _ , msg , __ = None ):
160+ async def on_message (self , _client , _ , msg , __ = None ) -> None :
160161 try :
161162 async with self ._mutex :
162163 self ._last_device_msg_in = mqtt .time_func ()
@@ -219,7 +220,7 @@ async def on_message(self, _client, _, msg, __=None):
219220 _LOGGER .exception (ex )
220221
221222 @run_in_executor ()
222- async def on_disconnect (self , _client : mqtt .Client , _ , rc , __ = None ):
223+ async def on_disconnect (self , _client : mqtt .Client , _ , rc , __ = None ) -> None :
223224 try :
224225 async with self ._mutex :
225226 self ._last_disconnection = mqtt .time_func ()
@@ -241,28 +242,28 @@ async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
241242 _LOGGER .exception (ex )
242243
243244 @run_in_executor ()
244- async def _async_check_keepalive (self ):
245+ async def _async_check_keepalive (self ) -> None :
245246 async with self ._mutex :
246247 now = mqtt .time_func ()
247248 if now - self ._last_disconnection > self ._keepalive ** 2 and now - self ._last_device_msg_in > self ._keepalive :
248249 self ._ping_t = self ._last_device_msg_in
249250
250- def _check_keepalive (self ):
251+ def _check_keepalive (self ) -> None :
251252 self ._async_check_keepalive ()
252253 super ()._check_keepalive ()
253254
254- def sync_stop_loop (self ):
255+ def sync_stop_loop (self ) -> None :
255256 if self ._thread :
256257 _LOGGER .info ("Stopping mqtt loop" )
257258 super ().loop_stop ()
258259
259- def sync_start_loop (self ):
260+ def sync_start_loop (self ) -> None :
260261 if not self ._thread or not self ._thread .is_alive ():
261262 self .sync_stop_loop ()
262263 _LOGGER .info ("Starting mqtt loop" )
263264 super ().loop_start ()
264265
265- def sync_disconnect (self ):
266+ def sync_disconnect (self ) -> bool :
266267 rc = mqtt .MQTT_ERR_AGAIN
267268 if self .is_connected ():
268269 _LOGGER .info ("Disconnecting from mqtt" )
@@ -271,7 +272,7 @@ def sync_disconnect(self):
271272 raise RoborockException (f"Failed to disconnect (rc:{ rc } )" )
272273 return rc == mqtt .MQTT_ERR_SUCCESS
273274
274- def sync_connect (self ):
275+ def sync_connect (self ) -> bool :
275276 rc = mqtt .MQTT_ERR_AGAIN
276277 self .sync_start_loop ()
277278 if not self .is_connected ():
@@ -285,7 +286,7 @@ def sync_connect(self):
285286 raise RoborockException (f"Failed to connect (rc:{ rc } )" )
286287 return rc == mqtt .MQTT_ERR_SUCCESS
287288
288- async def _async_response (self , request_id : int , protocol_id : int = 0 ):
289+ async def _async_response (self , request_id : int , protocol_id : int = 0 ) -> tuple [ Any , RoborockException | None ] :
289290 try :
290291 queue = RoborockQueue (protocol_id )
291292 self ._waiting_queue [request_id ] = queue
@@ -298,7 +299,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0):
298299 finally :
299300 del self ._waiting_queue [request_id ]
300301
301- async def async_disconnect (self ):
302+ async def async_disconnect (self ) -> Any :
302303 async with self ._mutex :
303304 disconnecting = self .sync_disconnect ()
304305 if disconnecting :
@@ -307,7 +308,7 @@ async def async_disconnect(self):
307308 raise RoborockException (err ) from err
308309 return response
309310
310- async def async_connect (self ):
311+ async def async_connect (self ) -> Any :
311312 async with self ._mutex :
312313 connecting = self .sync_connect ()
313314 if connecting :
@@ -316,10 +317,10 @@ async def async_connect(self):
316317 raise RoborockException (err ) from err
317318 return response
318319
319- async def validate_connection (self ):
320+ async def validate_connection (self ) -> None :
320321 await self .async_connect ()
321322
322- def _decode_msg (self , msg , device : HomeDataDevice ):
323+ def _decode_msg (self , msg , device : HomeDataDevice ) -> dict [ str , Any ] :
323324 if msg [0 :3 ] != "1.0" .encode ():
324325 raise RoborockException ("Unknown protocol version" )
325326 crc32 = binascii .crc32 (msg [0 : len (msg ) - 4 ])
@@ -344,7 +345,7 @@ def _decode_msg(self, msg, device: HomeDataDevice):
344345 "payload" : decrypted_payload ,
345346 }
346347
347- def _send_msg_raw (self , device_id , protocol , timestamp , payload ):
348+ def _send_msg_raw (self , device_id , protocol , timestamp , payload ) -> None :
348349 local_key = self .device_map [device_id ].device .local_key
349350 aes_key = md5bin (encode_timestamp (timestamp ) + local_key + self ._salt )
350351 cipher = AES .new (aes_key , AES .MODE_ECB )
@@ -438,7 +439,7 @@ async def get_consumable(self, device_id: str) -> Consumable:
438439 if isinstance (consumable , dict ):
439440 return Consumable (consumable )
440441
441- async def get_prop (self , device_id : str ):
442+ async def get_prop (self , device_id : str ) -> RoborockDeviceProp :
442443 [status , dnd_timer , clean_summary , consumable ] = await asyncio .gather (
443444 * [
444445 self .get_status (device_id ),
@@ -457,7 +458,7 @@ async def get_prop(self, device_id: str):
457458 status , dnd_timer , clean_summary , consumable , last_clean_record
458459 )
459460
460- async def get_multi_maps_list (self , device_id ):
461+ async def get_multi_maps_list (self , device_id ) -> MultiMapsList :
461462 multi_maps_list = await self .send_command (
462463 device_id , RoborockCommand .GET_MULTI_MAPS_LIST
463464 )
@@ -476,7 +477,7 @@ def __init__(self, username: str, base_url=None) -> None:
476477 self .base_url = base_url
477478 self ._device_identifier = secrets .token_urlsafe (16 )
478479
479- async def _get_base_url (self ):
480+ async def _get_base_url (self ) -> str :
480481 if not self .base_url :
481482 url_request = PreparedRequest (self ._default_url )
482483 response = await url_request .request (
@@ -495,7 +496,7 @@ def _get_header_client_id(self):
495496 md5 .update (self ._device_identifier .encode ())
496497 return base64 .b64encode (md5 .digest ()).decode ()
497498
498- async def request_code (self ):
499+ async def request_code (self ) -> None :
499500 base_url = await self ._get_base_url ()
500501 header_clientid = self ._get_header_client_id ()
501502 code_request = PreparedRequest (base_url , {"header_clientid" : header_clientid })
@@ -512,7 +513,7 @@ async def request_code(self):
512513 if code_response .get ("code" ) != 200 :
513514 raise RoborockException (code_response .get ("msg" ))
514515
515- async def pass_login (self , password : str ):
516+ async def pass_login (self , password : str ) -> UserData :
516517 base_url = await self ._get_base_url ()
517518 header_clientid = self ._get_header_client_id ()
518519
@@ -531,7 +532,7 @@ async def pass_login(self, password: str):
531532 raise RoborockException (login_response .get ("msg" ))
532533 return UserData (login_response .get ("data" ))
533534
534- async def code_login (self , code ):
535+ async def code_login (self , code ) -> UserData :
535536 base_url = await self ._get_base_url ()
536537 header_clientid = self ._get_header_client_id ()
537538
@@ -550,7 +551,7 @@ async def code_login(self, code):
550551 raise RoborockException (login_response .get ("msg" ))
551552 return UserData (login_response .get ("data" ))
552553
553- async def get_home_data (self , user_data : UserData ):
554+ async def get_home_data (self , user_data : UserData ) -> HomeData :
554555 base_url = await self ._get_base_url ()
555556 header_clientid = self ._get_header_client_id ()
556557 rriot = user_data .rriot
0 commit comments