Skip to content

feat: Add an explicit module for caching #432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from roborock import RoborockException
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData, NetworkInfo, RoborockBase, UserData
from roborock.devices.cache import Cache, CacheData
from roborock.devices.device_manager import create_device_manager, create_home_data_api
from roborock.protocol import MessageParser
from roborock.util import run_sync
Expand Down Expand Up @@ -39,7 +40,7 @@ class ConnectionCache(RoborockBase):
network_info: dict[str, NetworkInfo] | None = None


class RoborockContext:
class RoborockContext(Cache):
roborock_file = Path("~/.roborock").expanduser()
_cache_data: ConnectionCache | None = None

Expand Down Expand Up @@ -68,6 +69,18 @@ def cache_data(self) -> ConnectionCache:
self.validate()
return self._cache_data

async def get(self) -> CacheData:
"""Get cached value."""
connection_cache = self.cache_data()
return CacheData(home_data=connection_cache.home_data, network_info=connection_cache.network_info or {})

async def set(self, value: CacheData) -> None:
"""Set value in the cache."""
connection_cache = self.cache_data()
connection_cache.home_data = value.home_data
connection_cache.network_info = value.network_info
self.update(connection_cache)


@click.option("-d", "--debug", default=False, count=True)
@click.version_option(package_name="python-roborock")
Expand Down Expand Up @@ -119,14 +132,8 @@ async def session(ctx, duration: int):

home_data_api = create_home_data_api(cache_data.email, cache_data.user_data)

async def home_data_cache() -> HomeData:
if cache_data.home_data is None:
cache_data.home_data = await home_data_api()
context.update(cache_data)
return cache_data.home_data

# Create device manager
device_manager = await create_device_manager(cache_data.user_data, home_data_cache)
device_manager = await create_device_manager(cache_data.user_data, home_data_api, context)

devices = await device_manager.get_devices()
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")
Expand Down
1 change: 1 addition & 0 deletions roborock/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
__all__ = [
"device",
"device_manager",
"cache",
]
57 changes: 57 additions & 0 deletions roborock/devices/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""This module provides caching functionality for the Roborock device management system.

This module defines a cache interface that you may use to cache device
information to avoid unnecessary API calls. Callers may implement
this interface to provide their own caching mechanism.
"""

from dataclasses import dataclass, field
from typing import Protocol

from roborock.containers import HomeData, NetworkInfo


@dataclass
class CacheData:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding local key would be very helpful here. That theoretically would then have everything you need for local usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is in home_data.devices.local_key?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes sorry - was conflating user data in my head - my bad

"""Data structure for caching device information."""

home_data: HomeData | None = None
"""Home data containing device and product information."""

network_info: dict[str, NetworkInfo] = field(default_factory=dict)
"""Network information indexed by device DUID."""


class Cache(Protocol):
"""Protocol for a cache that can store and retrieve values."""

async def get(self) -> CacheData:
"""Get cached value."""
...

async def set(self, value: CacheData) -> None:
"""Set value in the cache."""
...


class InMemoryCache(Cache):
"""In-memory cache implementation."""

def __init__(self):
self._data = CacheData()

async def get(self) -> CacheData:
return self._data

async def set(self, value: CacheData) -> None:
self._data = value


class NoCache(Cache):
"""No-op cache implementation."""

async def get(self) -> CacheData:
return CacheData()

async def set(self, value: CacheData) -> None:
pass
25 changes: 19 additions & 6 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient

from .cache import Cache, NoCache
from .channel import Channel
from .mqtt_channel import create_mqtt_channel
from .traits.dyad import DyadApi
Expand All @@ -32,8 +33,6 @@
"create_device_manager",
"create_home_data_api",
"DeviceManager",
"HomeDataApi",
"DeviceCreator",
]


Expand All @@ -57,19 +56,27 @@ def __init__(
home_data_api: HomeDataApi,
device_creator: DeviceCreator,
mqtt_session: MqttSession,
cache: Cache,
) -> None:
"""Initialize the DeviceManager with user data and optional cache storage.

This takes ownership of the MQTT session and will close it when the manager is closed.
"""
self._home_data_api = home_data_api
self._cache = cache
self._device_creator = device_creator
self._devices: dict[str, RoborockDevice] = {}
self._mqtt_session = mqtt_session

async def discover_devices(self) -> list[RoborockDevice]:
"""Discover all devices for the logged-in user."""
home_data = await self._home_data_api()
cache_data = await self._cache.get()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a way to force an update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see if we can address this in the cache implementation e.g. caller can flush their own cache implementation if they don't want caching. If that doesn't work we can add this as an explicit flag here?

The reason why i am thinking about this is because we also may need to flush network info as well sometimes, so it seems nice to invalidate the whole cache and try to refresh.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay sounds good!

if not cache_data.home_data:
_LOGGER.debug("No cached home data found, fetching from API")
cache_data.home_data = await self._home_data_api()
await self._cache.set(cache_data)
home_data = cache_data.home_data

device_products = home_data.device_products
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)

Expand Down Expand Up @@ -118,13 +125,19 @@ async def home_data_api() -> HomeData:
return home_data_api


async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi) -> DeviceManager:
async def create_device_manager(
user_data: UserData,
home_data_api: HomeDataApi,
cache: Cache | None = None,
) -> DeviceManager:
"""Convenience function to create and initialize a DeviceManager.

The Home Data is fetched using the provided home_data_api callable which
is exposed this way to allow for swapping out other implementations to
include caching or other optimizations.
"""
if cache is None:
cache = NoCache()

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_session = await create_mqtt_session(mqtt_params)
Expand All @@ -135,7 +148,7 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
# TODO: Define a registration mechanism/factory for v1 traits
match device.pv:
case DeviceVersion.V1:
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device)
channel = create_v1_channel(user_data, mqtt_params, mqtt_session, device, cache)
traits.append(StatusTrait(product, channel.rpc_channel))
case DeviceVersion.A01:
mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device)
Expand All @@ -150,6 +163,6 @@ def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> Roborock
raise NotImplementedError(f"Device {device.name} has unsupported version {device.pv}")
return RoborockDevice(device, channel, traits)

manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session, cache=cache)
await manager.discover_devices()
return manager
27 changes: 20 additions & 7 deletions roborock/devices/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from roborock.roborock_message import RoborockMessage
from roborock.roborock_typing import RoborockCommand

from .cache import Cache
from .channel import Channel
from .local_channel import LocalChannel, LocalSession, create_local_session
from .mqtt_channel import MqttChannel
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(
security_data: SecurityData,
mqtt_channel: MqttChannel,
local_session: LocalSession,
cache: Cache,
) -> None:
"""Initialize the V1Channel.

Expand All @@ -62,7 +64,7 @@ def __init__(
self._mqtt_unsub: Callable[[], None] | None = None
self._local_unsub: Callable[[], None] | None = None
self._callback: Callable[[RoborockMessage], None] | None = None
self._networking_info: NetworkInfo | None = None
self._cache = cache

@property
def is_connected(self) -> bool:
Expand Down Expand Up @@ -131,19 +133,26 @@ async def _get_networking_info(self) -> NetworkInfo:

This is a cloud only command used to get the local device's IP address.
"""
cache_data = await self._cache.get()
if cache_data.network_info and (network_info := cache_data.network_info.get(self._device_uid)):
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
return network_info
try:
return await self._mqtt_rpc_channel.send_command(
network_info = await self._mqtt_rpc_channel.send_command(
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
)
except RoborockException as e:
raise RoborockException(f"Network info failed for device {self._device_uid}") from e
_LOGGER.debug("Network info for device %s: %s", self._device_uid, network_info)
cache_data.network_info[self._device_uid] = network_info
await self._cache.set(cache_data)
return network_info

async def _local_connect(self) -> Callable[[], None]:
"""Set up local connection if possible."""
_LOGGER.debug("Attempting to connect to local channel for device %s", self._device_uid)
if self._networking_info is None:
self._networking_info = await self._get_networking_info()
host = self._networking_info.ip
networking_info = await self._get_networking_info()
host = networking_info.ip
_LOGGER.debug("Connecting to local channel at %s", host)
self._local_channel = self._local_session(host)
try:
Expand All @@ -168,10 +177,14 @@ def _on_local_message(self, message: RoborockMessage) -> None:


def create_v1_channel(
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
user_data: UserData,
mqtt_params: MqttParams,
mqtt_session: MqttSession,
device: HomeDataDevice,
cache: Cache,
) -> V1Channel:
"""Create a V1Channel for the given device."""
security_data = create_security_data(user_data.rriot)
mqtt_channel = MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
local_session = create_local_session(device.local_key)
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session)
return V1Channel(device.duid, security_data, mqtt_channel, local_session=local_session, cache=cache)
35 changes: 35 additions & 0 deletions tests/devices/test_device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from roborock.containers import HomeData, UserData
from roborock.devices.cache import CacheData, InMemoryCache
from roborock.devices.device_manager import create_device_manager, create_home_data_api
from roborock.exceptions import RoborockException

Expand Down Expand Up @@ -98,3 +99,37 @@ async def test_create_home_data_api_exception() -> None:

with pytest.raises(RoborockException, match="Test exception"):
await api()


async def test_cache_logic() -> None:
"""Test that the cache logic works correctly."""
call_count = 0

async def mock_home_data_with_counter() -> HomeData:
nonlocal call_count
call_count += 1
return HomeData.from_dict(mock_data.HOME_DATA_RAW)

class TestCache:
def __init__(self):
self._data = CacheData()

async def get(self) -> CacheData:
return self._data

async def set(self, value: CacheData) -> None:
self._data = value

# First call happens during create_device_manager initialization
device_manager = await create_device_manager(USER_DATA, mock_home_data_with_counter, cache=InMemoryCache())
assert call_count == 1

# Second call should use cache, not increment call_count
devices2 = await device_manager.discover_devices()
assert call_count == 1 # Should still be 1, not 2
assert len(devices2) == 1

await device_manager.close()
assert len(devices2) == 1

await device_manager.close()
49 changes: 49 additions & 0 deletions tests/devices/test_v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

from roborock.containers import NetworkInfo, RoborockStateCode, S5MaxStatus, UserData
from roborock.devices.cache import CacheData, InMemoryCache
from roborock.devices.local_channel import LocalChannel, LocalSession
from roborock.devices.mqtt_channel import MqttChannel
from roborock.devices.v1_channel import V1Channel
Expand Down Expand Up @@ -105,6 +106,7 @@ def setup_v1_channel(
security_data=TEST_SECURITY_DATA,
mqtt_channel=mock_mqtt_channel,
local_session=mock_local_session,
cache=InMemoryCache(),
)


Expand Down Expand Up @@ -408,6 +410,52 @@ async def test_v1_channel_networking_info_retrieved_during_connection(
mock_local_session.assert_called_once_with(mock_data.NETWORK_INFO["ip"])


async def test_v1_channel_networking_info_cached_during_connection(
mock_mqtt_channel: Mock,
mock_local_channel: Mock,
mock_local_session: Mock,
) -> None:
"""Test that networking information is cached and reused on subsequent connections."""

# Create a cache with pre-populated network info
cache_data = CacheData()
cache_data.network_info[TEST_DEVICE_UID] = TEST_NETWORKING_INFO

mock_cache = AsyncMock()
mock_cache.get.return_value = cache_data
mock_cache.set = AsyncMock()

# Setup: MQTT and local connections succeed
mock_mqtt_channel.subscribe.return_value = Mock()
mock_local_channel.subscribe.return_value = Mock()

# Create V1Channel with the mock cache
v1_channel = V1Channel(
device_uid=TEST_DEVICE_UID,
security_data=TEST_SECURITY_DATA,
mqtt_channel=mock_mqtt_channel,
local_session=mock_local_session,
cache=mock_cache,
)

# Subscribe - should use cached network info
await v1_channel.subscribe(Mock())

# Verify both connections are established
assert v1_channel.is_mqtt_connected
assert v1_channel.is_local_connected

# Verify network info was NOT requested via MQTT (cache hit)
mock_mqtt_channel.send_message.assert_not_called()

# Verify local session was created with the correct IP from cache
mock_local_session.assert_called_once_with(mock_data.NETWORK_INFO["ip"])

# Verify cache was accessed but not updated (cache hit)
mock_cache.get.assert_called_once()
mock_cache.set.assert_not_called()


# V1Channel edge cases tests


Expand Down Expand Up @@ -513,6 +561,7 @@ async def test_v1_channel_full_subscribe_and_command_flow(
security_data=TEST_SECURITY_DATA,
mqtt_channel=mock_mqtt_channel,
local_session=mock_local_session,
cache=InMemoryCache(),
)

# Mock network info for local connection
Expand Down
Loading