Skip to content

Commit 4e7e776

Browse files
allenporterCopilot
andauthored
feat: Simplify device manager creation (#570)
* feat: Simplify device manager creation Remove the need to pass in a home data api by taking all information up front to create the roborock web API in the device manager creation interface. This creates a new Web Api wrapper that will also be used for getting routine information in a followup PR, which was the initial motivation behind this change. Overall this will allow additional code simplification in callers when onboarding the new API, completely removing the need to continue using the Web API directly for devices. * Update roborock/web_api.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update roborock/devices/device_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: go back to old web API name --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6a5db1d commit 4e7e776

File tree

4 files changed

+124
-85
lines changed

4 files changed

+124
-85
lines changed

roborock/cli.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from roborock.device_features import DeviceFeatures
4747
from roborock.devices.cache import Cache, CacheData
4848
from roborock.devices.device import RoborockDevice
49-
from roborock.devices.device_manager import DeviceManager, create_device_manager, create_home_data_api
49+
from roborock.devices.device_manager import DeviceManager, UserParams, create_device_manager
5050
from roborock.devices.traits import Trait
5151
from roborock.devices.traits.v1 import V1TraitMixin
5252
from roborock.devices.traits.v1.consumeable import ConsumableAttribute
@@ -135,8 +135,11 @@ async def ensure_device_manager(self) -> DeviceManager:
135135
"""Ensure device manager is initialized."""
136136
if self.device_manager is None:
137137
cache_data = self.context.cache_data()
138-
home_data_api = create_home_data_api(cache_data.email, cache_data.user_data)
139-
self.device_manager = await create_device_manager(cache_data.user_data, home_data_api, self.context)
138+
user_params = UserParams(
139+
username=cache_data.email,
140+
user_data=cache_data.user_data,
141+
)
142+
self.device_manager = await create_device_manager(user_params, cache=self.context)
140143
# Cache devices for quick lookup
141144
devices = await self.device_manager.get_devices()
142145
self._devices = {device.duid: device for device in devices}

roborock/devices/device_manager.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import asyncio
44
import enum
55
import logging
6-
from collections.abc import Awaitable, Callable
6+
from collections.abc import Callable
7+
from dataclasses import dataclass
78

89
import aiohttp
910

@@ -18,7 +19,7 @@
1819
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
1920
from roborock.mqtt.session import MqttSession
2021
from roborock.protocol import create_mqtt_params
21-
from roborock.web_api import RoborockApiClient
22+
from roborock.web_api import RoborockApiClient, UserWebApiClient
2223

2324
from .cache import Cache, NoCache
2425
from .channel import Channel
@@ -30,12 +31,11 @@
3031

3132
__all__ = [
3233
"create_device_manager",
33-
"create_home_data_api",
34+
"UserParams",
3435
"DeviceManager",
3536
]
3637

3738

38-
HomeDataApi = Callable[[], Awaitable[HomeData]]
3939
DeviceCreator = Callable[[HomeData, HomeDataDevice, HomeDataProduct], RoborockDevice]
4040

4141

@@ -53,7 +53,7 @@ class DeviceManager:
5353

5454
def __init__(
5555
self,
56-
home_data_api: HomeDataApi,
56+
web_api: UserWebApiClient,
5757
device_creator: DeviceCreator,
5858
mqtt_session: MqttSession,
5959
cache: Cache,
@@ -62,7 +62,7 @@ def __init__(
6262
6363
This takes ownership of the MQTT session and will close it when the manager is closed.
6464
"""
65-
self._home_data_api = home_data_api
65+
self._web_api = web_api
6666
self._cache = cache
6767
self._device_creator = device_creator
6868
self._devices: dict[str, RoborockDevice] = {}
@@ -73,7 +73,7 @@ async def discover_devices(self) -> list[RoborockDevice]:
7373
cache_data = await self._cache.get()
7474
if not cache_data.home_data:
7575
_LOGGER.debug("No cached home data found, fetching from API")
76-
cache_data.home_data = await self._home_data_api()
76+
cache_data.home_data = await self._web_api.get_home_data()
7777
await self._cache.set(cache_data)
7878
home_data = cache_data.home_data
7979

@@ -108,45 +108,69 @@ async def close(self) -> None:
108108
await asyncio.gather(*tasks)
109109

110110

111-
def create_home_data_api(
112-
email: str, user_data: UserData, base_url: str | None = None, session: aiohttp.ClientSession | None = None
113-
) -> HomeDataApi:
114-
"""Create a home data API wrapper.
111+
@dataclass
112+
class UserParams:
113+
"""Parameters for creating a new session with Roborock devices.
115114
116-
This function creates a wrapper around the Roborock API client to fetch
117-
home data for the user.
115+
These parameters include the username, user data for authentication,
116+
and an optional base URL for the Roborock API. The `user_data` and `base_url`
117+
parameters are obtained from `RoborockApiClient` during the login process.
118118
"""
119-
# Note: This will auto discover the API base URL. This can be improved
120-
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
121-
client = RoborockApiClient(username=email, base_url=base_url, session=session)
122119

123-
return create_home_data_from_api_client(client, user_data)
120+
username: str
121+
"""The username (email) used for logging in."""
122+
123+
user_data: UserData
124+
"""This is the user data containing authentication information."""
125+
126+
base_url: str | None = None
127+
"""Optional base URL for the Roborock API.
128+
129+
This is used to speed up connection times by avoiding the need to
130+
discover the API base URL each time. If not provided, the API client
131+
will attempt to discover it automatically which may take multiple requests.
132+
"""
124133

125134

126-
def create_home_data_from_api_client(client: RoborockApiClient, user_data: UserData) -> HomeDataApi:
135+
def create_web_api_wrapper(
136+
user_params: UserParams,
137+
*,
138+
cache: Cache | None = None,
139+
session: aiohttp.ClientSession | None = None,
140+
) -> UserWebApiClient:
127141
"""Create a home data API wrapper from an existing API client."""
128142

129-
async def home_data_api() -> HomeData:
130-
return await client.get_home_data_v3(user_data)
143+
# Note: This will auto discover the API base URL. This can be improved
144+
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
145+
client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session)
131146

132-
return home_data_api
147+
return UserWebApiClient(client, user_params.user_data)
133148

134149

135150
async def create_device_manager(
136-
user_data: UserData,
137-
home_data_api: HomeDataApi,
151+
user_params: UserParams,
152+
*,
138153
cache: Cache | None = None,
139154
map_parser_config: MapParserConfig | None = None,
155+
session: aiohttp.ClientSession | None = None,
140156
) -> DeviceManager:
141157
"""Convenience function to create and initialize a DeviceManager.
142158
143-
The Home Data is fetched using the provided home_data_api callable which
144-
is exposed this way to allow for swapping out other implementations to
145-
include caching or other optimizations.
159+
Args:
160+
user_params: Parameters for creating the user session.
161+
cache: Optional cache implementation to use for caching device data.
162+
map_parser_config: Optional configuration for parsing maps.
163+
session: Optional aiohttp ClientSession to use for HTTP requests.
164+
165+
Returns:
166+
An initialized DeviceManager with discovered devices.
146167
"""
147168
if cache is None:
148169
cache = NoCache()
149170

171+
web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
172+
user_data = user_params.user_data
173+
150174
mqtt_params = create_mqtt_params(user_data.rriot)
151175
mqtt_session = await create_lazy_mqtt_session(mqtt_params)
152176

@@ -176,6 +200,6 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
176200
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
177201
return RoborockDevice(device, product, channel, trait)
178202

179-
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
203+
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache)
180204
await manager.discover_devices()
181205
return manager

roborock/web_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,21 @@ def _get_hawk_authentication(rriot: RRiot, url: str, formdata: dict | None = Non
707707
)
708708
mac = base64.b64encode(hmac.new(rriot.h.encode(), prestr.encode(), hashlib.sha256).digest()).decode()
709709
return f'Hawk id="{rriot.u}",s="{rriot.s}",ts="{timestamp}",nonce="{nonce}",mac="{mac}"'
710+
711+
712+
class UserWebApiClient:
713+
"""Wrapper around RoborockApiClient to provide information for a specific user.
714+
715+
This binds a RoborockApiClient to a specific user context with the
716+
provided UserData. This allows for easier access to user-specific data,
717+
to avoid needing to pass UserData around and mock out the web API.
718+
"""
719+
720+
def __init__(self, web_api: RoborockApiClient, user_data: UserData) -> None:
721+
"""Initialize the wrapper with the API client and user data."""
722+
self._web_api = web_api
723+
self._user_data = user_data
724+
725+
async def get_home_data(self) -> HomeData:
726+
"""Fetch home data using the API client."""
727+
return await self._web_api.get_home_data_v3(self._user_data)
Lines changed: 49 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Tests for the DeviceManager class."""
22

3-
from collections.abc import Generator
3+
from collections.abc import Generator, Iterator
44
from unittest.mock import AsyncMock, Mock, patch
55

66
import pytest
77

88
from roborock.data import HomeData, UserData
9-
from roborock.devices.cache import CacheData, InMemoryCache
10-
from roborock.devices.device_manager import create_device_manager, create_home_data_api
9+
from roborock.devices.cache import InMemoryCache
10+
from roborock.devices.device_manager import UserParams, create_device_manager, create_web_api_wrapper
1111
from roborock.exceptions import RoborockException
1212

1313
from .. import mock_data
1414

1515
USER_DATA = UserData.from_dict(mock_data.USER_DATA)
16+
USER_PARAMS = UserParams(username="test_user", user_data=USER_DATA)
1617
NETWORK_INFO = mock_data.NETWORK_INFO
1718

1819

@@ -33,32 +34,40 @@ def channel_fixture() -> Generator[Mock, None, None]:
3334
yield mock_channel
3435

3536

36-
async def home_home_data_no_devices() -> HomeData:
37+
@pytest.fixture(name="home_data_no_devices")
38+
def home_data_no_devices_fixture() -> Iterator[HomeData]:
3739
"""Mock home data API that returns no devices."""
38-
return HomeData(
39-
id=1,
40-
name="Test Home",
41-
devices=[],
42-
products=[],
43-
)
44-
45-
46-
async def mock_home_data() -> HomeData:
40+
with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data:
41+
home_data = HomeData(
42+
id=1,
43+
name="Test Home",
44+
devices=[],
45+
products=[],
46+
)
47+
mock_home_data.return_value = home_data
48+
yield home_data
49+
50+
51+
@pytest.fixture(name="home_data")
52+
def home_data_fixture() -> Iterator[HomeData]:
4753
"""Mock home data API that returns devices."""
48-
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
54+
with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data:
55+
home_data = HomeData.from_dict(mock_data.HOME_DATA_RAW)
56+
mock_home_data.return_value = home_data
57+
yield home_data
4958

5059

51-
async def test_no_devices() -> None:
60+
async def test_no_devices(home_data_no_devices: HomeData) -> None:
5261
"""Test the DeviceManager created with no devices returned from the API."""
5362

54-
device_manager = await create_device_manager(USER_DATA, home_home_data_no_devices)
63+
device_manager = await create_device_manager(USER_PARAMS)
5564
devices = await device_manager.get_devices()
5665
assert devices == []
5766

5867

59-
async def test_with_device() -> None:
68+
async def test_with_device(home_data: HomeData) -> None:
6069
"""Test the DeviceManager created with devices returned from the API."""
61-
device_manager = await create_device_manager(USER_DATA, mock_home_data)
70+
device_manager = await create_device_manager(USER_PARAMS)
6271
devices = await device_manager.get_devices()
6372
assert len(devices) == 1
6473
assert devices[0].duid == "abc123"
@@ -72,64 +81,49 @@ async def test_with_device() -> None:
7281
await device_manager.close()
7382

7483

75-
async def test_get_non_existent_device() -> None:
84+
async def test_get_non_existent_device(home_data: HomeData) -> None:
7685
"""Test getting a non-existent device."""
77-
device_manager = await create_device_manager(USER_DATA, mock_home_data)
86+
device_manager = await create_device_manager(USER_PARAMS)
7887
device = await device_manager.get_device("non_existent_duid")
7988
assert device is None
8089
await device_manager.close()
8190

8291

83-
async def test_home_data_api_exception() -> None:
84-
"""Test the home data API with an exception."""
85-
86-
async def home_data_api_exception() -> HomeData:
87-
raise RoborockException("Test exception")
88-
89-
with pytest.raises(RoborockException, match="Test exception"):
90-
await create_device_manager(USER_DATA, home_data_api_exception)
91-
92-
9392
async def test_create_home_data_api_exception() -> None:
9493
"""Test that exceptions from the home data API are propagated through the wrapper."""
9594

9695
with patch("roborock.devices.device_manager.RoborockApiClient.get_home_data_v3") as mock_get_home_data:
9796
mock_get_home_data.side_effect = RoborockException("Test exception")
98-
api = create_home_data_api(USER_DATA, mock_get_home_data)
97+
user_params = UserParams(username="test_user", user_data=USER_DATA)
98+
api = create_web_api_wrapper(user_params)
9999

100100
with pytest.raises(RoborockException, match="Test exception"):
101-
await api()
101+
await api.get_home_data()
102102

103103

104104
async def test_cache_logic() -> None:
105105
"""Test that the cache logic works correctly."""
106106
call_count = 0
107107

108-
async def mock_home_data_with_counter() -> HomeData:
108+
async def mock_home_data_with_counter(*args, **kwargs) -> HomeData:
109109
nonlocal call_count
110110
call_count += 1
111111
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
112112

113-
class TestCache:
114-
def __init__(self):
115-
self._data = CacheData()
116-
117-
async def get(self) -> CacheData:
118-
return self._data
119-
120-
async def set(self, value: CacheData) -> None:
121-
self._data = value
122-
123113
# First call happens during create_device_manager initialization
124-
device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache())
125-
assert call_count == 1
126-
127-
# Second call should use cache, not increment call_count
128-
devices2 = await device_manager.discover_devices()
129-
assert call_count == 1 # Should still be 1, not 2
130-
assert len(devices2) == 1
131-
132-
await device_manager.close()
133-
assert len(devices2) == 1
134-
135-
await device_manager.close()
114+
with patch(
115+
"roborock.devices.device_manager.RoborockApiClient.get_home_data_v3",
116+
side_effect=mock_home_data_with_counter,
117+
):
118+
device_manager = await create_device_manager(USER_PARAMS, cache=InMemoryCache())
119+
assert call_count == 1
120+
121+
# Second call should use cache, not increment call_count
122+
devices2 = await device_manager.discover_devices()
123+
assert call_count == 1 # Should still be 1, not 2
124+
assert len(devices2) == 1
125+
126+
await device_manager.close()
127+
assert len(devices2) == 1
128+
129+
await device_manager.close()

0 commit comments

Comments
 (0)