Skip to content

Commit 362ec1d

Browse files
authored
chore: Overhaul new device trait interfaces (#489)
* chore: Overhaul new device trait interfaces * chore: Only allow a single trait * feat: update CLI with new properties * chore: remove unnecessarily local variables * chore: add comment about rpc channel hacks and separate property files * chore: Rename b01 properties to match v1 * chore: fix return types in CleanSummaryTrait
1 parent ed46bce commit 362ec1d

27 files changed

+577
-479
lines changed

roborock/cli.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import json
2828
import logging
2929
import threading
30+
from collections.abc import Callable
3031
from dataclasses import asdict, dataclass
3132
from pathlib import Path
3233
from typing import Any, cast
@@ -43,6 +44,8 @@
4344
from roborock.devices.cache import Cache, CacheData
4445
from roborock.devices.device import RoborockDevice
4546
from roborock.devices.device_manager import DeviceManager, create_device_manager, create_home_data_api
47+
from roborock.devices.traits import Trait
48+
from roborock.devices.traits.v1 import V1TraitMixin
4649
from roborock.protocol import MessageParser
4750
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
4851
from roborock.web_api import RoborockApiClient
@@ -377,23 +380,30 @@ async def execute_scene(ctx, scene_id):
377380
await client.execute_scene(cache_data.user_data, scene_id)
378381

379382

383+
async def _v1_trait(context: RoborockContext, device_id: str, display_func: Callable[[], V1TraitMixin]) -> Trait:
384+
device_manager = await context.get_device_manager()
385+
device = await device_manager.get_device(device_id)
386+
if device.v1_properties is None:
387+
raise RoborockException(f"Device {device.name} does not support V1 protocol")
388+
389+
trait = display_func(device.v1_properties)
390+
await trait.refresh()
391+
return trait
392+
393+
394+
async def _display_v1_trait(context: RoborockContext, device_id: str, display_func: Callable[[], Trait]) -> None:
395+
trait = await _v1_trait(context, device_id, display_func)
396+
click.echo(dump_json(trait.as_dict()))
397+
398+
380399
@session.command()
381400
@click.option("--device_id", required=True)
382401
@click.pass_context
383402
@async_command
384403
async def status(ctx, device_id: str):
385404
"""Get device status."""
386405
context: RoborockContext = ctx.obj
387-
388-
device_manager = await context.get_device_manager()
389-
device = await device_manager.get_device(device_id)
390-
391-
if not (status_trait := device.traits.get("status")):
392-
click.echo(f"Device {device.name} does not have a status trait")
393-
return
394-
395-
status_result = await status_trait.get_status()
396-
click.echo(dump_json(status_result.as_dict()))
406+
await _display_v1_trait(context, device_id, lambda v1: v1.status)
397407

398408

399409
@session.command()
@@ -403,15 +413,7 @@ async def status(ctx, device_id: str):
403413
async def clean_summary(ctx, device_id: str):
404414
"""Get device clean summary."""
405415
context: RoborockContext = ctx.obj
406-
407-
device_manager = await context.get_device_manager()
408-
device = await device_manager.get_device(device_id)
409-
if not (clean_summary_trait := device.traits.get("clean_summary")):
410-
click.echo(f"Device {device.name} does not have a clean summary trait")
411-
return
412-
413-
clean_summary_result = await clean_summary_trait.get_clean_summary()
414-
click.echo(dump_json(clean_summary_result.as_dict()))
416+
await _display_v1_trait(context, device_id, lambda v1: v1.clean_summary)
415417

416418

417419
@session.command()
@@ -421,17 +423,7 @@ async def clean_summary(ctx, device_id: str):
421423
async def volume(ctx, device_id: str):
422424
"""Get device volume."""
423425
context: RoborockContext = ctx.obj
424-
425-
device_manager = await context.get_device_manager()
426-
device = await device_manager.get_device(device_id)
427-
428-
if not (volume_trait := device.traits.get("sound_volume")):
429-
click.echo(f"Device {device.name} does not have a volume trait")
430-
return
431-
432-
volume_result = await volume_trait.get_volume()
433-
click.echo(f"Device {device_id} volume:")
434-
click.echo(volume_result)
426+
await _display_v1_trait(context, device_id, lambda v1: v1.sound_volume)
435427

436428

437429
@session.command()
@@ -442,14 +434,7 @@ async def volume(ctx, device_id: str):
442434
async def set_volume(ctx, device_id: str, volume: int):
443435
"""Set the devicevolume."""
444436
context: RoborockContext = ctx.obj
445-
446-
device_manager = await context.get_device_manager()
447-
device = await device_manager.get_device(device_id)
448-
449-
if not (volume_trait := device.traits.get("sound_volume")):
450-
click.echo(f"Device {device.name} does not have a volume trait")
451-
return
452-
437+
volume_trait = await _v1_trait(context, device_id, lambda v1: v1.sound_volume)
453438
await volume_trait.set_volume(volume)
454439
click.echo(f"Set Device {device_id} volume to {volume}")
455440

roborock/devices/device.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
import logging
88
from abc import ABC
9-
from collections.abc import Callable, Mapping
10-
from types import MappingProxyType
9+
from collections.abc import Callable
1110

1211
from roborock.containers import HomeDataDevice
1312
from roborock.roborock_message import RoborockMessage
1413

1514
from .channel import Channel
16-
from .traits.trait import Trait
15+
from .traits import Trait
16+
from .traits.traits_mixin import TraitsMixin
1717

1818
_LOGGER = logging.getLogger(__name__)
1919

@@ -22,33 +22,35 @@
2222
]
2323

2424

25-
class RoborockDevice(ABC):
25+
class RoborockDevice(ABC, TraitsMixin):
2626
"""A generic channel for establishing a connection with a Roborock device.
2727
2828
Individual channel implementations have their own methods for speaking to
2929
the device that hide some of the protocol specific complexity, but they
3030
are still specialized for the device type and protocol.
31+
32+
Attributes of the device are exposed through traits, which are mixed in
33+
through the TraitsMixin class. Traits are optional and may not be present
34+
on all devices.
3135
"""
3236

3337
def __init__(
3438
self,
3539
device_info: HomeDataDevice,
3640
channel: Channel,
37-
traits: list[Trait],
41+
trait: Trait,
3842
) -> None:
3943
"""Initialize the RoborockDevice.
4044
4145
The device takes ownership of the channel for communication with the device.
4246
Use `connect()` to establish the connection, which will set up the appropriate
4347
protocol channel. Use `close()` to clean up all connections.
4448
"""
49+
TraitsMixin.__init__(self, trait)
4550
self._duid = device_info.duid
4651
self._name = device_info.name
4752
self._channel = channel
4853
self._unsub: Callable[[], None] | None = None
49-
self._trait_map = {trait.name: trait for trait in traits}
50-
if len(self._trait_map) != len(traits):
51-
raise ValueError("Duplicate trait names found in traits list")
5254

5355
@property
5456
def duid(self) -> str:
@@ -81,8 +83,3 @@ async def close(self) -> None:
8183
def _on_message(self, message: RoborockMessage) -> None:
8284
"""Handle incoming messages from the device."""
8385
_LOGGER.debug("Received message from device: %s", message)
84-
85-
@property
86-
def traits(self) -> Mapping[str, Trait]:
87-
"""Return the traits of the device."""
88-
return MappingProxyType(self._trait_map)

roborock/devices/device_manager.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import aiohttp
99

10-
from roborock.code_mappings import RoborockCategory
1110
from roborock.containers import (
1211
HomeData,
1312
HomeDataDevice,
@@ -23,14 +22,7 @@
2322
from .cache import Cache, NoCache
2423
from .channel import Channel
2524
from .mqtt_channel import create_mqtt_channel
26-
from .traits.b01.props import B01PropsApi
27-
from .traits.clean_summary import CleanSummaryTrait
28-
from .traits.dnd import DoNotDisturbTrait
29-
from .traits.dyad import DyadApi
30-
from .traits.sound_volume import SoundVolumeTrait
31-
from .traits.status import StatusTrait
32-
from .traits.trait import Trait
33-
from .traits.zeo import ZeoApi
25+
from .traits import Trait, a01, b01, v1
3426
from .v1_channel import create_v1_channel
3527

3628
_LOGGER = logging.getLogger(__name__)
@@ -153,30 +145,20 @@ async def create_device_manager(
153145

154146
def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
155147
channel: Channel
156-
traits: list[Trait] = []
157-
# TODO: Define a registration mechanism/factory for v1 traits
148+
trait: Trait
158149
match device.pv:
159150
case DeviceVersion.V1:
160151
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
161-
traits.append(StatusTrait(product, channel.rpc_channel))
162-
traits.append(DoNotDisturbTrait(channel.rpc_channel))
163-
traits.append(CleanSummaryTrait(channel.rpc_channel))
164-
traits.append(SoundVolumeTrait(channel.rpc_channel))
152+
trait = v1.create(product, channel.rpc_channel)
165153
case DeviceVersion.A01:
166-
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
167-
match product.category:
168-
case RoborockCategory.WET_DRY_VAC:
169-
traits.append(DyadApi(mqtt_channel))
170-
case RoborockCategory.WASHING_MACHINE:
171-
traits.append(ZeoApi(mqtt_channel))
172-
case _:
173-
raise NotImplementedError(f"Device {device.name} has unsupported category {product.category}")
154+
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
155+
trait = a01.create(product, channel)
174156
case DeviceVersion.B01:
175157
channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
176-
traits.append(B01PropsApi(channel))
158+
trait = b01.create(channel)
177159
case _:
178160
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
179-
return RoborockDevice(device, channel, traits)
161+
return RoborockDevice(device, channel, trait)
180162

181163
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
182164
await manager.discover_devices()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Module for device traits."""
2+
3+
from abc import ABC
4+
5+
__all__ = [
6+
"Trait",
7+
"traits_mixin",
8+
"v1",
9+
"a01",
10+
"b01",
11+
]
12+
13+
14+
class Trait(ABC):
15+
"""Base class for all traits."""
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import Any
2+
3+
from roborock.containers import HomeDataProduct, RoborockCategory
4+
from roborock.devices.a01_channel import send_decoded_command
5+
from roborock.devices.mqtt_channel import MqttChannel
6+
from roborock.devices.traits import Trait
7+
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol
8+
9+
__init__ = [
10+
"DyadApi",
11+
"ZeoApi",
12+
]
13+
14+
15+
class DyadApi(Trait):
16+
"""API for interacting with Dyad devices."""
17+
18+
def __init__(self, channel: MqttChannel) -> None:
19+
"""Initialize the Dyad API."""
20+
self._channel = channel
21+
22+
async def query_values(self, protocols: list[RoborockDyadDataProtocol]) -> dict[RoborockDyadDataProtocol, Any]:
23+
"""Query the device for the values of the given Dyad protocols."""
24+
params = {RoborockDyadDataProtocol.ID_QUERY: [int(p) for p in protocols]}
25+
return await send_decoded_command(self._channel, params)
26+
27+
async def set_value(self, protocol: RoborockDyadDataProtocol, value: Any) -> dict[RoborockDyadDataProtocol, Any]:
28+
"""Set a value for a specific protocol on the device."""
29+
params = {protocol: value}
30+
return await send_decoded_command(self._channel, params)
31+
32+
33+
class ZeoApi(Trait):
34+
"""API for interacting with Zeo devices."""
35+
36+
name = "zeo"
37+
38+
def __init__(self, channel: MqttChannel) -> None:
39+
"""Initialize the Zeo API."""
40+
self._channel = channel
41+
42+
async def query_values(self, protocols: list[RoborockZeoProtocol]) -> dict[RoborockZeoProtocol, Any]:
43+
"""Query the device for the values of the given protocols."""
44+
params = {RoborockZeoProtocol.ID_QUERY: [int(p) for p in protocols]}
45+
return await send_decoded_command(self._channel, params)
46+
47+
async def set_value(self, protocol: RoborockZeoProtocol, value: Any) -> dict[RoborockZeoProtocol, Any]:
48+
"""Set a value for a specific protocol on the device."""
49+
params = {protocol: value}
50+
return await send_decoded_command(self._channel, params)
51+
52+
53+
def create(product: HomeDataProduct, mqtt_channel: MqttChannel) -> DyadApi | ZeoApi:
54+
"""Create traits for A01 devices."""
55+
match product.category:
56+
case RoborockCategory.WET_DRY_VAC:
57+
return DyadApi(mqtt_channel)
58+
case RoborockCategory.WASHING_MACHINE:
59+
return ZeoApi(mqtt_channel)
60+
case _:
61+
raise NotImplementedError(f"Unsupported category {product.category}")
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Traits for B01 devices."""
2+
3+
from roborock import RoborockB01Methods
4+
from roborock.devices.b01_channel import send_decoded_command
5+
from roborock.devices.mqtt_channel import MqttChannel
6+
from roborock.devices.traits import Trait
7+
from roborock.roborock_message import RoborockB01Props
8+
9+
__init__ = [
10+
"create_b01_traits",
11+
"PropertiesApi",
12+
]
13+
14+
15+
class PropertiesApi(Trait):
16+
"""API for interacting with B01 devices."""
17+
18+
def __init__(self, channel: MqttChannel) -> None:
19+
"""Initialize the B01Props API."""
20+
self._channel = channel
21+
22+
async def query_values(self, props: list[RoborockB01Props]) -> None:
23+
"""Query the device for the values of the given Dyad protocols."""
24+
await send_decoded_command(
25+
self._channel, dps=10000, command=RoborockB01Methods.GET_PROP, params={"property": props}
26+
)
27+
28+
29+
def create(channel: MqttChannel) -> PropertiesApi:
30+
"""Create traits for B01 devices."""
31+
return PropertiesApi(channel)

roborock/devices/traits/b01/props.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)