# Decorator Application (Decorator Class) — Advanced Problems with Solutions

This notebook is a problem set about writing **decorators as callable classes** (i.e., objects implementing `__call__`).

## Best practices you should use throughout
- Preserve wrapped function metadata with `functools.wraps` / `functools.update_wrapper`
- Keep decorators compatible with arbitrary signatures (`*args`, `**kwargs`)
- Prefer `time.perf_counter()` / `time.monotonic()` for timing
- Make state explicit and testable (avoid hidden globals)
- For decorators that wrap *methods*, understand descriptors (`__get__`)



---
## Setup
Run this cell once. Some problems use `logging`, `asyncio`, and timing helpers.


In [1]:
import asyncio
import functools
import inspect
import logging
import threading
import time
from collections import OrderedDict, deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, Hashable, Optional, Tuple, Type, TypeVar, get_type_hints
from weakref import WeakKeyDictionary

logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")

F = TypeVar("F", bound=Callable[..., Any])


---
## Problem 1 — Parameterized class decorator with good metadata

Create a **decorator class** `LoggedCall` that can be used like:

```python
@LoggedCall(prefix='[API]', log_args=True, log_return=False)
def f(x, y): ...
```

### Requirements
1. `LoggedCall(prefix, log_args, log_return, logger=None)` stores configuration on the instance.
2. When the wrapped function is called, log (via the `logging` module) a single line:
   - Always include prefix and function name
   - Optionally include args/kwargs (`log_args=True`)
   - Optionally include returned value (`log_return=True`)
3. Preserve metadata: `__name__`, `__doc__`, and `__wrapped__` should behave like `functools.wraps`.
4. Works for plain functions *and* instance methods.

### Hint
Use `functools.wraps(fn)` inside `__call__`.


In [2]:
# YOUR TURN (optional): implement LoggedCall
# class LoggedCall:
#     ...


### Solution


In [3]:
class LoggedCall:
    def __init__(
        self,
        prefix: str = "",
        *,
        log_args: bool = True,
        log_return: bool = False,
        logger: Optional[logging.Logger] = None,
    ) -> None:
        self.prefix = prefix
        self.log_args = log_args
        self.log_return = log_return
        self.logger = logger  # can be None; resolved per-function in __call__

    def __call__(self, fn: F) -> F:
        logger = self.logger or logging.getLogger(fn.__module__)

        @functools.wraps(fn)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            parts = []
            head = f"{self.prefix}{fn.__name__}"
            parts.append(head)

            if self.log_args:
                parts.append(f"args={args!r}")
                parts.append(f"kwargs={kwargs!r}")

            result = fn(*args, **kwargs)

            if self.log_return:
                parts.append(f"return={result!r}")

            logger.info(" | ".join(parts))
            return result

        return wrapper  # type: ignore[return-value]


@LoggedCall(prefix="[DEMO] ", log_args=True, log_return=True)
def add(a: int, b: int) -> int:
    """Adds two integers."""
    return a + b


class Greeter:
    @LoggedCall(prefix="[METHOD] ", log_args=False, log_return=True)
    def hello(self, name: str) -> str:
        return f"Hello, {name}!"


# Tests
assert add.__name__ == "add"
assert "Adds two integers" in (add.__doc__ or "")
assert hasattr(add, "__wrapped__")
assert add(2, 3) == 5

g = Greeter()
assert g.hello("Ada") == "Hello, Ada!"
print("Problem 1 OK")


INFO:__main__:[DEMO] add | args=(2, 3) | kwargs={} | return=5
INFO:__main__:[METHOD] hello | return='Hello, Ada!'


Problem 1 OK


---
## Problem 2 — Stateful call counter (thread-safe)

Write a decorator class `CallCounter` that counts how many times the wrapped function has been called.

### Requirements
1. Usage:
   ```python
   @CallCounter()
   def f(...): ...
   ```
2. The wrapper must have attributes:
   - `.calls` → current count (int)
   - `.reset()` → set count back to 0
3. Must be **thread-safe**.
4. Preserve metadata (`wraps`).


In [4]:
# YOUR TURN (optional): implement CallCounter
# class CallCounter:
#     ...


### Solution


In [5]:
class CallCounter:
    def __init__(self) -> None:
        self._lock = threading.Lock()

    def __call__(self, fn: F) -> F:
        calls = 0  # kept in closure; protected by lock

        @functools.wraps(fn)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            nonlocal calls
            with self._lock:
                calls += 1
                wrapper.calls = calls  # type: ignore[attr-defined]
            return fn(*args, **kwargs)

        def reset() -> None:
            nonlocal calls
            with self._lock:
                calls = 0
                wrapper.calls = 0  # type: ignore[attr-defined]

        wrapper.calls = 0  # type: ignore[attr-defined]
        wrapper.reset = reset  # type: ignore[attr-defined]
        return wrapper  # type: ignore[return-value]


@CallCounter()
def mul(a: int, b: int) -> int:
    return a * b


assert mul.calls == 0
assert mul(2, 4) == 8
assert mul.calls == 1
assert mul(3, 5) == 15
assert mul.calls == 2
mul.reset()
assert mul.calls == 0
print("Problem 2 OK")


Problem 2 OK


---
## Problem 3 — Decorator class usable with *or without* arguments

Implement a timing decorator class `timed` that supports both:

```python
@timed
def f(): ...

@timed(unit='ms')
def g(): ...
```

### Requirements
1. Default `unit='s'` or `'ms'`.
2. Wrapper should return the original function's result.
3. Expose last measured duration on the wrapper as `.last` (float, in chosen units).
4. Preserve metadata.

### Hint
This is easiest if the class's `__init__` can accept either a function or `None`.


In [6]:
# YOUR TURN (optional): implement timed
# class timed:
#     ...


### Solution


In [7]:
import functools
import time
from typing import Any, Callable, Optional, TypeVar

F = TypeVar("F", bound=Callable[..., Any])

class timed:
    """
    Timing decorator that works as:

        @timed
        def f(...): ...

    and:

        @timed(unit="ms")
        def g(...): ...

    After each call, the decorated callable exposes:
        .last  (float)  # duration in chosen units
    """

    def __init__(self, fn: Optional[F] = None, *, unit: str = "s") -> None:
        if unit not in ("s", "ms"):
            raise ValueError("unit must be 's' or 'ms'")
        self.unit = unit
        self.fn: Optional[F] = fn
        self.last: float = 0.0  # always available on the decorator object

        if fn is not None:
            # Make this instance look like the function (name, doc, __wrapped__, etc.)
            functools.update_wrapper(self, fn)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        # Case 1: used as @timed(unit="ms") -> first call receives the function to decorate
        if self.fn is None:
            if len(args) != 1 or not callable(args[0]) or kwargs:
                raise TypeError("Decorator not initialized: use as @timed or @timed(unit='ms').")
            fn = args[0]
            return type(self)(fn, unit=self.unit)

        # Case 2: normal function invocation
        factor = 1000.0 if self.unit == "ms" else 1.0
        start = time.perf_counter()
        try:
            return self.fn(*args, **kwargs)
        finally:
            self.last = (time.perf_counter() - start) * factor


# ---- tests ----

@timed
def busy_wait(n: int) -> int:
    s = 0
    for i in range(n):
        s += i
    return s

@timed(unit="ms")
def busy_wait_ms(n: int) -> int:
    s = 0
    for i in range(n):
        s += i
    return s

busy_wait(10_000)
assert busy_wait.last >= 0.0

busy_wait_ms(10_000)
assert busy_wait_ms.last >= 0.0

print("timed decorator OK", busy_wait.last, busy_wait_ms.last)


timed decorator OK 0.0004803999327123165 0.49349991604685783


---
## Problem 4 — Method decorator that needs a descriptor (`__get__`)

Sometimes you want the **decorator instance itself** to be the wrapper (instead of returning a nested function).
For methods, that means you need to implement the descriptor protocol.

Implement `PerInstanceMemoize` as a decorator **class** used like:

```python
class C:
    @PerInstanceMemoize
    def f(self, x): ...
```

### Requirements
1. Cache results **per instance** (each instance has its own cache).
2. Use `WeakKeyDictionary` to avoid memory leaks.
3. Keys should be based on `(args, frozenset(kwargs.items()))` **excluding `self`**.
4. Preserve metadata on the descriptor object (`__name__`, `__doc__`).
5. Should raise a clear `TypeError` if arguments are unhashable.


In [8]:
# YOUR TURN (optional): implement PerInstanceMemoize
# class PerInstanceMemoize:
#     ...


### Solution


In [9]:
class PerInstanceMemoize:
    def __init__(self, fn: F) -> None:
        self.fn = fn
        self._caches: "WeakKeyDictionary[object, Dict[Hashable, Any]]" = WeakKeyDictionary()
        functools.update_wrapper(self, fn)

    def __get__(self, instance: Any, owner: Any) -> Any:
        # Accessed on class -> return descriptor itself
        if instance is None:
            return self
        # Bind instance by returning a callable that passes instance explicitly
        return functools.partial(self.__call__, instance)

    def __call__(self, instance: Any, *args: Any, **kwargs: Any) -> Any:
        cache = self._caches.setdefault(instance, {})
        try:
            key = (args, frozenset(kwargs.items()))
        except TypeError as e:
            raise TypeError("Arguments to memoized method must be hashable.") from e

        if key in cache:
            return cache[key]
        result = self.fn(instance, *args, **kwargs)
        cache[key] = result
        return result


class Fib:
    def __init__(self) -> None:
        self.work = 0

    @PerInstanceMemoize
    def fib(self, n: int) -> int:
        self.work += 1
        if n < 2:
            return n
        return self.fib(n - 1) + self.fib(n - 2)


a = Fib()
b = Fib()
assert a.fib(10) == 55
assert a.work < 20  # memoization dramatically reduces work
assert b.fib(10) == 55
assert b.work < 20
print("Problem 4 OK")


Problem 4 OK


---
## Problem 5 — LRU + TTL cache as a decorator class

Implement a decorator class `LRUTTLCache(maxsize=128, ttl=1.0)`.

### Requirements
1. Caches results by arguments for the decorated function.
2. Evict using **LRU** when size exceeds `maxsize`.
3. Each entry expires after `ttl` seconds (use `time.monotonic()`).
4. Provide `.cache_info()` on the wrapper returning a dict with keys:
   - `hits`, `misses`, `size`, `maxsize`, `ttl`
5. Preserve metadata.

### Hint
`OrderedDict` is convenient for LRU.


In [10]:
# YOUR TURN (optional): implement LRUTTLCache
# class LRUTTLCache:
#     ...


### Solution


In [11]:
class LRUTTLCache:
    def __init__(self, *, maxsize: int = 128, ttl: float = 1.0) -> None:
        if maxsize <= 0:
            raise ValueError("maxsize must be positive.")
        if ttl <= 0:
            raise ValueError("ttl must be positive.")
        self.maxsize = maxsize
        self.ttl = ttl

    def __call__(self, fn: F) -> F:
        cache: "OrderedDict[Hashable, Tuple[float, Any]]" = OrderedDict()
        hits = 0
        misses = 0
        lock = threading.Lock()

        def make_key(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Hashable:
            try:
                return (args, frozenset(kwargs.items()))
            except TypeError as e:
                raise TypeError("Arguments to cached function must be hashable.") from e

        @functools.wraps(fn)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            nonlocal hits, misses
            now = time.monotonic()
            key = make_key(args, kwargs)
            with lock:
                if key in cache:
                    ts, value = cache[key]
                    if now - ts <= self.ttl:
                        hits += 1
                        cache.move_to_end(key, last=True)
                        return value
                    # expired
                    cache.pop(key, None)

                misses += 1

            value = fn(*args, **kwargs)

            with lock:
                cache[key] = (now, value)
                cache.move_to_end(key, last=True)
                while len(cache) > self.maxsize:
                    cache.popitem(last=False)
            return value

        def cache_info() -> Dict[str, Any]:
            with lock:
                return {
                    "hits": hits,
                    "misses": misses,
                    "size": len(cache),
                    "maxsize": self.maxsize,
                    "ttl": self.ttl,
                }

        wrapper.cache_info = cache_info  # type: ignore[attr-defined]
        return wrapper  # type: ignore[return-value]


@LRUTTLCache(maxsize=2, ttl=0.2)
def slow_square(x: int) -> int:
    time.sleep(0.01)
    return x * x


assert slow_square(2) == 4
assert slow_square(2) == 4  # should hit
info = slow_square.cache_info()
assert info["hits"] >= 1 and info["misses"] >= 1

time.sleep(0.25)  # expire
assert slow_square(2) == 4  # miss after ttl
print("Problem 5 OK")


Problem 5 OK


---
## Problem 6 — Retry decorator class for sync *and* async functions

Implement `Retry` that works for both synchronous and asynchronous callables.

### Requirements
1. Usage:
   ```python
   @Retry(tries=3, delay=0.01, backoff=2.0, exceptions=(ValueError,))
   def f(...): ...

   @Retry(...)
   async def g(...): ...
   ```
2. Retry up to `tries` total attempts.
3. Delay should multiply by `backoff` each retry.
4. Allow injecting `sleep_fn` (defaults to `time.sleep` for sync, `asyncio.sleep` for async).
5. Preserve metadata.


In [12]:
# YOUR TURN (optional): implement Retry
# class Retry:
#     ...


### Solution


In [13]:
class Retry:
    def __init__(
        self,
        *,
        tries: int = 3,
        delay: float = 0.0,
        backoff: float = 1.0,
        exceptions: Tuple[Type[BaseException], ...] = (Exception,),
        sleep_fn: Optional[Callable[[float], Any]] = None,
    ) -> None:
        if tries <= 0:
            raise ValueError("tries must be positive.")
        if delay < 0:
            raise ValueError("delay must be >= 0.")
        if backoff <= 0:
            raise ValueError("backoff must be > 0.")
        self.tries = tries
        self.delay = delay
        self.backoff = backoff
        self.exceptions = exceptions
        self.sleep_fn = sleep_fn

    def __call__(self, fn: F) -> F:
        is_async = inspect.iscoroutinefunction(fn)

        if is_async:
            async def _async_sleep(d: float) -> None:
                if self.sleep_fn is None:
                    await asyncio.sleep(d)
                else:
                    res = self.sleep_fn(d)
                    if inspect.isawaitable(res):
                        await res

            @functools.wraps(fn)
            async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
                d = self.delay
                last_exc: Optional[BaseException] = None
                for attempt in range(1, self.tries + 1):
                    try:
                        return await fn(*args, **kwargs)  # type: ignore[misc]
                    except self.exceptions as e:
                        last_exc = e
                        if attempt == self.tries:
                            raise
                        if d:
                            await _async_sleep(d)
                        d *= self.backoff
                raise last_exc  # pragma: no cover

            return async_wrapper  # type: ignore[return-value]

        # sync
        def _sleep(d: float) -> None:
            if self.sleep_fn is None:
                time.sleep(d)
            else:
                self.sleep_fn(d)

        @functools.wraps(fn)
        def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
            d = self.delay
            last_exc: Optional[BaseException] = None
            for attempt in range(1, self.tries + 1):
                try:
                    return fn(*args, **kwargs)
                except self.exceptions as e:
                    last_exc = e
                    if attempt == self.tries:
                        raise
                    if d:
                        _sleep(d)
                    d *= self.backoff
            raise last_exc  # pragma: no cover

        return sync_wrapper  # type: ignore[return-value]


# Tests (sync) with deterministic sleep
attempts = {"n": 0}
def fake_sleep(_: float) -> None:
    return None

@Retry(tries=4, delay=0.01, backoff=2.0, exceptions=(ValueError,), sleep_fn=fake_sleep)
def flaky() -> int:
    attempts["n"] += 1
    if attempts["n"] < 3:
        raise ValueError("nope")
    return 42

assert flaky() == 42
assert attempts["n"] == 3

# Tests (async) - Jupyter-safe
async_attempts = {"n": 0}

@Retry(tries=3, delay=0.0, exceptions=(RuntimeError,))
async def aflaky() -> str:
    async_attempts["n"] += 1
    if async_attempts["n"] < 2:
        raise RuntimeError("nope")
    return "ok"

# In notebooks, do NOT use asyncio.run; use top-level await:
assert await aflaky() == "ok"
assert async_attempts["n"] == 2
print("Problem 6 OK (async)")



Problem 6 OK (async)


---
## Problem 7 — Token-bucket rate limiter decorator class

Implement `RateLimit(rate, per, key=None)` using a token bucket.

### Requirements
1. `rate` tokens are refilled every `per` seconds (continuous refill).
2. Each call consumes **one** token. If no token is available, raise `RuntimeError`.
3. Support an optional `key` function that derives a bucket key from `*args, **kwargs`.
   - If `key` is `None`, treat all calls as the same bucket.
4. Must be thread-safe.
5. Preserve metadata.


In [14]:
# YOUR TURN (optional): implement RateLimit
# class RateLimit:
#     ...


### Solution


In [15]:
class RateLimit:
    def __init__(self, *, rate: float, per: float, key: Optional[Callable[..., Hashable]] = None) -> None:
        if rate <= 0:
            raise ValueError("rate must be > 0.")
        if per <= 0:
            raise ValueError("per must be > 0.")
        self.rate = float(rate)
        self.per = float(per)
        self.key = key
        self._lock = threading.Lock()
        # key -> (tokens, last_time)
        self._state: Dict[Hashable, Tuple[float, float]] = {}

    def __call__(self, fn: F) -> F:
        @functools.wraps(fn)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            bucket = self.key(*args, **kwargs) if self.key else "__global__"
            now = time.monotonic()
            with self._lock:
                tokens, last = self._state.get(bucket, (self.rate, now))
                # refill
                elapsed = now - last
                tokens = min(self.rate, tokens + elapsed * (self.rate / self.per))
                if tokens < 1.0:
                    raise RuntimeError(f"Rate limit exceeded for bucket {bucket!r}")
                tokens -= 1.0
                self._state[bucket] = (tokens, now)
            return fn(*args, **kwargs)

        return wrapper  # type: ignore[return-value]


@RateLimit(rate=2, per=0.5)
def ping() -> str:
    return "pong"

# First two should pass quickly
assert ping() == "pong"
assert ping() == "pong"
# Third should fail if within same window
try:
    ping()
    raise AssertionError("Expected rate limit exception")
except RuntimeError:
    pass

time.sleep(0.6)  # refill enough
assert ping() == "pong"
print("Problem 7 OK")


Problem 7 OK


---
## Problem 8 — Runtime type checking decorator class

Write a decorator class `EnforceTypes` that checks arguments and return values using annotations.

### Requirements
1. Use `typing.get_type_hints` to obtain evaluated types.
2. Use `inspect.signature` to bind arguments and apply defaults.
3. For each parameter with an annotation, verify `isinstance(value, annotation)`.
   - If the annotation is `typing.Any`, skip.
4. If a return annotation exists, check it too.
5. Raise `TypeError` with a helpful message showing which name failed.
6. Preserve metadata.

> Note: This problem intentionally focuses on basic `isinstance`-style checks.


In [16]:
# YOUR TURN (optional): implement EnforceTypes
# class EnforceTypes:
#     ...


### Solution


In [17]:
class EnforceTypes:
    def __init__(self, *, check_return: bool = True) -> None:
        self.check_return = check_return

    def __call__(self, fn: F) -> F:
        sig = inspect.signature(fn)
        hints = get_type_hints(fn)
        ret_type = hints.get("return", None)

        def should_skip(t: Any) -> bool:
            return t is Any

        @functools.wraps(fn)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()

            for name, value in bound.arguments.items():
                if name in hints and not should_skip(hints[name]):
                    expected = hints[name]
                    if not isinstance(value, expected):
                        raise TypeError(
                            f"{fn.__name__}(): argument '{name}' expected {expected}, got {type(value)}"
                        )

            result = fn(*args, **kwargs)

            if self.check_return and (ret_type is not None) and not should_skip(ret_type):
                if not isinstance(result, ret_type):
                    raise TypeError(
                        f"{fn.__name__}(): return expected {ret_type}, got {type(result)}"
                    )
            return result

        return wrapper  # type: ignore[return-value]


@EnforceTypes()
def repeat(s: str, n: int) -> str:
    return s * n

assert repeat("ha", 3) == "hahaha"
try:
    repeat("ha", "3")  # type: ignore[arg-type]
    raise AssertionError("Expected TypeError")
except TypeError:
    pass

print("Problem 8 OK")


Problem 8 OK
