diff --git a/roborock/containers.py b/roborock/containers.py index 4e78c8f0..8f08da15 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -134,8 +134,8 @@ def from_dict(cls, data: dict[str, Any]): return None field_types = {field.name: field.type for field in dataclasses.fields(cls)} result: dict[str, Any] = {} - for key, value in data.items(): - key = _decamelize(key) + for orig_key, value in data.items(): + key = _decamelize(orig_key) if (field_type := field_types.get(key)) is None: continue if value == "None" or value is None: @@ -178,16 +178,18 @@ class RoborockBaseTimer(RoborockBase): end_hour: int | None = None end_minute: int | None = None enabled: int | None = None - start_time: datetime.time | None = None - end_time: datetime.time | None = None - def __post_init__(self) -> None: - self.start_time = ( + @property + def start_time(self) -> datetime.time | None: + return ( datetime.time(hour=self.start_hour, minute=self.start_minute) if self.start_hour is not None and self.start_minute is not None else None ) - self.end_time = ( + + @property + def end_time(self) -> datetime.time | None: + return ( datetime.time(hour=self.end_hour, minute=self.end_minute) if self.end_hour is not None and self.end_minute is not None else None diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 52b03b01..853a5f57 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -22,6 +22,7 @@ from .channel import Channel from .mqtt_channel import create_mqtt_channel from .traits.b01.props import B01PropsApi +from .traits.dnd import DoNotDisturbTrait from .traits.dyad import DyadApi from .traits.status import StatusTrait from .traits.trait import Trait @@ -152,6 +153,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock case DeviceVersion.V1: channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache) traits.append(StatusTrait(product, channel.rpc_channel)) + traits.append(DoNotDisturbTrait(channel.rpc_channel)) case DeviceVersion.A01: mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) match product.category: diff --git a/roborock/devices/traits/dnd.py b/roborock/devices/traits/dnd.py new file mode 100644 index 00000000..0c57717e --- /dev/null +++ b/roborock/devices/traits/dnd.py @@ -0,0 +1,41 @@ +"""Module for Roborock V1 devices. + +This interface is experimental and subject to breaking changes without notice +until the API is stable. +""" + +import logging + +from roborock.containers import DnDTimer +from roborock.devices.v1_rpc_channel import V1RpcChannel +from roborock.roborock_typing import RoborockCommand + +from .trait import Trait + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "DoNotDisturbTrait", +] + + +class DoNotDisturbTrait(Trait): + """Trait for managing Do Not Disturb (DND) settings on Roborock devices.""" + + name = "do_not_disturb" + + def __init__(self, rpc_channel: V1RpcChannel) -> None: + """Initialize the DoNotDisturbTrait.""" + self._rpc_channel = rpc_channel + + async def get_dnd_timer(self) -> DnDTimer: + """Get the current Do Not Disturb (DND) timer settings of the device.""" + return await self._rpc_channel.send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None: + """Set the Do Not Disturb (DND) timer settings of the device.""" + await self._rpc_channel.send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict()) + + async def clear_dnd_timer(self) -> None: + """Clear the Do Not Disturb (DND) timer settings of the device.""" + await self._rpc_channel.send_command(RoborockCommand.CLOSE_DND_TIMER) diff --git a/roborock/devices/traits/status.py b/roborock/devices/traits/status.py index d7d622d9..e9a210d0 100644 --- a/roborock/devices/traits/status.py +++ b/roborock/devices/traits/status.py @@ -12,20 +12,20 @@ S7MaxVStatus, Status, ) +from roborock.devices.v1_rpc_channel import V1RpcChannel from roborock.roborock_typing import RoborockCommand -from ..v1_rpc_channel import V1RpcChannel from .trait import Trait _LOGGER = logging.getLogger(__name__) __all__ = [ - "Status", + "StatusTrait", ] class StatusTrait(Trait): - """Unified Roborock device class with automatic connection setup.""" + """Trait for managing the status of Roborock devices.""" name = "status" diff --git a/roborock/devices/v1_rpc_channel.py b/roborock/devices/v1_rpc_channel.py index 52122e08..bad24a2f 100644 --- a/roborock/devices/v1_rpc_channel.py +++ b/roborock/devices/v1_rpc_channel.py @@ -132,8 +132,10 @@ async def _send_raw_command( params: ParamsType = None, ) -> Any: """Send a command and return a parsed response RoborockBase type.""" - _LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params) request_message = RequestMessage(method, params=params) + _LOGGER.debug( + "Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params + ) message = self._payload_encoder(request_message) future: asyncio.Future[dict[str, Any]] = asyncio.Future() @@ -141,8 +143,10 @@ async def _send_raw_command( def find_response(response_message: RoborockMessage) -> None: try: decoded = decode_rpc_response(response_message) - except RoborockException: + except RoborockException as ex: + _LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex) return + _LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id) if decoded.request_id == request_message.request_id: future.set_result(decoded.data) diff --git a/roborock/protocols/v1_protocol.py b/roborock/protocols/v1_protocol.py index 6829b7db..8c1b70dd 100644 --- a/roborock/protocols/v1_protocol.py +++ b/roborock/protocols/v1_protocol.py @@ -109,7 +109,7 @@ class ResponseMessage: def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: """Decode a V1 RPC_RESPONSE message.""" if not message.payload: - raise RoborockException("Invalid V1 message format: missing payload") + return ResponseMessage(request_id=message.seq, data={}) try: payload = json.loads(message.payload.decode()) except (json.JSONDecodeError, TypeError) as e: @@ -141,6 +141,8 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage: _LOGGER.debug("Decoded V1 message result: %s", result) if isinstance(result, list) and result: result = result[0] + if isinstance(result, str) and result == "ok": + result = {} if not isinstance(result, dict): raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}") return ResponseMessage(request_id=request_id, data=result) diff --git a/tests/devices/traits/__init__.py b/tests/devices/traits/__init__.py new file mode 100644 index 00000000..21bbdbcd --- /dev/null +++ b/tests/devices/traits/__init__.py @@ -0,0 +1 @@ +"""Tests for device traits.""" diff --git a/tests/devices/traits/test_dnd.py b/tests/devices/traits/test_dnd.py new file mode 100644 index 00000000..fddc55cf --- /dev/null +++ b/tests/devices/traits/test_dnd.py @@ -0,0 +1,147 @@ +"""Tests for the DoNotDisturbTrait class.""" + +from unittest.mock import AsyncMock + +import pytest + +from roborock.containers import DnDTimer +from roborock.devices.traits.dnd import DoNotDisturbTrait +from roborock.devices.v1_rpc_channel import V1RpcChannel +from roborock.roborock_typing import RoborockCommand + + +@pytest.fixture +def mock_rpc_channel() -> AsyncMock: + """Create a mock RPC channel.""" + mock_channel = AsyncMock(spec=V1RpcChannel) + # Ensure send_command is an AsyncMock that returns awaitable coroutines + mock_channel.send_command = AsyncMock() + return mock_channel + + +@pytest.fixture +def dnd_trait(mock_rpc_channel: AsyncMock) -> DoNotDisturbTrait: + """Create a DoNotDisturbTrait instance with mocked dependencies.""" + return DoNotDisturbTrait(mock_rpc_channel) + + +@pytest.fixture +def sample_dnd_timer() -> DnDTimer: + """Create a sample DnDTimer for testing.""" + return DnDTimer( + start_hour=22, + start_minute=0, + end_hour=8, + end_minute=0, + enabled=1, + ) + + +def test_trait_name(dnd_trait: DoNotDisturbTrait) -> None: + """Test that the trait has the correct name.""" + assert dnd_trait.name == "do_not_disturb" + + +async def test_get_dnd_timer_success( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test successfully getting DnD timer settings.""" + # Setup mock to return the sample DnD timer + mock_rpc_channel.send_command.return_value = sample_dnd_timer + + # Call the method + result = await dnd_trait.get_dnd_timer() + + # Verify the result + assert result == sample_dnd_timer + assert result.start_hour == 22 + assert result.start_minute == 0 + assert result.end_hour == 8 + assert result.end_minute == 0 + assert result.enabled == 1 + + # Verify the RPC call was made correctly + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + +async def test_get_dnd_timer_disabled(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test getting DnD timer when it's disabled.""" + disabled_timer = DnDTimer( + start_hour=22, + start_minute=0, + end_hour=8, + end_minute=0, + enabled=0, + ) + mock_rpc_channel.send_command.return_value = disabled_timer + + result = await dnd_trait.get_dnd_timer() + + assert result.enabled == 0 + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer) + + +async def test_set_dnd_timer_success( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test successfully setting DnD timer settings.""" + # Call the method + await dnd_trait.set_dnd_timer(sample_dnd_timer) + + # Verify the RPC call was made correctly with dataclass converted to dict + + expected_params = { + "startHour": 22, + "startMinute": 0, + "endHour": 8, + "endMinute": 0, + "enabled": 1, + } + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.SET_DND_TIMER, params=expected_params) + + +async def test_clear_dnd_timer_success(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test successfully clearing DnD timer settings.""" + # Call the method + await dnd_trait.clear_dnd_timer() + + # Verify the RPC call was made correctly + mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.CLOSE_DND_TIMER) + + +async def test_get_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test that exceptions from RPC channel are propagated in get_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.get_dnd_timer() + + +async def test_set_dnd_timer_propagates_exception( + dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer +) -> None: + """Test that exceptions from RPC channel are propagated in set_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.set_dnd_timer(sample_dnd_timer) + + +async def test_clear_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None: + """Test that exceptions from RPC channel are propagated in clear_dnd_timer.""" + from roborock.exceptions import RoborockException + + # Setup mock to raise an exception + mock_rpc_channel.send_command.side_effect = RoborockException("Communication error") + + # Verify the exception is propagated + with pytest.raises(RoborockException, match="Communication error"): + await dnd_trait.clear_dnd_timer()