66and local connections, preferring local when available.
77"""
88
9+
910import asyncio
1011import logging
1112from collections .abc import Callable
1516from roborock .exceptions import RoborockException
1617from roborock .protocols .v1_protocol import (
1718 CommandType ,
19+ MapResponse ,
1820 ParamsType ,
1921 RequestMessage ,
2022 ResponseData ,
23+ ResponseMessage ,
2124 SecurityData ,
25+ create_map_response_decoder ,
2226 decode_rpc_response ,
2327)
2428from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
3135
3236
3337_T = TypeVar ("_T" , bound = RoborockBase )
38+ _V = TypeVar ("_V" )
3439
3540
3641class V1RpcChannel (Protocol ):
@@ -120,36 +125,40 @@ def __init__(
120125 name : str ,
121126 channel : MqttChannel | LocalChannel ,
122127 payload_encoder : Callable [[RequestMessage ], RoborockMessage ],
128+ decoder : Callable [[RoborockMessage ], ResponseMessage ] | Callable [[RoborockMessage ], MapResponse | None ],
123129 ) -> None :
124130 """Initialize the channel with a raw channel and an encoder function."""
125131 self ._name = name
126132 self ._channel = channel
127133 self ._payload_encoder = payload_encoder
134+ self ._decoder = decoder
128135
129136 async def _send_raw_command (
130137 self ,
131138 method : CommandType ,
132139 * ,
133140 params : ParamsType = None ,
134- ) -> ResponseData :
141+ ) -> ResponseData | bytes :
135142 """Send a command and return a parsed response RoborockBase type."""
136143 request_message = RequestMessage (method , params = params )
137144 _LOGGER .debug (
138145 "Sending command (%s, request_id=%s): %s, params=%s" , self ._name , request_message .request_id , method , params
139146 )
140147 message = self ._payload_encoder (request_message )
141148
142- future : asyncio .Future [ResponseData ] = asyncio .Future ()
149+ future : asyncio .Future [ResponseData | bytes ] = asyncio .Future ()
143150
144151 def find_response (response_message : RoborockMessage ) -> None :
145152 try :
146- decoded = decode_rpc_response (response_message )
153+ decoded = self . _decoder (response_message )
147154 except RoborockException as ex :
148155 _LOGGER .debug ("Exception while decoding message (%s): %s" , response_message , ex )
149156 return
150- _LOGGER .debug ("Received response (request_id=%s): %s" , self ._name , decoded .request_id )
157+ if decoded is None :
158+ return
159+ _LOGGER .debug ("Received response (%s, request_id=%s)" , self ._name , decoded .request_id )
151160 if decoded .request_id == request_message .request_id :
152- if decoded .api_error :
161+ if isinstance ( decoded , ResponseMessage ) and decoded .api_error :
153162 future .set_exception (decoded .api_error )
154163 else :
155164 future .set_result (decoded .data )
@@ -171,6 +180,7 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa
171180 "mqtt" ,
172181 mqtt_channel ,
173182 lambda x : x .encode_message (RoborockMessageProtocol .RPC_REQUEST , security_data = security_data ),
183+ decode_rpc_response ,
174184 )
175185
176186
@@ -180,4 +190,23 @@ def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel:
180190 "local" ,
181191 local_channel ,
182192 lambda x : x .encode_message (RoborockMessageProtocol .GENERAL_REQUEST ),
193+ decode_rpc_response ,
194+ )
195+
196+
197+ def create_map_rpc_channel (
198+ mqtt_channel : MqttChannel ,
199+ security_data : SecurityData ,
200+ ) -> V1RpcChannel :
201+ """Create a V1 RPC channel that fetches map data.
202+
203+ This will prefer local channels when available, falling back to MQTT
204+ channels if not. If neither is available, an exception will be raised
205+ when trying to send a command.
206+ """
207+ return PayloadEncodedV1RpcChannel (
208+ "map" ,
209+ mqtt_channel ,
210+ lambda x : x .encode_message (RoborockMessageProtocol .RPC_REQUEST , security_data = security_data ),
211+ create_map_response_decoder (security_data = security_data ),
183212 )
0 commit comments