diff --git a/roborock/cli.py b/roborock/cli.py index df3fe1e9..c32aac80 100644 --- a/roborock/cli.py +++ b/roborock/cli.py @@ -46,7 +46,7 @@ from roborock.device_features import DeviceFeatures from roborock.devices.cache import Cache, CacheData from roborock.devices.device import RoborockDevice -from roborock.devices.device_manager import DeviceManager, create_device_manager, create_home_data_api +from roborock.devices.device_manager import DeviceManager, UserParams, create_device_manager from roborock.devices.traits import Trait from roborock.devices.traits.v1 import V1TraitMixin from roborock.devices.traits.v1.consumeable import ConsumableAttribute @@ -135,8 +135,11 @@ async def ensure_device_manager(self) -> DeviceManager: """Ensure device manager is initialized.""" if self.device_manager is None: cache_data = self.context.cache_data() - home_data_api = create_home_data_api(cache_data.email, cache_data.user_data) - self.device_manager = await create_device_manager(cache_data.user_data, home_data_api, self.context) + user_params = UserParams( + username=cache_data.email, + user_data=cache_data.user_data, + ) + self.device_manager = await create_device_manager(user_params, cache=self.context) # Cache devices for quick lookup devices = await self.device_manager.get_devices() self._devices = {device.duid: device for device in devices} diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 5218a839..33f4ed41 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -3,7 +3,8 @@ import asyncio import enum import logging -from collections.abc import Awaitable, Callable +from collections.abc import Callable +from dataclasses import dataclass import aiohttp @@ -18,7 +19,7 @@ from roborock.mqtt.roborock_session import create_lazy_mqtt_session from roborock.mqtt.session import MqttSession from roborock.protocol import create_mqtt_params -from roborock.web_api import RoborockApiClient +from roborock.web_api import RoborockApiClient, UserWebApiClient from .cache import Cache, NoCache from .channel import Channel @@ -30,12 +31,11 @@ __all__ = [ "create_device_manager", - "create_home_data_api", + "UserParams", "DeviceManager", ] -HomeDataApi = Callable[[], Awaitable[HomeData]] DeviceCreator = Callable[[HomeData, HomeDataDevice, HomeDataProduct], RoborockDevice] @@ -53,7 +53,7 @@ class DeviceManager: def __init__( self, - home_data_api: HomeDataApi, + web_api: UserWebApiClient, device_creator: DeviceCreator, mqtt_session: MqttSession, cache: Cache, @@ -62,7 +62,7 @@ def __init__( This takes ownership of the MQTT session and will close it when the manager is closed. """ - self._home_data_api = home_data_api + self._web_api = web_api self._cache = cache self._device_creator = device_creator self._devices: dict[str, RoborockDevice] = {} @@ -73,7 +73,7 @@ async def discover_devices(self) -> list[RoborockDevice]: cache_data = await self._cache.get() if not cache_data.home_data: _LOGGER.debug("No cached home data found, fetching from API") - cache_data.home_data = await self._home_data_api() + cache_data.home_data = await self._web_api.get_home_data() await self._cache.set(cache_data) home_data = cache_data.home_data @@ -108,45 +108,69 @@ async def close(self) -> None: await asyncio.gather(*tasks) -def create_home_data_api( - email: str, user_data: UserData, base_url: str | None = None, session: aiohttp.ClientSession | None = None -) -> HomeDataApi: - """Create a home data API wrapper. +@dataclass +class UserParams: + """Parameters for creating a new session with Roborock devices. - This function creates a wrapper around the Roborock API client to fetch - home data for the user. + These parameters include the username, user data for authentication, + and an optional base URL for the Roborock API. The `user_data` and `base_url` + parameters are obtained from `RoborockApiClient` during the login process. """ - # Note: This will auto discover the API base URL. This can be improved - # by caching this next to `UserData` if needed to avoid unnecessary API calls. - client = RoborockApiClient(username=email, base_url=base_url, session=session) - return create_home_data_from_api_client(client, user_data) + username: str + """The username (email) used for logging in.""" + + user_data: UserData + """This is the user data containing authentication information.""" + + base_url: str | None = None + """Optional base URL for the Roborock API. + + This is used to speed up connection times by avoiding the need to + discover the API base URL each time. If not provided, the API client + will attempt to discover it automatically which may take multiple requests. + """ -def create_home_data_from_api_client(client: RoborockApiClient, user_data: UserData) -> HomeDataApi: +def create_web_api_wrapper( + user_params: UserParams, + *, + cache: Cache | None = None, + session: aiohttp.ClientSession | None = None, +) -> UserWebApiClient: """Create a home data API wrapper from an existing API client.""" - async def home_data_api() -> HomeData: - return await client.get_home_data_v3(user_data) + # Note: This will auto discover the API base URL. This can be improved + # by caching this next to `UserData` if needed to avoid unnecessary API calls. + client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session) - return home_data_api + return UserWebApiClient(client, user_params.user_data) async def create_device_manager( - user_data: UserData, - home_data_api: HomeDataApi, + user_params: UserParams, + *, cache: Cache | None = None, map_parser_config: MapParserConfig | None = None, + session: aiohttp.ClientSession | None = None, ) -> DeviceManager: """Convenience function to create and initialize a DeviceManager. - The Home Data is fetched using the provided home_data_api callable which - is exposed this way to allow for swapping out other implementations to - include caching or other optimizations. + Args: + user_params: Parameters for creating the user session. + cache: Optional cache implementation to use for caching device data. + map_parser_config: Optional configuration for parsing maps. + session: Optional aiohttp ClientSession to use for HTTP requests. + + Returns: + An initialized DeviceManager with discovered devices. """ if cache is None: cache = NoCache() + web_api = create_web_api_wrapper(user_params, session=session, cache=cache) + user_data = user_params.user_data + mqtt_params = create_mqtt_params(user_data.rriot) mqtt_session = await create_lazy_mqtt_session(mqtt_params) @@ -176,6 +200,6 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}") return RoborockDevice(device, product, channel, trait) - manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache) + manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache) await manager.discover_devices() return manager diff --git a/roborock/web_api.py b/roborock/web_api.py index 2806163a..d4163b3b 100644 --- a/roborock/web_api.py +++ b/roborock/web_api.py @@ -707,3 +707,21 @@ def _get_hawk_authentication(rriot: RRiot, url: str, formdata: dict | None = Non ) mac = base64.b64encode(hmac.new(rriot.h.encode(), prestr.encode(), hashlib.sha256).digest()).decode() return f'Hawk id="{rriot.u}",s="{rriot.s}",ts="{timestamp}",nonce="{nonce}",mac="{mac}"' + + +class UserWebApiClient: + """Wrapper around RoborockApiClient to provide information for a specific user. + + This binds a RoborockApiClient to a specific user context with the + provided UserData. This allows for easier access to user-specific data, + to avoid needing to pass UserData around and mock out the web API. + """ + + def __init__(self, web_api: RoborockApiClient, user_data: UserData) -> None: + """Initialize the wrapper with the API client and user data.""" + self._web_api = web_api + self._user_data = user_data + + async def get_home_data(self) -> HomeData: + """Fetch home data using the API client.""" + return await self._web_api.get_home_data_v3(self._user_data) diff --git a/tests/devices/test_device_manager.py b/tests/devices/test_device_manager.py index 57a057a8..24065609 100644 --- a/tests/devices/test_device_manager.py +++ b/tests/devices/test_device_manager.py @@ -1,18 +1,19 @@ """Tests for the DeviceManager class.""" -from collections.abc import Generator +from collections.abc import Generator, Iterator from unittest.mock import AsyncMock, Mock, patch import pytest from roborock.data import HomeData, UserData -from roborock.devices.cache import CacheData, InMemoryCache -from roborock.devices.device_manager import create_device_manager, create_home_data_api +from roborock.devices.cache import InMemoryCache +from roborock.devices.device_manager import UserParams, create_device_manager, create_web_api_wrapper from roborock.exceptions import RoborockException from .. import mock_data USER_DATA = UserData.from_dict(mock_data.USER_DATA) +USER_PARAMS = UserParams(username="test_user", user_data=USER_DATA) NETWORK_INFO = mock_data.NETWORK_INFO @@ -33,32 +34,40 @@ def channel_fixture() -> Generator[Mock, None, None]: yield mock_channel -async def home_home_data_no_devices() -> HomeData: +@pytest.fixture(name="home_data_no_devices") +def home_data_no_devices_fixture() -> Iterator[HomeData]: """Mock home data API that returns no devices.""" - return HomeData( - id=1, - name="Test Home", - devices=[], - products=[], - ) - - -async def mock_home_data() -> HomeData: + with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data: + home_data = HomeData( + id=1, + name="Test Home", + devices=[], + products=[], + ) + mock_home_data.return_value = home_data + yield home_data + + +@pytest.fixture(name="home_data") +def home_data_fixture() -> Iterator[HomeData]: """Mock home data API that returns devices.""" - return HomeData.from_dict(mock_data.HOME_DATA_RAW) + with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data: + home_data = HomeData.from_dict(mock_data.HOME_DATA_RAW) + mock_home_data.return_value = home_data + yield home_data -async def test_no_devices() -> None: +async def test_no_devices(home_data_no_devices: HomeData) -> None: """Test the DeviceManager created with no devices returned from the API.""" - device_manager = await create_device_manager(USER_DATA, home_home_data_no_devices) + device_manager = await create_device_manager(USER_PARAMS) devices = await device_manager.get_devices() assert devices == [] -async def test_with_device() -> None: +async def test_with_device(home_data: HomeData) -> None: """Test the DeviceManager created with devices returned from the API.""" - device_manager = await create_device_manager(USER_DATA, mock_home_data) + device_manager = await create_device_manager(USER_PARAMS) devices = await device_manager.get_devices() assert len(devices) == 1 assert devices[0].duid == "abc123" @@ -72,64 +81,49 @@ async def test_with_device() -> None: await device_manager.close() -async def test_get_non_existent_device() -> None: +async def test_get_non_existent_device(home_data: HomeData) -> None: """Test getting a non-existent device.""" - device_manager = await create_device_manager(USER_DATA, mock_home_data) + device_manager = await create_device_manager(USER_PARAMS) device = await device_manager.get_device("non_existent_duid") assert device is None await device_manager.close() -async def test_home_data_api_exception() -> None: - """Test the home data API with an exception.""" - - async def home_data_api_exception() -> HomeData: - raise RoborockException("Test exception") - - with pytest.raises(RoborockException, match="Test exception"): - await create_device_manager(USER_DATA, home_data_api_exception) - - async def test_create_home_data_api_exception() -> None: """Test that exceptions from the home data API are propagated through the wrapper.""" with patch("roborock.devices.device_manager.RoborockApiClient.get_home_data_v3") as mock_get_home_data: mock_get_home_data.side_effect = RoborockException("Test exception") - api = create_home_data_api(USER_DATA, mock_get_home_data) + user_params = UserParams(username="test_user", user_data=USER_DATA) + api = create_web_api_wrapper(user_params) with pytest.raises(RoborockException, match="Test exception"): - await api() + await api.get_home_data() async def test_cache_logic() -> None: """Test that the cache logic works correctly.""" call_count = 0 - async def mock_home_data_with_counter() -> HomeData: + async def mock_home_data_with_counter(*args, **kwargs) -> HomeData: nonlocal call_count call_count += 1 return HomeData.from_dict(mock_data.HOME_DATA_RAW) - class TestCache: - def __init__(self): - self._data = CacheData() - - async def get(self) -> CacheData: - return self._data - - async def set(self, value: CacheData) -> None: - self._data = value - # First call happens during create_device_manager initialization - device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache()) - assert call_count == 1 - - # Second call should use cache, not increment call_count - devices2 = await device_manager.discover_devices() - assert call_count == 1 # Should still be 1, not 2 - assert len(devices2) == 1 - - await device_manager.close() - assert len(devices2) == 1 - - await device_manager.close() + with patch( + "roborock.devices.device_manager.RoborockApiClient.get_home_data_v3", + side_effect=mock_home_data_with_counter, + ): + device_manager = await create_device_manager(USER_PARAMS, cache=InMemoryCache()) + assert call_count == 1 + + # Second call should use cache, not increment call_count + devices2 = await device_manager.discover_devices() + assert call_count == 1 # Should still be 1, not 2 + assert len(devices2) == 1 + + await device_manager.close() + assert len(devices2) == 1 + + await device_manager.close()