Skip to content

Commit cce1c1b

Browse files
authored
feat: Add a DnD trait and fix bugs in the rpc channels (#471)
* feat: Add a DnD trait and fix bugs in the rpc channels * chore: Simplify command sending * chore: revert changes to rpc channel
1 parent b227911 commit cce1c1b

File tree

8 files changed

+212
-13
lines changed

8 files changed

+212
-13
lines changed

roborock/containers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def from_dict(cls, data: dict[str, Any]):
134134
return None
135135
field_types = {field.name: field.type for field in dataclasses.fields(cls)}
136136
result: dict[str, Any] = {}
137-
for key, value in data.items():
138-
key = _decamelize(key)
137+
for orig_key, value in data.items():
138+
key = _decamelize(orig_key)
139139
if (field_type := field_types.get(key)) is None:
140140
continue
141141
if value == "None" or value is None:
@@ -178,16 +178,18 @@ class RoborockBaseTimer(RoborockBase):
178178
end_hour: int | None = None
179179
end_minute: int | None = None
180180
enabled: int | None = None
181-
start_time: datetime.time | None = None
182-
end_time: datetime.time | None = None
183181

184-
def __post_init__(self) -> None:
185-
self.start_time = (
182+
@property
183+
def start_time(self) -> datetime.time | None:
184+
return (
186185
datetime.time(hour=self.start_hour, minute=self.start_minute)
187186
if self.start_hour is not None and self.start_minute is not None
188187
else None
189188
)
190-
self.end_time = (
189+
190+
@property
191+
def end_time(self) -> datetime.time | None:
192+
return (
191193
datetime.time(hour=self.end_hour, minute=self.end_minute)
192194
if self.end_hour is not None and self.end_minute is not None
193195
else None

roborock/devices/device_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .channel import Channel
2323
from .mqtt_channel import create_mqtt_channel
2424
from .traits.b01.props import B01PropsApi
25+
from .traits.dnd import DoNotDisturbTrait
2526
from .traits.dyad import DyadApi
2627
from .traits.status import StatusTrait
2728
from .traits.trait import Trait
@@ -152,6 +153,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
152153
case DeviceVersion.V1:
153154
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
154155
traits.append(StatusTrait(product, channel.rpc_channel))
156+
traits.append(DoNotDisturbTrait(channel.rpc_channel))
155157
case DeviceVersion.A01:
156158
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
157159
match product.category:

roborock/devices/traits/dnd.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Module for Roborock V1 devices.
2+
3+
This interface is experimental and subject to breaking changes without notice
4+
until the API is stable.
5+
"""
6+
7+
import logging
8+
9+
from roborock.containers import DnDTimer
10+
from roborock.devices.v1_rpc_channel import V1RpcChannel
11+
from roborock.roborock_typing import RoborockCommand
12+
13+
from .trait import Trait
14+
15+
_LOGGER = logging.getLogger(__name__)
16+
17+
__all__ = [
18+
"DoNotDisturbTrait",
19+
]
20+
21+
22+
class DoNotDisturbTrait(Trait):
23+
"""Trait for managing Do Not Disturb (DND) settings on Roborock devices."""
24+
25+
name = "do_not_disturb"
26+
27+
def __init__(self, rpc_channel: V1RpcChannel) -> None:
28+
"""Initialize the DoNotDisturbTrait."""
29+
self._rpc_channel = rpc_channel
30+
31+
async def get_dnd_timer(self) -> DnDTimer:
32+
"""Get the current Do Not Disturb (DND) timer settings of the device."""
33+
return await self._rpc_channel.send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)
34+
35+
async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None:
36+
"""Set the Do Not Disturb (DND) timer settings of the device."""
37+
await self._rpc_channel.send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict())
38+
39+
async def clear_dnd_timer(self) -> None:
40+
"""Clear the Do Not Disturb (DND) timer settings of the device."""
41+
await self._rpc_channel.send_command(RoborockCommand.CLOSE_DND_TIMER)

roborock/devices/traits/status.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@
1212
S7MaxVStatus,
1313
Status,
1414
)
15+
from roborock.devices.v1_rpc_channel import V1RpcChannel
1516
from roborock.roborock_typing import RoborockCommand
1617

17-
from ..v1_rpc_channel import V1RpcChannel
1818
from .trait import Trait
1919

2020
_LOGGER = logging.getLogger(__name__)
2121

2222
__all__ = [
23-
"Status",
23+
"StatusTrait",
2424
]
2525

2626

2727
class StatusTrait(Trait):
28-
"""Unified Roborock device class with automatic connection setup."""
28+
"""Trait for managing the status of Roborock devices."""
2929

3030
name = "status"
3131

roborock/devices/v1_rpc_channel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,21 @@ async def _send_raw_command(
132132
params: ParamsType = None,
133133
) -> Any:
134134
"""Send a command and return a parsed response RoborockBase type."""
135-
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
136135
request_message = RequestMessage(method, params=params)
136+
_LOGGER.debug(
137+
"Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params
138+
)
137139
message = self._payload_encoder(request_message)
138140

139141
future: asyncio.Future[dict[str, Any]] = asyncio.Future()
140142

141143
def find_response(response_message: RoborockMessage) -> None:
142144
try:
143145
decoded = decode_rpc_response(response_message)
144-
except RoborockException:
146+
except RoborockException as ex:
147+
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
145148
return
149+
_LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id)
146150
if decoded.request_id == request_message.request_id:
147151
future.set_result(decoded.data)
148152

roborock/protocols/v1_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class ResponseMessage:
109109
def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
110110
"""Decode a V1 RPC_RESPONSE message."""
111111
if not message.payload:
112-
raise RoborockException("Invalid V1 message format: missing payload")
112+
return ResponseMessage(request_id=message.seq, data={})
113113
try:
114114
payload = json.loads(message.payload.decode())
115115
except (json.JSONDecodeError, TypeError) as e:
@@ -141,6 +141,8 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
141141
_LOGGER.debug("Decoded V1 message result: %s", result)
142142
if isinstance(result, list) and result:
143143
result = result[0]
144+
if isinstance(result, str) and result == "ok":
145+
result = {}
144146
if not isinstance(result, dict):
145147
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
146148
return ResponseMessage(request_id=request_id, data=result)

tests/devices/traits/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for device traits."""

tests/devices/traits/test_dnd.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Tests for the DoNotDisturbTrait class."""
2+
3+
from unittest.mock import AsyncMock
4+
5+
import pytest
6+
7+
from roborock.containers import DnDTimer
8+
from roborock.devices.traits.dnd import DoNotDisturbTrait
9+
from roborock.devices.v1_rpc_channel import V1RpcChannel
10+
from roborock.roborock_typing import RoborockCommand
11+
12+
13+
@pytest.fixture
14+
def mock_rpc_channel() -> AsyncMock:
15+
"""Create a mock RPC channel."""
16+
mock_channel = AsyncMock(spec=V1RpcChannel)
17+
# Ensure send_command is an AsyncMock that returns awaitable coroutines
18+
mock_channel.send_command = AsyncMock()
19+
return mock_channel
20+
21+
22+
@pytest.fixture
23+
def dnd_trait(mock_rpc_channel: AsyncMock) -> DoNotDisturbTrait:
24+
"""Create a DoNotDisturbTrait instance with mocked dependencies."""
25+
return DoNotDisturbTrait(mock_rpc_channel)
26+
27+
28+
@pytest.fixture
29+
def sample_dnd_timer() -> DnDTimer:
30+
"""Create a sample DnDTimer for testing."""
31+
return DnDTimer(
32+
start_hour=22,
33+
start_minute=0,
34+
end_hour=8,
35+
end_minute=0,
36+
enabled=1,
37+
)
38+
39+
40+
def test_trait_name(dnd_trait: DoNotDisturbTrait) -> None:
41+
"""Test that the trait has the correct name."""
42+
assert dnd_trait.name == "do_not_disturb"
43+
44+
45+
async def test_get_dnd_timer_success(
46+
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
47+
) -> None:
48+
"""Test successfully getting DnD timer settings."""
49+
# Setup mock to return the sample DnD timer
50+
mock_rpc_channel.send_command.return_value = sample_dnd_timer
51+
52+
# Call the method
53+
result = await dnd_trait.get_dnd_timer()
54+
55+
# Verify the result
56+
assert result == sample_dnd_timer
57+
assert result.start_hour == 22
58+
assert result.start_minute == 0
59+
assert result.end_hour == 8
60+
assert result.end_minute == 0
61+
assert result.enabled == 1
62+
63+
# Verify the RPC call was made correctly
64+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)
65+
66+
67+
async def test_get_dnd_timer_disabled(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
68+
"""Test getting DnD timer when it's disabled."""
69+
disabled_timer = DnDTimer(
70+
start_hour=22,
71+
start_minute=0,
72+
end_hour=8,
73+
end_minute=0,
74+
enabled=0,
75+
)
76+
mock_rpc_channel.send_command.return_value = disabled_timer
77+
78+
result = await dnd_trait.get_dnd_timer()
79+
80+
assert result.enabled == 0
81+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)
82+
83+
84+
async def test_set_dnd_timer_success(
85+
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
86+
) -> None:
87+
"""Test successfully setting DnD timer settings."""
88+
# Call the method
89+
await dnd_trait.set_dnd_timer(sample_dnd_timer)
90+
91+
# Verify the RPC call was made correctly with dataclass converted to dict
92+
93+
expected_params = {
94+
"startHour": 22,
95+
"startMinute": 0,
96+
"endHour": 8,
97+
"endMinute": 0,
98+
"enabled": 1,
99+
}
100+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.SET_DND_TIMER, params=expected_params)
101+
102+
103+
async def test_clear_dnd_timer_success(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
104+
"""Test successfully clearing DnD timer settings."""
105+
# Call the method
106+
await dnd_trait.clear_dnd_timer()
107+
108+
# Verify the RPC call was made correctly
109+
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.CLOSE_DND_TIMER)
110+
111+
112+
async def test_get_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
113+
"""Test that exceptions from RPC channel are propagated in get_dnd_timer."""
114+
from roborock.exceptions import RoborockException
115+
116+
# Setup mock to raise an exception
117+
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")
118+
119+
# Verify the exception is propagated
120+
with pytest.raises(RoborockException, match="Communication error"):
121+
await dnd_trait.get_dnd_timer()
122+
123+
124+
async def test_set_dnd_timer_propagates_exception(
125+
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
126+
) -> None:
127+
"""Test that exceptions from RPC channel are propagated in set_dnd_timer."""
128+
from roborock.exceptions import RoborockException
129+
130+
# Setup mock to raise an exception
131+
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")
132+
133+
# Verify the exception is propagated
134+
with pytest.raises(RoborockException, match="Communication error"):
135+
await dnd_trait.set_dnd_timer(sample_dnd_timer)
136+
137+
138+
async def test_clear_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
139+
"""Test that exceptions from RPC channel are propagated in clear_dnd_timer."""
140+
from roborock.exceptions import RoborockException
141+
142+
# Setup mock to raise an exception
143+
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")
144+
145+
# Verify the exception is propagated
146+
with pytest.raises(RoborockException, match="Communication error"):
147+
await dnd_trait.clear_dnd_timer()

0 commit comments

Comments
 (0)