# Decorators Application — Memoization (Advanced Problems + Solutions)

This notebook contains advanced memoization exercises (with solutions) using decorators.

Topics:
- Correct memoization keys for `args/kwargs`
- Handling unhashable inputs via deep-freezing
- Bounded caches (LRU)
- TTL expiration
- Exception caching policy
- Async memoization with concurrency de-duplication
- Per-instance method memoization without memory leaks (weakrefs)
- Best-practice use of `functools.lru_cache`, `functools.cache`, and `cached_property`


In [31]:
import asyncio
import functools
import inspect
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Hashable
from weakref import WeakKeyDictionary


## Exercise 1 — Memoization basics: prove the repeated-work problem

Write a naive recursive Fibonacci implementation and show that it repeats work.

Requirements:
- Use 1-based indexing: F(1)=1, F(2)=1
- Track how many times the function is called for `n=10`

Then (in later exercises) you'll fix it with memoization.


In [32]:
# --- Solution (Exercise 1) ---

def count_calls(fn: Callable[..., Any]):
    """Simple call counter decorator used for measurement in this notebook."""
    count = 0

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        nonlocal count
        count += 1
        wrapper.call_count = count  # type: ignore[attr-defined]
        return fn(*args, **kwargs)

    wrapper.call_count = 0  # type: ignore[attr-defined]
    return wrapper

@count_calls
def fib_naive_counted(n: int) -> int:
    """Naive recursive Fibonacci (1-based), with correct recursive call counting."""
    if n <= 0:
        raise ValueError('n must be >= 1')
    return 1 if n <= 2 else fib_naive_counted(n - 1) + fib_naive_counted(n - 2)


### Demo (Exercise 1)
Naive recursion does a lot of repeated work; call counts grow quickly.


In [33]:
fib_naive_counted.call_count = 0  # type: ignore[attr-defined]
val = fib_naive_counted(10)
val, fib_naive_counted.call_count  # type: ignore[attr-defined]


(55, 109)

## Exercise 2 — A correct memoize decorator for args + kwargs

Implement `memoize`:

Requirements:
- Works for arbitrary `*args` and `**kwargs`
- Cache key must be stable regardless of kwargs order
- Preserves metadata using `functools.wraps`
- Exposes:
  - `wrapper.cache_clear()`
  - `wrapper.cache_info()` returning `{hits, misses, size}`

Best practice: fail loudly on unhashable inputs (handled in Exercise 3).


In [34]:
# --- Solution (Exercise 2) ---

def _default_key(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Hashable:
    # Raises TypeError if any element is unhashable.
    return (args, tuple(sorted(kwargs.items())))

def memoize(
    fn: Callable[..., Any] | None = None,
    *,
    key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None,
):
    """Memoize a function by caching results based on args/kwargs.

    Usage:
        @memoize
        def f(...): ...

        @memoize(key=custom_key)
        def g(...): ...
    """

    def decorate(target: Callable[..., Any]):
        cache: dict[Hashable, Any] = {}
        hits = 0
        misses = 0
        make_key = key if key is not None else _default_key

        @functools.wraps(target)
        def wrapper(*args, **kwargs):
            nonlocal hits, misses
            k = make_key(args, kwargs)
            if k in cache:
                hits += 1
                return cache[k]
            misses += 1
            val = target(*args, **kwargs)
            cache[k] = val
            return val

        def cache_clear():
            nonlocal hits, misses
            cache.clear()
            hits = 0
            misses = 0

        def cache_info():
            return {'hits': hits, 'misses': misses, 'size': len(cache)}

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info    # type: ignore[attr-defined]
        return wrapper

    if callable(fn):
        return decorate(fn)
    return decorate


## Exercise 3 — Handle unhashable inputs via deep-freezing (advanced)

Memoization keys must be hashable. Implement `deep_freeze(obj)` that converts:
- `list` -> `tuple` (recursively)
- `dict` -> sorted tuple of key/value pairs (recursively)
- `set` -> `frozenset` (recursively)

Then implement `memoize_frozen` that uses deep-freezing in its key builder.

Best practices:
- Only freeze known container types.
- Do not attempt to freeze arbitrary custom objects (treat as-is).
- Be explicit about limitations.


In [35]:
# --- Solution (Exercise 3) ---

def deep_freeze(obj: Any) -> Any:
    """Convert common mutable containers into immutable equivalents recursively.

    Supported:
      - list -> tuple
      - dict -> tuple(sorted((k, v), ...))
      - set -> frozenset
      - tuple -> tuple (elements frozen)

    Other objects are returned as-is.
    """
    if isinstance(obj, list):
        return tuple(deep_freeze(x) for x in obj)
    if isinstance(obj, dict):
        return tuple(sorted((deep_freeze(k), deep_freeze(v)) for k, v in obj.items()))
    if isinstance(obj, set):
        return frozenset(deep_freeze(x) for x in obj)
    if isinstance(obj, tuple):
        return tuple(deep_freeze(x) for x in obj)
    return obj

def _frozen_key(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Hashable:
    frozen_args = deep_freeze(args)
    frozen_kwargs = deep_freeze(kwargs)
    return (frozen_args, frozen_kwargs)

def memoize_frozen(fn: Callable[..., Any] | None = None):
    return memoize(fn, key=_frozen_key)  # type: ignore[arg-type]


## Exercise 4 — Bounded memoization: LRU cache decorator

Implement `lru_memoize(maxsize=128, thread_safe=True)`:

Requirements:
- Keeps at most `maxsize` items; evict **least-recently-used**
- Accepts args/kwargs keys (use `_default_key` from Exercise 2)
- Thread-safe when `thread_safe=True`
- Exposes:
  - `cache_clear()`, `cache_info()` -> `{hits, misses, size, maxsize}`

Best practice:
- Compute the function result outside the lock to avoid blocking.
- Only keep lock around cache mutation.


In [36]:
# --- Solution (Exercise 4) ---

def lru_memoize(
    *,
    maxsize: int = 128,
    thread_safe: bool = True,
    key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None,
):
    if maxsize < 1:
        raise ValueError('maxsize must be >= 1')
    make_key = key if key is not None else _default_key

    def decorate(fn: Callable[..., Any]):
        cache: OrderedDict[Hashable, Any] = OrderedDict()
        hits = 0
        misses = 0
        lock = threading.Lock() if thread_safe else None

        def _get(k: Hashable):
            nonlocal hits
            if k in cache:
                hits += 1
                cache.move_to_end(k, last=True)
                return True, cache[k]
            return False, None

        def _put(k: Hashable, v: Any):
            cache[k] = v
            cache.move_to_end(k, last=True)
            while len(cache) > maxsize:
                cache.popitem(last=False)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal misses
            k = make_key(args, kwargs)

            if lock is None:
                ok, val = _get(k)
                if ok:
                    return val
                misses += 1
                val = fn(*args, **kwargs)
                _put(k, val)
                return val

            # Locking path
            with lock:
                ok, val = _get(k)
                if ok:
                    return val
                misses += 1

            # Compute outside lock
            val = fn(*args, **kwargs)

            with lock:
                _put(k, val)
            return val

        def cache_clear():
            nonlocal hits, misses
            if lock is None:
                cache.clear()
                hits = 0
                misses = 0
                return
            with lock:
                cache.clear()
                hits = 0
                misses = 0

        def cache_info():
            if lock is None:
                return {'hits': hits, 'misses': misses, 'size': len(cache), 'maxsize': maxsize}
            with lock:
                return {'hits': hits, 'misses': misses, 'size': len(cache), 'maxsize': maxsize}

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info    # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 5 — TTL memoization (time-based expiration)

Implement `ttl_memoize(ttl_seconds, maxsize=128, time_func=time.monotonic)`:

Requirements:
- Cache entries expire after `ttl_seconds`
- Use LRU eviction when exceeding `maxsize`
- Provide `cache_clear()` and `cache_info()` -> `{hits, misses, size, maxsize, ttl_seconds}`

Best practice:
- Accept an injectable `time_func` so tests don't need real sleeping.
- Purge expired entries opportunistically on access.


In [37]:
# --- Solution (Exercise 5) ---

def ttl_memoize(
    *,
    ttl_seconds: float,
    maxsize: int = 128,
    time_func: Callable[[], float] = time.monotonic,
    key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None,
):
    if ttl_seconds <= 0:
        raise ValueError('ttl_seconds must be > 0')
    if maxsize < 1:
        raise ValueError('maxsize must be >= 1')
    make_key = key if key is not None else _default_key

    def decorate(fn: Callable[..., Any]):
        # OrderedDict: key -> (expires_at, value)
        cache: OrderedDict[Hashable, tuple[float, Any]] = OrderedDict()
        hits = 0
        misses = 0

        def _purge_expired(now: float):
            # LRU order != expiration order, so scan.
            expired = [k for k, (exp, _) in cache.items() if exp <= now]
            for k in expired:
                cache.pop(k, None)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal hits, misses
            k = make_key(args, kwargs)
            now = time_func()
            _purge_expired(now)

            if k in cache:
                hits += 1
                exp, val = cache.pop(k)
                cache[k] = (exp, val)
                return val

            misses += 1
            val = fn(*args, **kwargs)
            exp = now + ttl_seconds
            cache[k] = (exp, val)
            cache.move_to_end(k, last=True)
            while len(cache) > maxsize:
                cache.popitem(last=False)
            return val

        def cache_clear():
            nonlocal hits, misses
            cache.clear()
            hits = 0
            misses = 0

        def cache_info():
            return {
                'hits': hits,
                'misses': misses,
                'size': len(cache),
                'maxsize': maxsize,
                'ttl_seconds': ttl_seconds,
            }

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info    # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 6 — Exception caching policy

Implement `memoize_policy(cache_exceptions=False)`:

- If `cache_exceptions=False` (recommended default):
  - exceptions are **not** cached
  - if a call fails, later calls retry the computation
- If `cache_exceptions=True`:
  - cache exceptions and re-raise the cached exception for future calls

Best practice:
- Most application logic should not cache transient failures.


In [38]:
# --- Solution (Exercise 6) ---

@dataclass
class _CachedException:
    exc: BaseException

def memoize_policy(
    *,
    cache_exceptions: bool = False,
    key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None,
):
    make_key = key if key is not None else _default_key

    def decorate(fn: Callable[..., Any]):
        cache: dict[Hashable, Any] = {}

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            k = make_key(args, kwargs)
            if k in cache:
                v = cache[k]
                if isinstance(v, _CachedException):
                    raise v.exc
                return v

            try:
                v = fn(*args, **kwargs)
            except Exception as e:
                if cache_exceptions:
                    cache[k] = _CachedException(e)
                raise

            cache[k] = v
            return v

        def cache_clear():
            cache.clear()

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 7 — Async memoization with concurrency de-duplication

Implement `async_memoize()` for `async def` functions:

Goals:
- If multiple tasks call the function concurrently with the same args/kwargs, only compute once.
- Cache the **awaited result**, not the coroutine.
- If the computation raises, remove the cache entry so future calls can retry.
- Expose `cache_clear()` and `cache_info()`.

Hint:
- Cache `asyncio.Task` objects keyed by args/kwargs.


In [39]:
# --- Solution (Exercise 7) ---

def async_memoize(*, key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None):
    make_key = key if key is not None else _default_key

    def decorate(fn: Callable[..., Any]):
        if not asyncio.iscoroutinefunction(fn):
            raise TypeError('async_memoize can only decorate async functions')

        cache: dict[Hashable, asyncio.Task] = {}
        hits = 0
        misses = 0
        lock = asyncio.Lock()

        @functools.wraps(fn)
        async def wrapper(*args, **kwargs):
            nonlocal hits, misses
            k = make_key(args, kwargs)

            async with lock:
                if k in cache:
                    hits += 1
                    task = cache[k]
                else:
                    misses += 1
                    task = asyncio.create_task(fn(*args, **kwargs))
                    cache[k] = task

            try:
                return await task
            except Exception:
                # Do not cache failures by default; allow retries.
                async with lock:
                    if cache.get(k) is task:
                        cache.pop(k, None)
                raise

        def cache_clear():
            nonlocal hits, misses
            cache.clear()
            hits = 0
            misses = 0

        def cache_info():
            return {'hits': hits, 'misses': misses, 'size': len(cache)}

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info    # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 8 — Per-instance method memoization (avoid memory leaks)

Memoizing methods can accidentally keep instances alive if the cache strongly references `self`.

Implement `memoize_method` that:
- memoizes results per instance (each instance has its own cache)
- uses `WeakKeyDictionary` to avoid keeping instances alive
- supports args/kwargs keys
- preserves metadata

Hint:
- Use a descriptor with `__get__` to bind a wrapper per instance.


In [40]:
# --- Solution (Exercise 8) ---
# Supports unhashable instances (e.g., normal @dataclass) without memory leaks.

import weakref

class _BoundMemoizedMethod:
    def __init__(self, descriptor: "memoize_method", obj: Any):
        self._descriptor = descriptor
        self._obj = obj
        functools.update_wrapper(self, descriptor._fn)

    def __call__(self, *args, **kwargs):
        return self._descriptor._call(self._obj, *args, **kwargs)

    def cache_clear(self):
        self._descriptor._clear(self._obj)


class memoize_method:
    """Memoize an instance method per-instance without keeping instances alive.

    Note: WeakKeyDictionary requires hashable instances; many dataclasses are unhashable.
    This implementation keys by id(obj) and uses weakref.finalize to clean up.
    """

    def __init__(
        self,
        fn: Callable[..., Any],
        *,
        key: Callable[[tuple[Any, ...], dict[str, Any]], Hashable] | None = None,
    ):
        self._fn = fn
        self._key = key if key is not None else _default_key

        # id(obj) -> per-instance cache
        self._caches_by_id: dict[int, dict[Hashable, Any]] = {}

        # id(obj) -> finalizer that deletes the cache when obj is GC'd
        self._finalizers_by_id: dict[int, weakref.finalize] = {}

        functools.update_wrapper(self, fn)

    def _drop(self, oid: int) -> None:
        self._caches_by_id.pop(oid, None)
        self._finalizers_by_id.pop(oid, None)

    def _get_cache(self, obj: Any) -> dict[Hashable, Any]:
        oid = id(obj)
        cache = self._caches_by_id.get(oid)
        if cache is None:
            cache = {}
            self._caches_by_id[oid] = cache
            # Ensure cleanup when the instance is garbage-collected
            self._finalizers_by_id[oid] = weakref.finalize(obj, self._drop, oid)
        return cache

    def _call(self, obj: Any, *args, **kwargs):
        cache = self._get_cache(obj)
        k = self._key(args, kwargs)
        if k in cache:
            return cache[k]
        v = self._fn(obj, *args, **kwargs)
        cache[k] = v
        return v

    def _clear(self, obj: Any) -> None:
        cache = self._caches_by_id.get(id(obj))
        if cache is not None:
            cache.clear()

    def __get__(self, obj: Any, objtype=None):
        if obj is None:
            return self
        return _BoundMemoizedMethod(self, obj)


## Exercise 9 — Best practice: use the standard library when possible

Tasks:
1) Use `functools.lru_cache` to memoize Fibonacci.
2) Inspect `.cache_info()` and `.cache_clear()`.
3) Show that `inspect.signature` still works due to `__wrapped__`.
4) Compare `functools.cache` (unbounded) vs `lru_cache(maxsize=...)`.
5) Use `functools.cached_property` for per-instance memoization of an expensive property.


In [41]:
# --- Solution (Exercise 9) ---

from functools import cache, cached_property, lru_cache

@lru_cache(maxsize=256)
def fib_lru(n: int) -> int:
    if n <= 0:
        raise ValueError('n must be >= 1')
    return 1 if n <= 2 else fib_lru(n - 1) + fib_lru(n - 2)

@cache
def fib_cache(n: int) -> int:
    if n <= 0:
        raise ValueError('n must be >= 1')
    return 1 if n <= 2 else fib_cache(n - 1) + fib_cache(n - 2)

@dataclass
class Expensive:
    x: int
    calls: int = 0

    @cached_property
    def heavy(self) -> int:
        # Simulate an expensive computation
        object.__setattr__(self, 'calls', self.calls + 1)
        return self.x * self.x


# Verification / Tests

Run this cell to validate correctness and best-practice behavior.


In [42]:
# Exercise 1
fib_naive_counted.call_count = 0  # type: ignore[attr-defined]
assert fib_naive_counted(10) == 55
assert fib_naive_counted.call_count > 50  # type: ignore[attr-defined]

# Exercise 2: args/kwargs order stability + signature preservation
calls = {"n": 0}

@memoize
def add(a: int, b: int = 0, *, c: int = 0) -> int:
    calls["n"] += 1
    return a + b + c

add.cache_clear()  # type: ignore[attr-defined]
assert add(1, 2, c=3) == 6
assert add(1, 2, c=3) == 6
assert add(1, 2, **{"c": 3}) == 6
info = add.cache_info()  # type: ignore[attr-defined]
assert calls["n"] == 1
assert info["hits"] == 2 and info["misses"] == 1

# Robust signature check (works even with: from __future__ import annotations)
assert inspect.signature(add) == inspect.signature(add.__wrapped__)  # type: ignore[attr-defined]

# Exercise 3: unhashable inputs handled via freezing
calls2 = {"n": 0}

@memoize_frozen
def sum_payload(payload: dict[str, Any]) -> int:
    calls2["n"] += 1
    return sum(payload["values"]) + payload["bias"]

p1 = {"values": [1, 2, 3], "bias": 10}
p2 = {"bias": 10, "values": [1, 2, 3]}  # same meaning, different order
assert sum_payload(p1) == 16
assert sum_payload(p2) == 16
assert calls2["n"] == 1

# Exercise 4: LRU eviction
calls3 = {"n": 0}

@lru_memoize(maxsize=2)
def square(x: int) -> int:
    calls3["n"] += 1
    return x * x

square.cache_clear()  # type: ignore[attr-defined]
assert square(1) == 1
assert square(2) == 4
assert square(1) == 1  # hit
assert square(3) == 9  # evicts 2 (LRU)
assert square.cache_info()["size"] == 2  # type: ignore[attr-defined]
before = calls3["n"]
square(2)  # recompute (was evicted)
assert calls3["n"] == before + 1

# Exercise 5: TTL without real sleeping (injectable clock)
class FakeClock:
    def __init__(self):
        self.t = 0.0

    def now(self):
        return self.t

    def advance(self, dt: float):
        self.t += dt

clk = FakeClock()
calls4 = {"n": 0}

@ttl_memoize(ttl_seconds=5.0, maxsize=10, time_func=clk.now)
def inc(x: int) -> int:
    calls4["n"] += 1
    return x + 1

inc.cache_clear()  # type: ignore[attr-defined]
assert inc(1) == 2
assert inc(1) == 2
assert calls4["n"] == 1
clk.advance(6.0)
assert inc(1) == 2
assert calls4["n"] == 2

# Exercise 6: exception caching policy
state = {"fail": True, "calls": 0}

@memoize_policy(cache_exceptions=False)
def flaky(x: int) -> int:
    state["calls"] += 1
    if state["fail"]:
        raise RuntimeError("transient")
    return x * 2

try:
    flaky(2)
    raise AssertionError("Expected RuntimeError")
except RuntimeError:
    pass

state["fail"] = False
assert flaky(2) == 4
assert state["calls"] == 2

state2 = {"calls": 0}

@memoize_policy(cache_exceptions=True)
def always_fails(x: int) -> int:
    state2["calls"] += 1
    raise ValueError("no")

for _ in range(3):
    try:
        always_fails(1)
    except ValueError:
        pass

assert state2["calls"] == 1  # cached exception

# Exercise 7: async memoize de-duplicates concurrent calls
async_calls = {"n": 0}

@async_memoize()
async def slow_double(x: int) -> int:
    async_calls["n"] += 1
    await asyncio.sleep(0)
    return x * 2

async def _async_test():
    slow_double.cache_clear()  # type: ignore[attr-defined]
    res = await asyncio.gather(slow_double(5), slow_double(5), slow_double(5))
    assert res == [10, 10, 10]
    assert async_calls["n"] == 1
    info = slow_double.cache_info()  # type: ignore[attr-defined]
    assert info["hits"] == 2 and info["misses"] == 1

await _async_test()

# Exercise 8: per-instance memoization
@dataclass
class C:
    base: int
    calls: int = 0

    @memoize_method
    def compute(self, x: int) -> int:
        self.calls += 1
        return self.base + x

c1 = C(10)
c2 = C(100)
assert c1.compute(5) == 15
assert c1.compute(5) == 15
assert c1.calls == 1
assert c2.compute(5) == 105
assert c2.calls == 1

c1.compute.cache_clear()  # type: ignore[attr-defined]
assert c1.compute(5) == 15
assert c1.calls == 2

# Exercise 9: functools tools
fib_lru.cache_clear()
assert fib_lru(35) == 9227465
ci = fib_lru.cache_info()
assert ci.hits >= 0 and ci.misses >= 0

# Robust signature check (works even with: from __future__ import annotations)
assert inspect.signature(fib_lru) == inspect.signature(fib_lru.__wrapped__)  # type: ignore[attr-defined]

fib_cache.cache_clear()
assert fib_cache(35) == 9227465

e = Expensive(7)
assert e.calls == 0
assert e.heavy == 49
assert e.heavy == 49
assert e.calls == 1  # cached_property computed once

print("All checks passed ✅")


All checks passed ✅
