From 8e9f1ba5d5ac0a4b3da8deeb525d29eb369793e0 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 13 Sep 2025 07:00:40 -0700 Subject: [PATCH 1/3] feat: Add a DnD trait and fix bugs in the rpc channels --- roborock/containers.py | 16 +-- roborock/devices/device_manager.py | 2 + roborock/devices/traits/dnd.py | 42 ++++++++ roborock/devices/traits/status.py | 11 ++- roborock/devices/v1_channel.py | 7 +- roborock/devices/v1_rpc_channel.py | 8 +- roborock/protocols/v1_protocol.py | 4 +- tests/devices/test_v1_channel.py | 12 +-- tests/devices/test_v1_device.py | 2 +- tests/devices/traits/__init__.py | 1 + tests/devices/traits/test_dnd.py | 153 +++++++++++++++++++++++++++++ 11 files changed, 234 insertions(+), 24 deletions(-) create mode 100644 roborock/devices/traits/dnd.py create mode 100644 tests/devices/traits/__init__.py create mode 100644 tests/devices/traits/test_dnd.py diff --git a/roborock/containers.py b/roborock/containers.py index 4e78c8f0..8f08da15 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -134,8 +134,8 @@ def from_dict(cls, data: dict[str, Any]): return None field_types = {field.name: field.type for field in dataclasses.fields(cls)} result: dict[str, Any] = {} - for key, value in data.items(): - key = _decamelize(key) + for orig_key, value in data.items(): + key = _decamelize(orig_key) if (field_type := field_types.get(key)) is None: continue if value == "None" or value is None: @@ -178,16 +178,18 @@ class RoborockBaseTimer(RoborockBase): end_hour: int | None = None end_minute: int | None = None enabled: int | None = None - start_time: datetime.time | None = None - end_time: datetime.time | None = None - def __post_init__(self) -> None: - self.start_time = ( + @property + def start_time(self) -> datetime.time | None: + return ( datetime.time(hour=self.start_hour, minute=self.start_minute) if self.start_hour is not None and self.start_minute is not None else None ) - self.end_time = ( + + @property + def end_time(self) -> datetime.time | None: + return ( datetime.time(hour=self.end_hour, minute=self.end_minute) if self.end_hour is not None and self.end_minute is not None else None diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 52b03b01..853a5f57 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -22,6 +22,7 @@ from .channel import Channel from .mqtt_channel import create_mqtt_channel from .traits.b01.props import B01PropsApi +from .traits.dnd import DoNotDisturbTrait from .traits.dyad import DyadApi from .traits.status import StatusTrait from .traits.trait import Trait @@ -152,6 +153,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock case DeviceVersion.V1: channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache) traits.append(StatusTrait(product, channel.rpc_channel)) + traits.append(DoNotDisturbTrait(channel.rpc_channel)) case DeviceVersion.A01: mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) match product.category: diff --git a/roborock/devices/traits/dnd.py b/roborock/devices/traits/dnd.py new file mode 100644 index 00000000..476ecb33 --- /dev/null +++ b/roborock/devices/traits/dnd.py @@ -0,0 +1,42 @@ +"""Module for Roborock V1 devices. + +This interface is experimental and subject to breaking changes without notice +until the API is stable. +""" + +import logging +from collections.abc import Callable + +from roborock.containers import DnDTimer +from roborock.devices.v1_rpc_channel import V1RpcChannel +from roborock.roborock_typing import RoborockCommand + +from .trait import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "DoNotDisturbTrait", +] + + +class DoNotDisturbTrait(Trait): + """Trait for managing Do Not Disturb (DND) settings on Roborock devices.""" + + name = "do_not_disturb" + + def __init__(self, rpc_channel: Callable[[], V1RpcChannel]) -> None: + """Initialize the DoNotDisturbTrait.""" + self._rpc_channel = rpc_channel + + async def get_dnd_timer(self) -> DnDTimer: + """Get the current Do Not Disturb (DND) timer settings of the device.""" + return await self._rpc_channel().send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None: + """Set the Do Not Disturb (DND) timer settings of the device.""" + await self._rpc_channel().send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) + + async def clear_dnd_timer(self) -> None: + """Clear the Do Not Disturb (DND) timer settings of the device.""" + await self._rpc_channel().send_command(RoborockCommand.CLOSE_DND_TIMER) diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py index d7d622d9..ce837d27 100644 --- a/roborock/devices/traits/status.py +++ b/roborock/devices/traits/status.py @@ -5,6 +5,7 @@ """ import logging +from collections.abc import Callable from roborock.containers import ( HomeDataProduct, @@ -12,24 +13,24 @@ S7MaxVStatus, Status, ) +from roborock.devices.v1_rpc_channel import V1RpcChannel from roborock.roborock_typing import RoborockCommand -from ..v1_rpc_channel import V1RpcChannel from .trait import Trait _LOGGER = logging.getLogger(__name__) __all__ = [ - "Status", + "StatusTrait", ] class StatusTrait(Trait): - """Unified Roborock device class with automatic connection setup.""" + """Trait for managing the status of Roborock devices.""" name = "status" - def __init__(self, product_info: HomeDataProduct, rpc_channel: V1RpcChannel) -> None: + def __init__(self, product_info: HomeDataProduct, rpc_channel: Callable[[], V1RpcChannel]) -> None: """Initialize the StatusTrait.""" self._product_info = product_info self._rpc_channel = rpc_channel @@ -40,4 +41,4 @@ async def get_status(self) -> Status: This is a placeholder command and will likely be changed/moved in the future. """ status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) - return await self._rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) + return await self._rpc_channel().send_command(RoborockCommand.GET_STATUS, response_type=status_type) diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index b7371e33..cda5e40a 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -81,9 +81,12 @@ def is_mqtt_connected(self) -> bool: """Return whether MQTT connection is available.""" return self._mqtt_unsub is not None and self._mqtt_channel.is_connected - @property def rpc_channel(self) -> V1RpcChannel: - """Return the combined RPC channel prefers local with a fallback to MQTT.""" + """Return the combined RPC channel prefers local with a fallback to MQTT. + + This is dynamic based on the current connection status. That is, it may return + a different channel depending on whether local or MQTT is available. + """ return self._combined_rpc_channel or self._mqtt_rpc_channel @property diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py index 52122e08..bad24a2f 100644 --- a/roborock/devices/v1_rpc_channel.py +++ b/roborock/devices/v1_rpc_channel.py @@ -132,8 +132,10 @@ async def _send_raw_command( params: ParamsType = None, ) -> Any: """Send a command and return a parsed response RoborockBase type.""" - _LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params) request_message = RequestMessage(method, params=params) + _LOGGER.debug( + "Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params + ) message = self._payload_encoder(request_message) future: asyncio.Future[dict[str, Any]] = asyncio.Future() @@ -141,8 +143,10 @@ async def _send_raw_command( def find_response(response_message: RoborockMessage) -> None: try: decoded = decode_rpc_response(response_message) - except RoborockException: + except RoborockException as ex: + _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) return + _LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id) if decoded.request_id == request_message.request_id: future.set_result(decoded.data) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 6829b7db..8c1b70dd 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -109,7 +109,7 @@ class ResponseMessage: def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: """Decode a V1 RPC_RESPONSE message.""" if not message.payload: - raise RoborockException("Invalid V1 message format: missing payload") + return ResponseMessage(request_id=message.seq, data={}) try: payload = json.loads(message.payload.decode()) except (json.JSONDecodeError, TypeError) as e: @@ -141,6 +141,8 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: _LOGGER.debug("Decoded V1 message result: %s", result) if isinstance(result, list) and result: result = result[0] + if isinstance(result, str) and result == "ok": + result = {} if not isinstance(result, dict): raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}") return ResponseMessage(request_id=request_id, data=result) diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index d7711f42..ff729fb2 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -254,7 +254,7 @@ async def test_v1_channel_send_command_local_preferred( # Send command mock_local_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel.send_command( + result = await v1_channel.rpc_channel().send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -280,7 +280,7 @@ async def test_v1_channel_send_command_local_fails( # Send command with pytest.raises(RoborockException, match="Local failed"): - await v1_channel.rpc_channel.send_command( + await v1_channel.rpc_channel().send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -300,7 +300,7 @@ async def test_v1_channel_send_decoded_command_mqtt_only( # Send command mock_mqtt_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel.send_command( + result = await v1_channel.rpc_channel().send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -322,7 +322,7 @@ async def test_v1_channel_send_decoded_command_with_params( # Send command with params mock_local_channel.response_queue.append(TEST_RESPONSE) test_params = {"volume": 80} - await v1_channel.rpc_channel.send_command( + await v1_channel.rpc_channel().send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, params=test_params, @@ -444,7 +444,7 @@ async def test_v1_channel_command_encoding_validation( # Send local command and capture the request mock_local_channel.response_queue.append(TEST_RESPONSE_2) - await v1_channel.rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) + await v1_channel.rpc_channel().send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) assert mock_local_channel.published_messages local_message = mock_local_channel.published_messages[0] @@ -512,7 +512,7 @@ async def test_v1_channel_full_subscribe_and_command_flow( # Send a command (should use local) mock_local_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel.send_command( + result = await v1_channel.rpc_channel().send_command( RoborockCommand.GET_STATUS, response_type=S5MaxStatus, ) diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index 10ef8073..c50932a8 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -44,7 +44,7 @@ def traits_fixture(rpc_channel: AsyncMock) -> list[Trait]: return [ StatusTrait( product_info=HOME_DATA.products[0], - rpc_channel=rpc_channel, + rpc_channel=lambda: rpc_channel, ) ] diff --git a/tests/devices/traits/__init__.py b/tests/devices/traits/__init__.py new file mode 100644 index 00000000..21bbdbcd --- /dev/null +++ b/tests/devices/traits/__init__.py @@ -0,0 +1 @@ +"""Tests for device traits.""" diff --git a/tests/devices/traits/test_dnd.py b/tests/devices/traits/test_dnd.py new file mode 100644 index 00000000..38e5c07d --- /dev/null +++ b/tests/devices/traits/test_dnd.py @@ -0,0 +1,153 @@ +"""Tests for the DoNotDisturbTrait class.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from roborock.containers import DnDTimer +from roborock.devices.traits.dnd import DoNotDisturbTrait +from roborock.devices.v1_rpc_channel import V1RpcChannel +from roborock.roborock_typing import RoborockCommand + + +@pytest.fixture +def mock_rpc_channel() -> AsyncMock: + """Create a mock RPC channel.""" + mock_channel = AsyncMock(spec=V1RpcChannel) + # Ensure send_command is an AsyncMock that returns awaitable coroutines + mock_channel.send_command = AsyncMock() + return mock_channel + + +@pytest.fixture +def mock_rpc_channel_callable(mock_rpc_channel: AsyncMock) -> Mock: + """Create a callable that returns the mock RPC channel.""" + return Mock(return_value=mock_rpc_channel) + + +@pytest.fixture +def dnd_trait(mock_rpc_channel_callable: Mock) -> DoNotDisturbTrait: + """Create a DoNotDisturbTrait instance with mocked dependencies.""" + return DoNotDisturbTrait(mock_rpc_channel_callable) + + +@pytest.fixture +def sample_dnd_timer() -> DnDTimer: + """Create a sample DnDTimer for testing.""" + return DnDTimer( + start_hour=22, + start_minute=0, + end_hour=8, + end_minute=0, + enabled=1, + ) + + +def test_trait_name(dnd_trait: DoNotDisturbTrait) -> None: + """Test that the trait has the correct name.""" + assert dnd_trait.name == "do_not_disturb" + + +async def test_get_dnd_timer_success( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test successfully getting DnD timer settings.""" + # Setup mock to return the sample DnD timer + mock_rpc_channel.send_command.return_value = sample_dnd_timer + + # Call the method + result = await dnd_trait.get_dnd_timer() + + # Verify the result + assert result == sample_dnd_timer + assert result.start_hour == 22 + assert result.start_minute == 0 + assert result.end_hour == 8 + assert result.end_minute == 0 + assert result.enabled == 1 + + # Verify the RPC call was made correctly + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + +async def test_get_dnd_timer_disabled(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test getting DnD timer when it's disabled.""" + disabled_timer = DnDTimer( + start_hour=22, + start_minute=0, + end_hour=8, + end_minute=0, + enabled=0, + ) + mock_rpc_channel.send_command.return_value = disabled_timer + + result = await dnd_trait.get_dnd_timer() + + assert result.enabled == 0 + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + +async def test_set_dnd_timer_success( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test successfully setting DnD timer settings.""" + # Call the method + await dnd_trait.set_dnd_timer(sample_dnd_timer) + + # Verify the RPC call was made correctly with dataclass converted to dict + + expected_params = { + "startHour": 22, + "startMinute": 0, + "endHour": 8, + "endMinute": 0, + "enabled": 1, + } + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.SET_DND_TIMER, params=expected_params) + + +async def test_clear_dnd_timer_success(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test successfully clearing DnD timer settings.""" + # Call the method + await dnd_trait.clear_dnd_timer() + + # Verify the RPC call was made correctly + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.CLOSE_DND_TIMER) + + +async def test_get_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test that exceptions from RPC channel are propagated in get_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.get_dnd_timer() + + +async def test_set_dnd_timer_propagates_exception( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test that exceptions from RPC channel are propagated in set_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.set_dnd_timer(sample_dnd_timer) + + +async def test_clear_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test that exceptions from RPC channel are propagated in clear_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.clear_dnd_timer() From 1286b9b7c330efe2ab6720de075b10aedfe8e332 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 13 Sep 2025 13:24:05 -0700 Subject: [PATCH 2/3] chore: Simplify command sending --- roborock/devices/traits/dnd.py | 8 ++++---- roborock/devices/traits/status.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/roborock/devices/traits/dnd.py b/roborock/devices/traits/dnd.py index 476ecb33..1fad3b97 100644 --- a/roborock/devices/traits/dnd.py +++ b/roborock/devices/traits/dnd.py @@ -27,16 +27,16 @@ class DoNotDisturbTrait(Trait): def __init__(self, rpc_channel: Callable[[], V1RpcChannel]) -> None: """Initialize the DoNotDisturbTrait.""" - self._rpc_channel = rpc_channel + self._send_command = lambda *args, **kwargs: rpc_channel().send_command(*args, **kwargs) async def get_dnd_timer(self) -> DnDTimer: """Get the current Do Not Disturb (DND) timer settings of the device.""" - return await self._rpc_channel().send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + return await self._send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None: """Set the Do Not Disturb (DND) timer settings of the device.""" - await self._rpc_channel().send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) + await self._send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) async def clear_dnd_timer(self) -> None: """Clear the Do Not Disturb (DND) timer settings of the device.""" - await self._rpc_channel().send_command(RoborockCommand.CLOSE_DND_TIMER) + await self._send_command(RoborockCommand.CLOSE_DND_TIMER) diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py index ce837d27..bb6b3aed 100644 --- a/roborock/devices/traits/status.py +++ b/roborock/devices/traits/status.py @@ -33,7 +33,7 @@ class StatusTrait(Trait): def __init__(self, product_info: HomeDataProduct, rpc_channel: Callable[[], V1RpcChannel]) -> None: """Initialize the StatusTrait.""" self._product_info = product_info - self._rpc_channel = rpc_channel + self._send_command = lambda *args, **kwargs: rpc_channel().send_command(*args, **kwargs) async def get_status(self) -> Status: """Get the current status of the device. @@ -41,4 +41,4 @@ async def get_status(self) -> Status: This is a placeholder command and will likely be changed/moved in the future. """ status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) - return await self._rpc_channel().send_command(RoborockCommand.GET_STATUS, response_type=status_type) + return await self._send_command(RoborockCommand.GET_STATUS, response_type=status_type) From 6199707e1af43a5329e6de64a642e71171c85d3e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 14 Sep 2025 09:30:53 -0700 Subject: [PATCH 3/3] chore: revert changes to rpc channel --- roborock/devices/traits/dnd.py | 11 +++++------ roborock/devices/traits/status.py | 7 +++---- roborock/devices/v1_channel.py | 7 ++----- tests/devices/test_v1_channel.py | 12 ++++++------ tests/devices/test_v1_device.py | 2 +- tests/devices/traits/test_dnd.py | 12 +++--------- 6 files changed, 20 insertions(+), 31 deletions(-) diff --git a/roborock/devices/traits/dnd.py b/roborock/devices/traits/dnd.py index 1fad3b97..0c57717e 100644 --- a/roborock/devices/traits/dnd.py +++ b/roborock/devices/traits/dnd.py @@ -5,7 +5,6 @@ """ import logging -from collections.abc import Callable from roborock.containers import DnDTimer from roborock.devices.v1_rpc_channel import V1RpcChannel @@ -25,18 +24,18 @@ class DoNotDisturbTrait(Trait): name = "do_not_disturb" - def __init__(self, rpc_channel: Callable[[], V1RpcChannel]) -> None: + def __init__(self, rpc_channel: V1RpcChannel) -> None: """Initialize the DoNotDisturbTrait.""" - self._send_command = lambda *args, **kwargs: rpc_channel().send_command(*args, **kwargs) + self._rpc_channel = rpc_channel async def get_dnd_timer(self) -> DnDTimer: """Get the current Do Not Disturb (DND) timer settings of the device.""" - return await self._send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + return await self._rpc_channel.send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None: """Set the Do Not Disturb (DND) timer settings of the device.""" - await self._send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) + await self._rpc_channel.send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) async def clear_dnd_timer(self) -> None: """Clear the Do Not Disturb (DND) timer settings of the device.""" - await self._send_command(RoborockCommand.CLOSE_DND_TIMER) + await self._rpc_channel.send_command(RoborockCommand.CLOSE_DND_TIMER) diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py index bb6b3aed..e9a210d0 100644 --- a/roborock/devices/traits/status.py +++ b/roborock/devices/traits/status.py @@ -5,7 +5,6 @@ """ import logging -from collections.abc import Callable from roborock.containers import ( HomeDataProduct, @@ -30,10 +29,10 @@ class StatusTrait(Trait): name = "status" - def __init__(self, product_info: HomeDataProduct, rpc_channel: Callable[[], V1RpcChannel]) -> None: + def __init__(self, product_info: HomeDataProduct, rpc_channel: V1RpcChannel) -> None: """Initialize the StatusTrait.""" self._product_info = product_info - self._send_command = lambda *args, **kwargs: rpc_channel().send_command(*args, **kwargs) + self._rpc_channel = rpc_channel async def get_status(self) -> Status: """Get the current status of the device. @@ -41,4 +40,4 @@ async def get_status(self) -> Status: This is a placeholder command and will likely be changed/moved in the future. """ status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) - return await self._send_command(RoborockCommand.GET_STATUS, response_type=status_type) + return await self._rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) diff --git a/roborock/devices/v1_channel.py b/roborock/devices/v1_channel.py index cda5e40a..b7371e33 100644 --- a/roborock/devices/v1_channel.py +++ b/roborock/devices/v1_channel.py @@ -81,12 +81,9 @@ def is_mqtt_connected(self) -> bool: """Return whether MQTT connection is available.""" return self._mqtt_unsub is not None and self._mqtt_channel.is_connected + @property def rpc_channel(self) -> V1RpcChannel: - """Return the combined RPC channel prefers local with a fallback to MQTT. - - This is dynamic based on the current connection status. That is, it may return - a different channel depending on whether local or MQTT is available. - """ + """Return the combined RPC channel prefers local with a fallback to MQTT.""" return self._combined_rpc_channel or self._mqtt_rpc_channel @property diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index ff729fb2..d7711f42 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -254,7 +254,7 @@ async def test_v1_channel_send_command_local_preferred( # Send command mock_local_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel().send_command( + result = await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -280,7 +280,7 @@ async def test_v1_channel_send_command_local_fails( # Send command with pytest.raises(RoborockException, match="Local failed"): - await v1_channel.rpc_channel().send_command( + await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -300,7 +300,7 @@ async def test_v1_channel_send_decoded_command_mqtt_only( # Send command mock_mqtt_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel().send_command( + result = await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, ) @@ -322,7 +322,7 @@ async def test_v1_channel_send_decoded_command_with_params( # Send command with params mock_local_channel.response_queue.append(TEST_RESPONSE) test_params = {"volume": 80} - await v1_channel.rpc_channel().send_command( + await v1_channel.rpc_channel.send_command( RoborockCommand.CHANGE_SOUND_VOLUME, response_type=S5MaxStatus, params=test_params, @@ -444,7 +444,7 @@ async def test_v1_channel_command_encoding_validation( # Send local command and capture the request mock_local_channel.response_queue.append(TEST_RESPONSE_2) - await v1_channel.rpc_channel().send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) + await v1_channel.rpc_channel.send_command(RoborockCommand.CHANGE_SOUND_VOLUME, params={"volume": 50}) assert mock_local_channel.published_messages local_message = mock_local_channel.published_messages[0] @@ -512,7 +512,7 @@ async def test_v1_channel_full_subscribe_and_command_flow( # Send a command (should use local) mock_local_channel.response_queue.append(TEST_RESPONSE) - result = await v1_channel.rpc_channel().send_command( + result = await v1_channel.rpc_channel.send_command( RoborockCommand.GET_STATUS, response_type=S5MaxStatus, ) diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index c50932a8..10ef8073 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -44,7 +44,7 @@ def traits_fixture(rpc_channel: AsyncMock) -> list[Trait]: return [ StatusTrait( product_info=HOME_DATA.products[0], - rpc_channel=lambda: rpc_channel, + rpc_channel=rpc_channel, ) ] diff --git a/tests/devices/traits/test_dnd.py b/tests/devices/traits/test_dnd.py index 38e5c07d..fddc55cf 100644 --- a/tests/devices/traits/test_dnd.py +++ b/tests/devices/traits/test_dnd.py @@ -1,6 +1,6 @@ """Tests for the DoNotDisturbTrait class.""" -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest @@ -20,15 +20,9 @@ def mock_rpc_channel() -> AsyncMock: @pytest.fixture -def mock_rpc_channel_callable(mock_rpc_channel: AsyncMock) -> Mock: - """Create a callable that returns the mock RPC channel.""" - return Mock(return_value=mock_rpc_channel) - - -@pytest.fixture -def dnd_trait(mock_rpc_channel_callable: Mock) -> DoNotDisturbTrait: +def dnd_trait(mock_rpc_channel: AsyncMock) -> DoNotDisturbTrait: """Create a DoNotDisturbTrait instance with mocked dependencies.""" - return DoNotDisturbTrait(mock_rpc_channel_callable) + return DoNotDisturbTrait(mock_rpc_channel) @pytest.fixture