In [2]:
from cachetools import TTLCache

cache = TTLCache[int, int](maxsize=100, ttl=300)

In [3]:
from typing import Any, Protocol

from cachetools import Cache


class AsyncCacheResolver[K, V](Protocol):
    async def __call__(self, keys: set[K], *args, **kwargs) -> dict[K, V]:
        ...


def aget_partial_cache[K, V](cache: Cache[Any, Any]):
    def decorator(func: AsyncCacheResolver[K, V]):
        async def get_partial_cache_wrapper(keys: set[Any], *args, **kwargs):
            uncached_keys = keys.difference(cache.keys())
            uncached_dict = await func(uncached_keys, *args, **kwargs)
            cache.update(uncached_dict)
            return {k: cache[k] for k in keys}

        return get_partial_cache_wrapper

    return decorator


@aget_partial_cache(cache)
async def beans(keys: set[int], beans: int, *args, **kwargs):
    return {k: k * 2 for k in keys}


await beans({1, 2, 3}, 5)

{1: 2, 2: 4, 3: 6}