# Decorators 2 — Advanced Practice (with Solutions)

This notebook focuses on **decorator factories** (decorators with parameters), metadata preservation with `functools.wraps`,
and production-style patterns (sync/async support, optional-argument decorators, stacking, and introspection).

**How to use:** Read each problem statement, then study (and run) the solution + tests.

> Tip: In the solutions, pay attention to **where code runs**: decoration time vs call time.


In [1]:
from __future__ import annotations

import asyncio
import inspect
import random
import threading
import time
from collections import deque
from contextvars import ContextVar
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, get_args, get_origin

T = TypeVar("T")

def assert_raises(exc_type: type[BaseException], fn: Callable[..., Any], *args: Any, **kwargs: Any) -> BaseException:
    """Tiny helper for assertion-style testing (no pytest needed)."""
    try:
        fn(*args, **kwargs)
    except exc_type as ex:
        return ex
    else:
        raise AssertionError(f"Expected {exc_type.__name__} to be raised")

def pretty_sig(fn: Callable[..., Any]) -> str:
    """Convenience for showing a function's signature (after decoration)."""
    return f"{fn.__name__}{inspect.signature(fn)}"


## Problem 1 — Parameterized `@timed(...)` that supports **sync and async**

Create a decorator factory `timed(...)` that measures **average runtime** over `num_reps` repetitions.

Requirements:
- Use `@wraps` to preserve metadata (`__name__`, `__doc__`, and signature).
- Support both regular functions **and** `async def` functions.
- Parameters:
  - `num_reps: int = 1` (must be `>= 1`)
  - `clock: Callable[[], float] = time.perf_counter`
  - `logger: Callable[[str], Any] = print`
- Print one line: `Avg run time: ...s (N reps)`.

Bonus:
- If `logger` is `None`, do not log anything.

### Solution


In [2]:
def timed(
    num_reps: int = 1,
    *,
    clock: Callable[[], float] = time.perf_counter,
    logger: Optional[Callable[[str], Any]] = print,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Decorator factory that measures average runtime over num_reps runs (sync + async)."""
    if not isinstance(num_reps, int) or num_reps < 1:
        raise ValueError("num_reps must be an int >= 1")

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            async def async_inner(*args: Any, **kwargs: Any) -> Any:
                total = 0.0
                result: Any = None
                for _ in range(num_reps):
                    start = clock()
                    result = await fn(*args, **kwargs)
                    total += (clock() - start)
                avg = total / num_reps
                if logger is not None:
                    logger(f"Avg run time: {avg:.6f}s ({num_reps} reps)")
                return result

            return async_inner

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            total = 0.0
            result: Any = None
            for _ in range(num_reps):
                start = clock()
                result = fn(*args, **kwargs)
                total += (clock() - start)
            avg = total / num_reps
            if logger is not None:
                logger(f"Avg run time: {avg:.6f}s ({num_reps} reps)")
            return result

        return inner

    return decorator


In [3]:
# Tests / demo (sync)
@timed(3, logger=None)
def add(a: int, b: int) -> int:
    """Add two integers."""
    time.sleep(0.01)
    return a + b

assert add(1, 2) == 3
assert add.__name__ == "add"
assert "Add two integers" in (add.__doc__ or "")
# assert str(inspect.signature(add)) == "(a: int, b: int) -> int"  # future import - all annotations are stored as strings, so inspect.signature renders them with quotes
sig = str(inspect.signature(add))
assert sig in {"(a: int, b: int) -> int", "(a: 'int', b: 'int') -> 'int'"}

# advanced:
# sig = inspect.signature(add)
# assert sig.parameters["a"].annotation in (int, "int")
# assert sig.parameters["b"].annotation in (int, "int")
# assert sig.return_annotation in (int, "int")


pretty_sig(add)


"add(a: 'int', b: 'int') -> 'int'"

In [4]:
# Tests / demo (async)
@timed(2, logger=None)
async def async_add(a: int, b: int) -> int:
    await asyncio.sleep(0.01)
    return a + b

assert inspect.iscoroutinefunction(async_add)
result = await async_add(10, 20)
assert result == 30
pretty_sig(async_add)


"async_add(a: 'int', b: 'int') -> 'int'"

## Problem 2 — Decorator that works as `@repeat` **or** `@repeat(...)`

Implement `repeat` so it can be used **with or without parentheses**:

```python
@repeat
def f(...):
    ...

@repeat(num_reps=3)
def g(...):
    ...
```

Requirements:
- Preserve metadata with `@wraps`.
- Execute the function `num_reps` times and return the **last** result.
- Parameter: `num_reps: int = 2` (only when used with parentheses).

### Solution


In [5]:
def repeat(_fn: Optional[Callable[..., Any]] = None, *, num_reps: int = 2) -> Callable[..., Any]:
    """Decorator usable with or without parentheses."""
    if not isinstance(num_reps, int) or num_reps < 1:
        raise ValueError("num_reps must be an int >= 1")

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            result: Any = None
            for _ in range(num_reps):
                result = fn(*args, **kwargs)
            return result
        return inner

    # If called as @repeat, Python passes the function as the first argument.
    if _fn is not None:
        return decorator(_fn)

    # If called as @repeat(...), return the real decorator.
    return decorator


In [6]:
calls: list[int] = []

@repeat
def bump() -> int:
    calls.append(1)
    return len(calls)

@repeat(num_reps=3)
def bump3() -> int:
    calls.append(1)
    return len(calls)

assert bump() == 2          # default num_reps=2 for @repeat => 2 calls, returns last => 2
assert len(calls) == 2
assert bump3() == 5         # 3 more calls => total 5
assert len(calls) == 5
pretty_sig(bump3)


"bump3() -> 'int'"

## Problem 3 — A parameterized `@debug(...)` that prints **nested call traces**

Write a decorator factory `debug(prefix='DBG', show_return=True)` that logs:
- function name
- args/kwargs
- (optionally) the return value

Hard part: pretty indentation for nested/recursive calls **without using globals**.
Use `contextvars.ContextVar` so indentation works correctly across threads and async tasks.

### Solution


In [7]:
_call_depth: ContextVar[int] = ContextVar("_call_depth", default=0)

def debug(
    *,
    prefix: str = "DBG",
    show_return: bool = True,
    logger: Callable[[str], Any] = print,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Decorator factory that traces calls with indentation using ContextVar."""

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            async def async_inner(*args: Any, **kwargs: Any) -> Any:
                depth = _call_depth.get()
                token = _call_depth.set(depth + 1)
                indent = "  " * depth
                logger(f"{indent}{prefix} call {fn.__name__} args={args} kwargs={kwargs}")
                try:
                    result = await fn(*args, **kwargs)
                    if show_return:
                        logger(f"{indent}{prefix} return {fn.__name__} -> {result!r}")
                    return result
                finally:
                    _call_depth.reset(token)

            return async_inner

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            depth = _call_depth.get()
            token = _call_depth.set(depth + 1)
            indent = "  " * depth
            logger(f"{indent}{prefix} call {fn.__name__} args={args} kwargs={kwargs}")
            try:
                result = fn(*args, **kwargs)
                if show_return:
                    logger(f"{indent}{prefix} return {fn.__name__} -> {result!r}")
                return result
            finally:
                _call_depth.reset(token)

        return inner

    return decorator


In [8]:
# Demo with recursion
log: list[str] = []
logger = log.append

@debug(prefix="TRACE", logger=logger)
def factorial(n: int) -> int:
    return 1 if n <= 1 else n * factorial(n - 1)

assert factorial(4) == 24
# Indentation should increase then decrease; we just sanity-check we captured multiple lines.
assert any("TRACE call factorial" in line for line in log)
assert any("TRACE return factorial" in line for line in log)
log[:6]


['TRACE call factorial args=(4,) kwargs={}',
 '  TRACE call factorial args=(3,) kwargs={}',
 '    TRACE call factorial args=(2,) kwargs={}',
 '      TRACE call factorial args=(1,) kwargs={}',
 '      TRACE return factorial -> 1',
 '    TRACE return factorial -> 2']

## Problem 4 — `@retry(...)` with exponential backoff (+ async support)

Implement a decorator factory `retry(...)`.

Requirements:
- Parameters:
  - `max_attempts: int = 3` (>= 1)
  - `exceptions: tuple[type[BaseException], ...] = (Exception,)`
  - `delay: float = 0.0` (seconds)
  - `backoff: float = 1.0` (multiplier)
  - `jitter: float = 0.0` (adds random +/- jitter * current_delay)
- If it keeps failing, re-raise the **last** exception.
- Works for sync and async functions.
- Preserve metadata with `wraps`.

### Solution


In [9]:
def retry(
    *,
    max_attempts: int = 3,
    exceptions: tuple[type[BaseException], ...] = (Exception,),
    delay: float = 0.0,
    backoff: float = 1.0,
    jitter: float = 0.0,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    if not isinstance(max_attempts, int) or max_attempts < 1:
        raise ValueError("max_attempts must be an int >= 1")
    if delay < 0:
        raise ValueError("delay must be >= 0")
    if backoff <= 0:
        raise ValueError("backoff must be > 0")
    if jitter < 0:
        raise ValueError("jitter must be >= 0")

    def next_delay(current: float) -> float:
        if current <= 0:
            return 0.0
        if jitter:
            current += random.uniform(-jitter * current, jitter * current)
            current = max(0.0, current)
        return current

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            async def async_inner(*args: Any, **kwargs: Any) -> Any:
                current_delay = delay
                last_exc: Optional[BaseException] = None
                for attempt in range(1, max_attempts + 1):
                    try:
                        return await fn(*args, **kwargs)
                    except exceptions as ex:  # type: ignore[misc]
                        last_exc = ex
                        if attempt >= max_attempts:
                            raise
                        if current_delay:
                            await asyncio.sleep(next_delay(current_delay))
                        current_delay *= backoff
                # Should be unreachable
                assert last_exc is not None
                raise last_exc

            return async_inner

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            current_delay = delay
            last_exc: Optional[BaseException] = None
            for attempt in range(1, max_attempts + 1):
                try:
                    return fn(*args, **kwargs)
                except exceptions as ex:  # type: ignore[misc]
                    last_exc = ex
                    if attempt >= max_attempts:
                        raise
                    if current_delay:
                        time.sleep(next_delay(current_delay))
                    current_delay *= backoff
            assert last_exc is not None
            raise last_exc

        return inner

    return decorator


In [10]:
# Sync test
attempts = {"n": 0}

@retry(max_attempts=4, delay=0.0, backoff=1.0)
def flaky() -> str:
    attempts["n"] += 1
    if attempts["n"] < 3:
        raise ValueError("nope")
    return "ok"

assert flaky() == "ok"
assert attempts["n"] == 3

# It should re-raise on permanent failure:
@retry(max_attempts=2, exceptions=(KeyError,), delay=0.0)
def always_keyerror() -> None:
    raise KeyError("bad")

ex = assert_raises(KeyError, always_keyerror)
assert "bad" in str(ex)


In [11]:
# Async test
attempts_async = {"n": 0}

@retry(max_attempts=3, delay=0.0)
async def async_flaky() -> int:
    attempts_async["n"] += 1
    if attempts_async["n"] < 2:
        raise RuntimeError("try again")
    return 42

assert await async_flaky() == 42
assert attempts_async["n"] == 2


## Problem 5 — Runtime type validation from annotations (advanced)

Implement `validated(...)` that enforces type hints **at runtime**.

Requirements:
- Only validate parameters/return values that have annotations.
- Use `inspect.signature(fn).bind_partial(...)` to map args to parameter names.
- Support at least:
  - `int`, `str`, etc.
  - `Optional[T]` / `Union[...]`
  - `list[T]`, `tuple[T, ...]`, `dict[K, V]`
- Options:
  - `check_return: bool = True`
  - `strict: bool = False` (if True, reject subclasses; if False, allow `isinstance`)

Note: This is not meant to be a full type-checker; it's a practical subset.

### Solution


In [12]:
import typing
import types
import collections.abc
from functools import wraps
import inspect
from typing import Any, Callable, TypeVar, get_args, get_origin

T = TypeVar("T")


def _is_instance(value: Any, ann: Any, *, strict: bool) -> bool:
    """Best-effort runtime check for common typing constructs."""
    if ann is Any or ann is inspect._empty:
        return True

    origin = get_origin(ann)
    args = get_args(ann)

    # Union / Optional
    if origin in (typing.Union, types.UnionType):
        return any(_is_instance(value, t, strict=strict) for t in args)

    if origin is list:
        if not isinstance(value, list):
            return False
        (elem_t,) = args or (Any,)
        return all(_is_instance(v, elem_t, strict=strict) for v in value)

    if origin is tuple:
        if not isinstance(value, tuple):
            return False
        if len(args) == 2 and args[1] is Ellipsis:
            return all(_is_instance(v, args[0], strict=strict) for v in value)
        if len(args) != len(value):
            return False
        return all(_is_instance(v, t, strict=strict) for v, t in zip(value, args))

    if origin is dict:
        if not isinstance(value, dict):
            return False
        key_t, val_t = args or (Any, Any)
        return all(
            _is_instance(k, key_t, strict=strict) and _is_instance(v, val_t, strict=strict)
            for k, v in value.items()
        )

    if origin is set:
        if not isinstance(value, set):
            return False
        (elem_t,) = args or (Any,)
        return all(_is_instance(v, elem_t, strict=strict) for v in value)

    if origin in (collections.abc.Callable, Callable):
        return callable(value)

    # Plain classes (int, str, custom types)
    if origin is None:
        if isinstance(ann, type):
            return type(value) is ann if strict else isinstance(value, ann)
        return True

    # Fallback: check origin as a normal class
    if isinstance(origin, type):
        return type(value) is origin if strict else isinstance(value, origin)

    return True


def validated(*, check_return: bool = True, strict: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Decorator factory enforcing a pragmatic subset of typing annotations.

    Uses typing.get_type_hints so it works with `from __future__ import annotations`.
    """

    def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
        sig = inspect.signature(fn)

        # Resolve annotations (handles string annotations from __future__ import annotations)
        try:
            hints = typing.get_type_hints(fn, globalns=fn.__globals__, localns=None, include_extras=True)
        except Exception:
            # Fallback: raw annotations (may be strings)
            hints = dict(getattr(fn, "__annotations__", {}))

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            bound = sig.bind_partial(*args, **kwargs)
            bound.apply_defaults()

            # Validate inputs
            for name, value in bound.arguments.items():
                ann = hints.get(name, inspect._empty)
                if ann is inspect._empty:
                    continue
                if not _is_instance(value, ann, strict=strict):
                    raise TypeError(
                        f"Argument '{name}' must be {ann!r}, got {type(value).__name__}: {value!r}"
                    )

            result = fn(*args, **kwargs)

            # Validate return
            if check_return:
                ret_ann = hints.get("return", inspect._empty)
                if ret_ann is not inspect._empty and not _is_instance(result, ret_ann, strict=strict):
                    raise TypeError(
                        f"Return value must be {ret_ann!r}, got {type(result).__name__}: {result!r}"
                    )

            return result

        return inner

    return decorator


In [13]:
@validated()
def join_ints(xs: list[int], sep: str = ",") -> str:
    return sep.join(str(x) for x in xs)

assert join_ints([1, 2, 3]) == "1,2,3"
assert_raises(TypeError, join_ints, ["1", "2"])     # list[str] not list[int]
assert_raises(TypeError, join_ints, [1, 2], sep=5)  # sep should be str

@validated(check_return=True)
def bad_return(xs: list[int]) -> int:
    return "nope"  # type: ignore[return-value]

assert_raises(TypeError, bad_return, [1, 2, 3])


TypeError("Return value must be <class 'int'>, got str: 'nope'")

## Problem 6 — `@once(...)`: execute exactly once (thread-safe)

Implement a decorator factory `once(cache_exceptions=False)`.

Behavior:
- The first call runs the function and caches either:
  - the result, or
  - the exception (only if `cache_exceptions=True`).
- Subsequent calls:
  - return the cached result, or
  - re-raise the cached exception.

Requirements:
- Thread-safe.
- Preserve metadata.

### Solution


In [14]:
def once(*, cache_exceptions: bool = False) -> Callable[[Callable[..., T]], Callable[..., T]]:
    """Decorator factory that executes the function at most once (thread-safe)."""

    def decorator(fn: Callable[..., T]) -> Callable[..., T]:
        lock = threading.Lock()
        has_value = False
        cached_value: Any = None
        cached_exc: Optional[BaseException] = None

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> T:
            nonlocal has_value, cached_value, cached_exc
            if has_value:
                if cached_exc is not None:
                    raise cached_exc
                return cached_value  # type: ignore[return-value]

            with lock:
                if has_value:
                    if cached_exc is not None:
                        raise cached_exc
                    return cached_value  # type: ignore[return-value]

                try:
                    cached_value = fn(*args, **kwargs)
                    has_value = True
                    return cached_value
                except BaseException as ex:
                    if cache_exceptions:
                        cached_exc = ex
                        has_value = True
                    raise

        return inner

    return decorator


In [15]:
counter = {"n": 0}

@once()
def init_value() -> int:
    counter["n"] += 1
    return 123

assert init_value() == 123
assert init_value() == 123
assert counter["n"] == 1

counter2 = {"n": 0}

@once(cache_exceptions=True)
def init_fail() -> int:
    counter2["n"] += 1
    raise RuntimeError("boom")

assert_raises(RuntimeError, init_fail)
assert_raises(RuntimeError, init_fail)
assert counter2["n"] == 1


## Problem 7 — Class-based decorator `@RateLimit(...)` (stateful)

Implement a class-based decorator `RateLimit(calls, period, mode='raise')`.

Semantics:
- Allow at most `calls` executions per `period` seconds.
- Track timestamps (monotonic clock) in a deque.
- `mode='raise'`: raise `RuntimeError` when limit exceeded.
- `mode='sleep'`: sleep until allowed, then proceed.

Requirements:
- Works for both sync and async functions.
- Preserve metadata with `wraps`.

### Solution


In [16]:
@dataclass(frozen=True)
class RateLimit:
    calls: int
    period: float
    mode: str = "raise"  # 'raise' or 'sleep'

    def __post_init__(self) -> None:
        if self.calls < 1:
            raise ValueError("calls must be >= 1")
        if self.period <= 0:
            raise ValueError("period must be > 0")
        if self.mode not in {"raise", "sleep"}:
            raise ValueError("mode must be 'raise' or 'sleep'")

    def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
        timestamps: deque[float] = deque()
        lock = threading.Lock()
        clock = time.monotonic

        def _prune(now: float) -> None:
            boundary = now - self.period
            while timestamps and timestamps[0] <= boundary:
                timestamps.popleft()

        def _time_to_wait(now: float) -> float:
            # After pruning, if too many calls remain, wait until oldest expires.
            _prune(now)
            if len(timestamps) < self.calls:
                return 0.0
            oldest = timestamps[0]
            return max(0.0, (oldest + self.period) - now)

        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            async def async_inner(*args: Any, **kwargs: Any) -> Any:
                while True:
                    with lock:
                        now = clock()
                        wait = _time_to_wait(now)
                        if wait == 0.0:
                            timestamps.append(now)
                            break
                    if self.mode == "raise":
                        raise RuntimeError("rate limit exceeded")
                    await asyncio.sleep(wait)
                return await fn(*args, **kwargs)

            return async_inner

        @wraps(fn)
        def inner(*args: Any, **kwargs: Any) -> Any:
            while True:
                with lock:
                    now = clock()
                    wait = _time_to_wait(now)
                    if wait == 0.0:
                        timestamps.append(now)
                        break
                if self.mode == "raise":
                    raise RuntimeError("rate limit exceeded")
                time.sleep(wait)
            return fn(*args, **kwargs)

        return inner


In [17]:
# Demo (sync, raise)
@RateLimit(calls=2, period=0.5, mode="raise")
def ping() -> str:
    return "pong"

assert ping() == "pong"
assert ping() == "pong"
assert_raises(RuntimeError, ping)

# Demo (sync, sleep): should eventually work without raising
@RateLimit(calls=2, period=0.3, mode="sleep")
def ping_sleep() -> str:
    return "pong"

assert ping_sleep() == "pong"
assert ping_sleep() == "pong"
assert ping_sleep() == "pong"  # may sleep a bit


In [18]:
# Demo (async)
@RateLimit(calls=1, period=0.2, mode="sleep")
async def aping() -> str:
    return "pong"

assert await aping() == "pong"
assert await aping() == "pong"  # must sleep


## Problem 8 — Stacking decorators and verifying metadata

Stacking multiple decorators is common, but order matters.

Task:
- Create a function decorated with both `@timed(...)` and `@retry(...)`.
- Verify:
  - the signature and name are preserved
  - the retry behavior happens (simulate failure)

Then swap the order of decorators and observe the effect on timing/logging.

### Solution + demo


In [19]:
# Stack order: retry wraps timed or timed wraps retry depending on order.

log_lines: list[str] = []

def list_logger(msg: str) -> None:
    log_lines.append(msg)

attempt = {"n": 0}

@retry(max_attempts=3, delay=0.0)
@timed(1, logger=list_logger)   # <- was 2
def sometimes_fails(x: int) -> int:
    attempt["n"] += 1
    if attempt["n"] < 2:
        raise ValueError("fail once")
    return x * 2

assert sometimes_fails(21) == 42
assert attempt["n"] == 2

# Metadata preserved:
assert sometimes_fails.__name__ == "sometimes_fails"
# assert str(inspect.signature(sometimes_fails)) == "(x: int) -> int"
sig = inspect.signature(sometimes_fails)
assert sig.parameters["x"].annotation in (int, "int")
assert sig.return_annotation in (int, "int")

# We should have timing logs for successful attempts (not necessarily for failed ones, depending on order).
log_lines[:3], pretty_sig(sometimes_fails)


(['Avg run time: 0.000003s (1 reps)'], "sometimes_fails(x: 'int') -> 'int'")

In [20]:
# Swap order to observe different behavior: timed outermost measures *including retries*.

log_lines2: list[str] = []
attempt2 = {"n": 0}

@timed(1, logger=log_lines2.append)
@retry(max_attempts=3, delay=0.0)
def sometimes_fails_2(x: int) -> int:
    attempt2["n"] += 1
    if attempt2["n"] < 3:
        raise ValueError("fail twice")
    return x + 1

assert sometimes_fails_2(10) == 11
assert attempt2["n"] == 3

# Here the timing wrapper sees the retries because it is outside retry.
log_lines2, pretty_sig(sometimes_fails_2)


(['Avg run time: 0.000010s (1 reps)'], "sometimes_fails_2(x: 'int') -> 'int'")