Skip to content

Commit 9f97a2b

Browse files
authored
fix: fix a bug with local / mqtt fallback (#475)
1 parent 4b97db2 commit 9f97a2b

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

roborock/devices/v1_channel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .channel import Channel
2323
from .local_channel import LocalChannel, LocalSession, create_local_session
2424
from .mqtt_channel import MqttChannel
25-
from .v1_rpc_channel import V1RpcChannel, create_combined_rpc_channel, create_mqtt_rpc_channel
25+
from .v1_rpc_channel import PickFirstAvailable, V1RpcChannel, create_local_rpc_channel, create_mqtt_rpc_channel
2626

2727
_LOGGER = logging.getLogger(__name__)
2828

@@ -60,7 +60,11 @@ def __init__(
6060
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
6161
self._local_session = local_session
6262
self._local_channel: LocalChannel | None = None
63-
self._combined_rpc_channel: V1RpcChannel | None = None
63+
self._local_rpc_channel: V1RpcChannel | None = None
64+
# Prefer local, fallback to MQTT
65+
self._combined_rpc_channel = PickFirstAvailable(
66+
[lambda: self._local_rpc_channel, lambda: self._mqtt_rpc_channel]
67+
)
6468
self._mqtt_unsub: Callable[[], None] | None = None
6569
self._local_unsub: Callable[[], None] | None = None
6670
self._callback: Callable[[RoborockMessage], None] | None = None
@@ -84,7 +88,7 @@ def is_mqtt_connected(self) -> bool:
8488
@property
8589
def rpc_channel(self) -> V1RpcChannel:
8690
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
87-
return self._combined_rpc_channel or self._mqtt_rpc_channel
91+
return self._combined_rpc_channel
8892

8993
@property
9094
def mqtt_rpc_channel(self) -> V1RpcChannel:
@@ -160,7 +164,7 @@ async def _local_connect(self) -> Callable[[], None]:
160164
except RoborockException as e:
161165
self._local_channel = None
162166
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
163-
self._combined_rpc_channel = create_combined_rpc_channel(self._local_channel, self._mqtt_rpc_channel)
167+
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
164168
return await self._local_channel.subscribe(self._on_local_message)
165169

166170
def _on_mqtt_message(self, message: RoborockMessage) -> None:

roborock/devices/v1_rpc_channel.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,15 @@ async def _send_raw_command(
8888
raise NotImplementedError
8989

9090

91-
class CombinedV1RpcChannel(BaseV1RpcChannel):
92-
"""A V1 RPC channel that can use both local and MQTT channels, preferring local when available."""
91+
class PickFirstAvailable(BaseV1RpcChannel):
92+
"""A V1 RPC channel that tries multiple channels and picks the first that works."""
9393

9494
def __init__(
95-
self, local_channel: LocalChannel, local_rpc_channel: V1RpcChannel, mqtt_channel: V1RpcChannel
95+
self,
96+
channel_cbs: list[Callable[[], V1RpcChannel | None]],
9697
) -> None:
97-
"""Initialize the combined channel with local and MQTT channels."""
98-
self._local_channel = local_channel
99-
self._local_rpc_channel = local_rpc_channel
100-
self._mqtt_rpc_channel = mqtt_channel
98+
"""Initialize the pick-first-available channel."""
99+
self._channel_cbs = channel_cbs
101100

102101
async def _send_raw_command(
103102
self,
@@ -106,9 +105,10 @@ async def _send_raw_command(
106105
params: ParamsType = None,
107106
) -> Any:
108107
"""Send a command and return a parsed response RoborockBase type."""
109-
if self._local_channel.is_connected:
110-
return await self._local_rpc_channel.send_command(method, params=params)
111-
return await self._mqtt_rpc_channel.send_command(method, params=params)
108+
for channel_cb in self._channel_cbs:
109+
if channel := channel_cb():
110+
return await channel.send_command(method, params=params)
111+
raise RoborockException("No available connection to send command")
112112

113113

114114
class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
@@ -170,11 +170,10 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa
170170
)
171171

172172

173-
def create_combined_rpc_channel(local_channel: LocalChannel, mqtt_rpc_channel: V1RpcChannel) -> V1RpcChannel:
174-
"""Create a V1 RPC channel that combines local and MQTT channels."""
175-
local_rpc_channel = PayloadEncodedV1RpcChannel(
173+
def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel:
174+
"""Create a V1 RPC channel using a local channel."""
175+
return PayloadEncodedV1RpcChannel(
176176
"local",
177177
local_channel,
178178
lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
179179
)
180-
return CombinedV1RpcChannel(local_channel, local_rpc_channel, mqtt_rpc_channel)

tests/devices/test_v1_channel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ async def test_v1_channel_full_subscribe_and_command_flow(
497497
local_session=mock_local_session,
498498
cache=InMemoryCache(),
499499
)
500+
# Get a handle to the V1RpcChannel. It may change which connection is
501+
# active, but getting now to reproduce a bug where it doesn't change.
502+
rpc_channel = v1_channel.rpc_channel
500503

501504
# Mock network info for local connection
502505
callback = Mock()
@@ -512,7 +515,7 @@ async def test_v1_channel_full_subscribe_and_command_flow(
512515

513516
# Send a command (should use local)
514517
mock_local_channel.response_queue.append(TEST_RESPONSE)
515-
result = await v1_channel.rpc_channel.send_command(
518+
result = await rpc_channel.send_command(
516519
RoborockCommand.GET_STATUS,
517520
response_type=S5MaxStatus,
518521
)

0 commit comments

Comments
 (0)