Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from roborock.device_features import DeviceFeatures
from roborock.devices.cache import Cache, CacheData
from roborock.devices.device import RoborockDevice
from roborock.devices.device_manager import DeviceManager, create_device_manager, create_home_data_api
from roborock.devices.device_manager import DeviceManager, UserParams, create_device_manager
from roborock.devices.traits import Trait
from roborock.devices.traits.v1 import V1TraitMixin
from roborock.devices.traits.v1.consumeable import ConsumableAttribute
Expand Down Expand Up @@ -135,8 +135,11 @@ async def ensure_device_manager(self) -> DeviceManager:
"""Ensure device manager is initialized."""
if self.device_manager is None:
cache_data = self.context.cache_data()
home_data_api = create_home_data_api(cache_data.email, cache_data.user_data)
self.device_manager = await create_device_manager(cache_data.user_data, home_data_api, self.context)
user_params = UserParams(
username=cache_data.email,
user_data=cache_data.user_data,
)
self.device_manager = await create_device_manager(user_params, cache=self.context)
# Cache devices for quick lookup
devices = await self.device_manager.get_devices()
self._devices = {device.duid: device for device in devices}
Expand Down
78 changes: 51 additions & 27 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from dataclasses import dataclass

import aiohttp

Expand All @@ -18,7 +19,7 @@
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient
from roborock.web_api import RoborockApiClient, UserWebApiClient

from .cache import Cache, NoCache
from .channel import Channel
Expand All @@ -30,12 +31,11 @@

__all__ = [
"create_device_manager",
"create_home_data_api",
"UserParams",
"DeviceManager",
]


HomeDataApi = Callable[[], Awaitable[HomeData]]
DeviceCreator = Callable[[HomeData, HomeDataDevice, HomeDataProduct], RoborockDevice]


Expand All @@ -53,7 +53,7 @@ class DeviceManager:

def __init__(
self,
home_data_api: HomeDataApi,
web_api: UserWebApiClient,
device_creator: DeviceCreator,
mqtt_session: MqttSession,
cache: Cache,
Expand All @@ -62,7 +62,7 @@ def __init__(

This takes ownership of the MQTT session and will close it when the manager is closed.
"""
self._home_data_api = home_data_api
self._web_api = web_api
self._cache = cache
self._device_creator = device_creator
self._devices: dict[str, RoborockDevice] = {}
Expand All @@ -73,7 +73,7 @@ async def discover_devices(self) -> list[RoborockDevice]:
cache_data = await self._cache.get()
if not cache_data.home_data:
_LOGGER.debug("No cached home data found, fetching from API")
cache_data.home_data = await self._home_data_api()
cache_data.home_data = await self._web_api.get_home_data()
await self._cache.set(cache_data)
home_data = cache_data.home_data

Expand Down Expand Up @@ -108,45 +108,69 @@ async def close(self) -> None:
await asyncio.gather(*tasks)


def create_home_data_api(
email: str, user_data: UserData, base_url: str | None = None, session: aiohttp.ClientSession | None = None
) -> HomeDataApi:
"""Create a home data API wrapper.
@dataclass
class UserParams:
"""Parameters for creating a new session with Roborock devices.

This function creates a wrapper around the Roborock API client to fetch
home data for the user.
These parameters include the username, user data for authentication,
and an optional base URL for the Roborock API. The `user_data` and `base_url`
parameters are obtained from `RoborockApiClient` during the login process.
"""
# Note: This will auto discover the API base URL. This can be improved
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
client = RoborockApiClient(username=email, base_url=base_url, session=session)

return create_home_data_from_api_client(client, user_data)
username: str
"""The username (email) used for logging in."""

user_data: UserData
"""This is the user data containing authentication information."""

base_url: str | None = None
"""Optional base URL for the Roborock API.

This is used to speed up connection times by avoiding the need to
discover the API base URL each time. If not provided, the API client
will attempt to discover it automatically which may take multiple requests.
"""


def create_home_data_from_api_client(client: RoborockApiClient, user_data: UserData) -> HomeDataApi:
def create_web_api_wrapper(
user_params: UserParams,
*,
cache: Cache | None = None,
session: aiohttp.ClientSession | None = None,
) -> UserWebApiClient:
"""Create a home data API wrapper from an existing API client."""

async def home_data_api() -> HomeData:
return await client.get_home_data_v3(user_data)
# Note: This will auto discover the API base URL. This can be improved
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session)

return home_data_api
return UserWebApiClient(client, user_params.user_data)


async def create_device_manager(
user_data: UserData,
home_data_api: HomeDataApi,
user_params: UserParams,
*,
cache: Cache | None = None,
map_parser_config: MapParserConfig | None = None,
session: aiohttp.ClientSession | None = None,
) -> DeviceManager:
"""Convenience function to create and initialize a DeviceManager.

The Home Data is fetched using the provided home_data_api callable which
is exposed this way to allow for swapping out other implementations to
include caching or other optimizations.
Args:
user_params: Parameters for creating the user session.
cache: Optional cache implementation to use for caching device data.
map_parser_config: Optional configuration for parsing maps.
session: Optional aiohttp ClientSession to use for HTTP requests.

Returns:
An initialized DeviceManager with discovered devices.
"""
if cache is None:
cache = NoCache()

web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
user_data = user_params.user_data

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_session = await create_lazy_mqtt_session(mqtt_params)

Expand Down Expand Up @@ -176,6 +200,6 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
return RoborockDevice(device, product, channel, trait)

manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache)
await manager.discover_devices()
return manager
18 changes: 18 additions & 0 deletions roborock/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,3 +707,21 @@ def _get_hawk_authentication(rriot: RRiot, url: str, formdata: dict | None = Non
)
mac = base64.b64encode(hmac.new(rriot.h.encode(), prestr.encode(), hashlib.sha256).digest()).decode()
return f'Hawk id="{rriot.u}",s="{rriot.s}",ts="{timestamp}",nonce="{nonce}",mac="{mac}"'


class UserWebApiClient:
"""Wrapper around RoborockApiClient to provide information for a specific user.
This binds a RoborockApiClient to a specific user context with the
provided UserData. This allows for easier access to user-specific data,
to avoid needing to pass UserData around and mock out the web API.
"""

def __init__(self, web_api: RoborockApiClient, user_data: UserData) -> None:
"""Initialize the wrapper with the API client and user data."""
self._web_api = web_api
self._user_data = user_data

async def get_home_data(self) -> HomeData:
"""Fetch home data using the API client."""
return await self._web_api.get_home_data_v3(self._user_data)
104 changes: 49 additions & 55 deletions tests/devices/test_device_manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Tests for the DeviceManager class."""

from collections.abc import Generator
from collections.abc import Generator, Iterator
from unittest.mock import AsyncMock, Mock, patch

import pytest

from roborock.data import HomeData, UserData
from roborock.devices.cache import CacheData, InMemoryCache
from roborock.devices.device_manager import create_device_manager, create_home_data_api
from roborock.devices.cache import InMemoryCache
from roborock.devices.device_manager import UserParams, create_device_manager, create_web_api_wrapper
from roborock.exceptions import RoborockException

from .. import mock_data

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
USER_PARAMS = UserParams(username="test_user", user_data=USER_DATA)
NETWORK_INFO = mock_data.NETWORK_INFO


Expand All @@ -33,32 +34,40 @@ def channel_fixture() -> Generator[Mock, None, None]:
yield mock_channel


async def home_home_data_no_devices() -> HomeData:
@pytest.fixture(name="home_data_no_devices")
def home_data_no_devices_fixture() -> Iterator[HomeData]:
"""Mock home data API that returns no devices."""
return HomeData(
id=1,
name="Test Home",
devices=[],
products=[],
)


async def mock_home_data() -> HomeData:
with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data:
home_data = HomeData(
id=1,
name="Test Home",
devices=[],
products=[],
)
mock_home_data.return_value = home_data
yield home_data


@pytest.fixture(name="home_data")
def home_data_fixture() -> Iterator[HomeData]:
"""Mock home data API that returns devices."""
return HomeData.from_dict(mock_data.HOME_DATA_RAW)
with patch("roborock.devices.device_manager.UserWebApiClient.get_home_data") as mock_home_data:
home_data = HomeData.from_dict(mock_data.HOME_DATA_RAW)
mock_home_data.return_value = home_data
yield home_data


async def test_no_devices() -> None:
async def test_no_devices(home_data_no_devices: HomeData) -> None:
"""Test the DeviceManager created with no devices returned from the API."""

device_manager = await create_device_manager(USER_DATA, home_home_data_no_devices)
device_manager = await create_device_manager(USER_PARAMS)
devices = await device_manager.get_devices()
assert devices == []


async def test_with_device() -> None:
async def test_with_device(home_data: HomeData) -> None:
"""Test the DeviceManager created with devices returned from the API."""
device_manager = await create_device_manager(USER_DATA, mock_home_data)
device_manager = await create_device_manager(USER_PARAMS)
devices = await device_manager.get_devices()
assert len(devices) == 1
assert devices[0].duid == "abc123"
Expand All @@ -72,64 +81,49 @@ async def test_with_device() -> None:
await device_manager.close()


async def test_get_non_existent_device() -> None:
async def test_get_non_existent_device(home_data: HomeData) -> None:
"""Test getting a non-existent device."""
device_manager = await create_device_manager(USER_DATA, mock_home_data)
device_manager = await create_device_manager(USER_PARAMS)
device = await device_manager.get_device("non_existent_duid")
assert device is None
await device_manager.close()


async def test_home_data_api_exception() -> None:
"""Test the home data API with an exception."""

async def home_data_api_exception() -> HomeData:
raise RoborockException("Test exception")

with pytest.raises(RoborockException, match="Test exception"):
await create_device_manager(USER_DATA, home_data_api_exception)


async def test_create_home_data_api_exception() -> None:
"""Test that exceptions from the home data API are propagated through the wrapper."""

with patch("roborock.devices.device_manager.RoborockApiClient.get_home_data_v3") as mock_get_home_data:
mock_get_home_data.side_effect = RoborockException("Test exception")
api = create_home_data_api(USER_DATA, mock_get_home_data)
user_params = UserParams(username="test_user", user_data=USER_DATA)
api = create_web_api_wrapper(user_params)

with pytest.raises(RoborockException, match="Test exception"):
await api()
await api.get_home_data()


async def test_cache_logic() -> None:
"""Test that the cache logic works correctly."""
call_count = 0

async def mock_home_data_with_counter() -> HomeData:
async def mock_home_data_with_counter(*args, **kwargs) -> HomeData:
nonlocal call_count
call_count += 1
return HomeData.from_dict(mock_data.HOME_DATA_RAW)

class TestCache:
def __init__(self):
self._data = CacheData()

async def get(self) -> CacheData:
return self._data

async def set(self, value: CacheData) -> None:
self._data = value

# First call happens during create_device_manager initialization
device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache())
assert call_count == 1

# Second call should use cache, not increment call_count
devices2 = await device_manager.discover_devices()
assert call_count == 1 # Should still be 1, not 2
assert len(devices2) == 1

await device_manager.close()
assert len(devices2) == 1

await device_manager.close()
with patch(
"roborock.devices.device_manager.RoborockApiClient.get_home_data_v3",
side_effect=mock_home_data_with_counter,
):
device_manager = await create_device_manager(USER_PARAMS, cache=InMemoryCache())
assert call_count == 1

# Second call should use cache, not increment call_count
devices2 = await device_manager.discover_devices()
assert call_count == 1 # Should still be 1, not 2
assert len(devices2) == 1

await device_manager.close()
assert len(devices2) == 1

await device_manager.close()
Loading