Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions roborock/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def from_dict(cls, data: dict[str, Any]):
return None
field_types = {field.name: field.type for field in dataclasses.fields(cls)}
result: dict[str, Any] = {}
for key, value in data.items():
key = _decamelize(key)
for orig_key, value in data.items():
key = _decamelize(orig_key)
if (field_type := field_types.get(key)) is None:
continue
if value == "None" or value is None:
Expand Down Expand Up @@ -178,16 +178,18 @@ class RoborockBaseTimer(RoborockBase):
end_hour: int | None = None
end_minute: int | None = None
enabled: int | None = None
start_time: datetime.time | None = None
end_time: datetime.time | None = None

def __post_init__(self) -> None:
self.start_time = (
@property
def start_time(self) -> datetime.time | None:
return (
datetime.time(hour=self.start_hour, minute=self.start_minute)
if self.start_hour is not None and self.start_minute is not None
else None
)
self.end_time = (

@property
def end_time(self) -> datetime.time | None:
return (
datetime.time(hour=self.end_hour, minute=self.end_minute)
if self.end_hour is not None and self.end_minute is not None
else None
Expand Down
2 changes: 2 additions & 0 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .channel import Channel
from .mqtt_channel import create_mqtt_channel
from .traits.b01.props import B01PropsApi
from .traits.dnd import DoNotDisturbTrait
from .traits.dyad import DyadApi
from .traits.status import StatusTrait
from .traits.trait import Trait
Expand Down Expand Up @@ -152,6 +153,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
case DeviceVersion.V1:
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
traits.append(StatusTrait(product, channel.rpc_channel))
traits.append(DoNotDisturbTrait(channel.rpc_channel))
case DeviceVersion.A01:
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
match product.category:
Expand Down
41 changes: 41 additions & 0 deletions roborock/devices/traits/dnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Module for Roborock V1 devices.

This interface is experimental and subject to breaking changes without notice
until the API is stable.
"""

import logging

from roborock.containers import DnDTimer
from roborock.devices.v1_rpc_channel import V1RpcChannel
from roborock.roborock_typing import RoborockCommand

from .trait import Trait

_LOGGER = logging.getLogger(__name__)

__all__ = [
"DoNotDisturbTrait",
]


class DoNotDisturbTrait(Trait):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zooming out a bit, where do you see the data from get dnd timer being stored? In the device object? Still in the trait somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know yet?

My thinking is on this next step is: the current set of commands isn't complex enough to have any real use cases, so I want to start adding them in. The whole trait syntax is not good yet, and needs to be rewritten, with a few examples. I was only thinking in the context of adding traits to the CLI for now.

One question i'm wondering if is if data needs to be stored at all here? but yeah i think it probably will.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need data stored somewhere for simplicity, but i'm open to other solutions.

There could be cases where one trait has multiple entities relying on it, so we don't want to call update() for each entity.

But 100% fine with what you're saying for now, fine with punting this down the road.

"""Trait for managing Do Not Disturb (DND) settings on Roborock devices."""

name = "do_not_disturb"

def __init__(self, rpc_channel: V1RpcChannel) -> None:
"""Initialize the DoNotDisturbTrait."""
self._rpc_channel = rpc_channel

async def get_dnd_timer(self) -> DnDTimer:
"""Get the current Do Not Disturb (DND) timer settings of the device."""
return await self._rpc_channel.send_command(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)

async def set_dnd_timer(self, dnd_timer: DnDTimer) -> None:
"""Set the Do Not Disturb (DND) timer settings of the device."""
await self._rpc_channel.send_command(RoborockCommand.SET_DND_TIMER, params=dnd_timer.as_dict())

async def clear_dnd_timer(self) -> None:
"""Clear the Do Not Disturb (DND) timer settings of the device."""
await self._rpc_channel.send_command(RoborockCommand.CLOSE_DND_TIMER)
6 changes: 3 additions & 3 deletions roborock/devices/traits/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
S7MaxVStatus,
Status,
)
from roborock.devices.v1_rpc_channel import V1RpcChannel
from roborock.roborock_typing import RoborockCommand

from ..v1_rpc_channel import V1RpcChannel
from .trait import Trait

_LOGGER = logging.getLogger(__name__)

__all__ = [
"Status",
"StatusTrait",
]


class StatusTrait(Trait):
"""Unified Roborock device class with automatic connection setup."""
"""Trait for managing the status of Roborock devices."""

name = "status"

Expand Down
8 changes: 6 additions & 2 deletions roborock/devices/v1_rpc_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,21 @@ async def _send_raw_command(
params: ParamsType = None,
) -> Any:
"""Send a command and return a parsed response RoborockBase type."""
_LOGGER.debug("Sending command (%s): %s, params=%s", self._name, method, params)
request_message = RequestMessage(method, params=params)
_LOGGER.debug(
"Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params
)
message = self._payload_encoder(request_message)

future: asyncio.Future[dict[str, Any]] = asyncio.Future()

def find_response(response_message: RoborockMessage) -> None:
try:
decoded = decode_rpc_response(response_message)
except RoborockException:
except RoborockException as ex:
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
return
_LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id)
if decoded.request_id == request_message.request_id:
future.set_result(decoded.data)

Expand Down
4 changes: 3 additions & 1 deletion roborock/protocols/v1_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ResponseMessage:
def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
"""Decode a V1 RPC_RESPONSE message."""
if not message.payload:
raise RoborockException("Invalid V1 message format: missing payload")
return ResponseMessage(request_id=message.seq, data={})
try:
payload = json.loads(message.payload.decode())
except (json.JSONDecodeError, TypeError) as e:
Expand Down Expand Up @@ -141,6 +141,8 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
_LOGGER.debug("Decoded V1 message result: %s", result)
if isinstance(result, list) and result:
result = result[0]
if isinstance(result, str) and result == "ok":
result = {}
if not isinstance(result, dict):
raise RoborockException(f"Invalid V1 message format: 'result' should be a dictionary for {message.payload!r}")
return ResponseMessage(request_id=request_id, data=result)
Expand Down
1 change: 1 addition & 0 deletions tests/devices/traits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for device traits."""
147 changes: 147 additions & 0 deletions tests/devices/traits/test_dnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Tests for the DoNotDisturbTrait class."""

from unittest.mock import AsyncMock

import pytest

from roborock.containers import DnDTimer
from roborock.devices.traits.dnd import DoNotDisturbTrait
from roborock.devices.v1_rpc_channel import V1RpcChannel
from roborock.roborock_typing import RoborockCommand


@pytest.fixture
def mock_rpc_channel() -> AsyncMock:
"""Create a mock RPC channel."""
mock_channel = AsyncMock(spec=V1RpcChannel)
# Ensure send_command is an AsyncMock that returns awaitable coroutines
mock_channel.send_command = AsyncMock()
return mock_channel


@pytest.fixture
def dnd_trait(mock_rpc_channel: AsyncMock) -> DoNotDisturbTrait:
"""Create a DoNotDisturbTrait instance with mocked dependencies."""
return DoNotDisturbTrait(mock_rpc_channel)


@pytest.fixture
def sample_dnd_timer() -> DnDTimer:
"""Create a sample DnDTimer for testing."""
return DnDTimer(
start_hour=22,
start_minute=0,
end_hour=8,
end_minute=0,
enabled=1,
)


def test_trait_name(dnd_trait: DoNotDisturbTrait) -> None:
"""Test that the trait has the correct name."""
assert dnd_trait.name == "do_not_disturb"


async def test_get_dnd_timer_success(
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
) -> None:
"""Test successfully getting DnD timer settings."""
# Setup mock to return the sample DnD timer
mock_rpc_channel.send_command.return_value = sample_dnd_timer

# Call the method
result = await dnd_trait.get_dnd_timer()

# Verify the result
assert result == sample_dnd_timer
assert result.start_hour == 22
assert result.start_minute == 0
assert result.end_hour == 8
assert result.end_minute == 0
assert result.enabled == 1

# Verify the RPC call was made correctly
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)


async def test_get_dnd_timer_disabled(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
"""Test getting DnD timer when it's disabled."""
disabled_timer = DnDTimer(
start_hour=22,
start_minute=0,
end_hour=8,
end_minute=0,
enabled=0,
)
mock_rpc_channel.send_command.return_value = disabled_timer

result = await dnd_trait.get_dnd_timer()

assert result.enabled == 0
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.GET_DND_TIMER, response_type=DnDTimer)


async def test_set_dnd_timer_success(
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
) -> None:
"""Test successfully setting DnD timer settings."""
# Call the method
await dnd_trait.set_dnd_timer(sample_dnd_timer)

# Verify the RPC call was made correctly with dataclass converted to dict

expected_params = {
"startHour": 22,
"startMinute": 0,
"endHour": 8,
"endMinute": 0,
"enabled": 1,
}
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.SET_DND_TIMER, params=expected_params)


async def test_clear_dnd_timer_success(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
"""Test successfully clearing DnD timer settings."""
# Call the method
await dnd_trait.clear_dnd_timer()

# Verify the RPC call was made correctly
mock_rpc_channel.send_command.assert_called_once_with(RoborockCommand.CLOSE_DND_TIMER)


async def test_get_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
"""Test that exceptions from RPC channel are propagated in get_dnd_timer."""
from roborock.exceptions import RoborockException

# Setup mock to raise an exception
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")

# Verify the exception is propagated
with pytest.raises(RoborockException, match="Communication error"):
await dnd_trait.get_dnd_timer()


async def test_set_dnd_timer_propagates_exception(
dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock, sample_dnd_timer: DnDTimer
) -> None:
"""Test that exceptions from RPC channel are propagated in set_dnd_timer."""
from roborock.exceptions import RoborockException

# Setup mock to raise an exception
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")

# Verify the exception is propagated
with pytest.raises(RoborockException, match="Communication error"):
await dnd_trait.set_dnd_timer(sample_dnd_timer)


async def test_clear_dnd_timer_propagates_exception(dnd_trait: DoNotDisturbTrait, mock_rpc_channel: AsyncMock) -> None:
"""Test that exceptions from RPC channel are propagated in clear_dnd_timer."""
from roborock.exceptions import RoborockException

# Setup mock to raise an exception
mock_rpc_channel.send_command.side_effect = RoborockException("Communication error")

# Verify the exception is propagated
with pytest.raises(RoborockException, match="Communication error"):
await dnd_trait.clear_dnd_timer()