# Decorators (Part 1) — Advanced Problems (with Solutions)

This notebook contains advanced exercises on Python decorators: closures, `functools.wraps`, decorator factories, stacking, async support, caching, type enforcement, and method/descriptor behavior.

**How to use**: read each exercise, try implementing it yourself (optional), then compare with the included solution and run the tests at the end.


In [1]:
from __future__ import annotations

import asyncio
import inspect
import time
import threading
import functools
import typing
from collections import OrderedDict
from weakref import WeakKeyDictionary


## Exercise 1 — A robust call-counter decorator (optional args, thread-safe, resettable)

Write a decorator `call_counter` that can be used **with or without arguments**:

- `@call_counter`
- `@call_counter(thread_safe=True)`

Requirements:

1. Preserves metadata (`__name__`, `__doc__`, and signature introspection) using `functools.wraps`.
2. Counts calls in an internal counter.
3. Exposes:
   - `wrapper.call_count` (an `int` attribute)
   - `wrapper.reset()` (sets count to zero)
4. If `thread_safe=True`, increments are protected by a lock.

Starter sketch:

```python
def call_counter(fn=None, *, thread_safe: bool = False):
    ...
```


In [2]:
# --- Solution (Exercise 1) ---

def call_counter(fn=None, *, thread_safe: bool = False):
    """Count how many times a function is called.

    Can be used as:
        @call_counter
        @call_counter(thread_safe=True)

    Adds:
        - call_count (int attribute, updated on every call)
        - reset() method
    """
    def decorate(target):
        count = 0
        lock = threading.Lock() if thread_safe else None

        @functools.wraps(target)
        def wrapper(*args, **kwargs):
            nonlocal count
            if lock is None:
                count += 1
                wrapper.call_count = count  # type: ignore[attr-defined]
            else:
                with lock:
                    count += 1
                    wrapper.call_count = count  # type: ignore[attr-defined]
            return target(*args, **kwargs)

        def reset():
            nonlocal count
            if lock is None:
                count = 0
                wrapper.call_count = 0  # type: ignore[attr-defined]
            else:
                with lock:
                    count = 0
                    wrapper.call_count = 0  # type: ignore[attr-defined]

        wrapper.call_count = 0  # type: ignore[attr-defined]
        wrapper.reset = reset    # type: ignore[attr-defined]
        return wrapper

    # Support @call_counter and @call_counter(...)
    if callable(fn):
        return decorate(fn)
    return decorate


## Exercise 2 — A `trace` decorator usable with or without arguments

Implement `trace` that prints a one-line trace on every call.

Usage:
- `@trace`
- `@trace(prefix='DBG', max_len=60, show_return=False)`

Requirements:
- Use `functools.wraps`.
- Format: `PREFIX func_name(arg1_repr, kw=value_repr, ...)`
- Truncate each `repr(...)` to `max_len` characters (append `…`).
- If `show_return=True`, also print `-> return_repr`.

Tip: build a small helper `short_repr(obj, max_len)`.


In [3]:
# --- Solution (Exercise 2) ---

def _short_repr(obj, max_len: int) -> str:
    s = repr(obj)
    if len(s) <= max_len:
        return s
    return s[: max(0, max_len - 1)] + "…"

def trace(fn=None, *, prefix: str = "", max_len: int = 80, show_return: bool = True):
    """Print a one-line trace per call (supports optional arguments)."""
    def decorate(target):
        @functools.wraps(target)
        def wrapper(*args, **kwargs):
            parts = []
            parts.extend(_short_repr(a, max_len) for a in args)
            parts.extend(f"{k}={_short_repr(v, max_len)}" for k, v in kwargs.items())
            pre = (prefix + " ") if prefix else ""
            call_s = f"{pre}{target.__name__}({', '.join(parts)})"
            if show_return:
                result = target(*args, **kwargs)
                print(f"{call_s} -> {_short_repr(result, max_len)}")
                return result
            else:
                print(call_s)
                return target(*args, **kwargs)
        return wrapper

    if callable(fn):
        return decorate(fn)
    return decorate


## Exercise 3 — Runtime type enforcement from annotations

Write `@enforce_types` that checks argument and return types using the function's type annotations.

Requirements:

1. Use `typing.get_type_hints` + `inspect.signature` to bind arguments.
2. For each annotated parameter, raise `TypeError` if the provided value does not match.
3. If a return annotation exists, validate the returned value too.
4. Support a **practical subset** of typing constructs:
   - `Any` (skip check)
   - `Union` / `Optional`
   - `list[T]`, `set[T]`, `tuple[T, ...]`, `dict[K, V]`
   - plain classes like `int`, `str`, custom classes

Notes:
- This is a teaching exercise, not a full type checker.
- Use recursion for container element checks.


In [4]:
# --- Solution (Exercise 3) ---

_Any = typing.Any

def _is_any(t) -> bool:
    return t is _Any

def _origin(t):
    return getattr(t, "__origin__", None)

def _args(t):
    return getattr(t, "__args__", ())

def _type_name(t) -> str:
    try:
        return t.__name__
    except Exception:
        return str(t)

def _check_type(value, expected) -> bool:
    """Return True if value conforms to expected (best-effort subset)."""
    if _is_any(expected):
        return True

    # Forward refs are resolved by get_type_hints already.
    origin = _origin(expected)
    args = _args(expected)

    # Union / Optional
    if origin is typing.Union:
        return any(_check_type(value, t) for t in args)

    # NoneType
    if expected is type(None):  # noqa: E721
        return value is None

    # Containers
    if origin in (list, typing.List):
        if not isinstance(value, list):
            return False
        (elem_t,) = args if args else (_Any,)
        return all(_check_type(v, elem_t) for v in value)

    if origin in (set, typing.Set):
        if not isinstance(value, set):
            return False
        (elem_t,) = args if args else (_Any,)
        return all(_check_type(v, elem_t) for v in value)

    if origin in (dict, typing.Dict):
        if not isinstance(value, dict):
            return False
        key_t, val_t = args if len(args) == 2 else (_Any, _Any)
        return all(_check_type(k, key_t) and _check_type(v, val_t) for k, v in value.items())

    if origin in (tuple, typing.Tuple):
        if not isinstance(value, tuple):
            return False
        if not args:
            return True
        # tuple[T, ...] (homogeneous)
        if len(args) == 2 and args[1] is Ellipsis:
            return all(_check_type(v, args[0]) for v in value)
        # tuple[T1, T2, ...] (fixed length)
        if len(value) != len(args):
            return False
        return all(_check_type(v, t) for v, t in zip(value, args))

    # Fallback: plain class check
    try:
        return isinstance(value, expected)
    except TypeError:
        # Some typing objects are not valid in isinstance
        return True  # best-effort: don't block execution

def enforce_types(fn):
    """Enforce runtime types based on annotations (best-effort subset)."""
    sig = inspect.signature(fn)
    hints = typing.get_type_hints(fn)

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        bound = sig.bind(*args, **kwargs)
        bound.apply_defaults()

        for name, value in bound.arguments.items():
            if name in hints and name != "return":
                expected = hints[name]
                if not _check_type(value, expected):
                    raise TypeError(
                        f"Argument '{name}' expected {_type_name(expected)}, got {type(value).__name__}"
                    )

        result = fn(*args, **kwargs)

        if "return" in hints:
            expected_ret = hints["return"]
            if not _check_type(result, expected_ret):
                raise TypeError(
                    f"Return value expected {_type_name(expected_ret)}, got {type(result).__name__}"
                )
        return result

    return wrapper


## Exercise 4 — A retry decorator that works for sync *and* async functions

Implement `retry` as a **decorator factory**:

```python
@retry(tries=5, exceptions=(ValueError,), delay=0.01, backoff=2.0)
def flaky(...):
    ...
```

Requirements:
- Retries on the specified `exceptions` up to `tries` total attempts.
- Waits `delay` seconds before the 2nd attempt, then multiplies delay by `backoff` each time.
- Works for both normal and `async def` functions.
- Accepts an optional `sleep` argument for testing (sync sleep or async sleep).
- Uses `functools.wraps`.


In [5]:
# --- Solution (Exercise 4) ---

def retry(*, tries: int = 3,
          exceptions: tuple[type[BaseException], ...] = (Exception,),
          delay: float = 0.0,
          backoff: float = 1.0,
          sleep=None):
    if tries < 1:
        raise ValueError("tries must be >= 1")
    if backoff < 1.0:
        raise ValueError("backoff must be >= 1.0")
    if delay < 0:
        raise ValueError("delay must be >= 0")

    def decorate(fn):
        is_async = asyncio.iscoroutinefunction(fn)

        if is_async:
            async_sleep = sleep if sleep is not None else asyncio.sleep

            @functools.wraps(fn)
            async def awrapper(*args, **kwargs):
                d = delay
                last_exc = None
                for attempt in range(1, tries + 1):
                    try:
                        return await fn(*args, **kwargs)
                    except exceptions as e:
                        last_exc = e
                        if attempt == tries:
                            raise
                        if d > 0:
                            await async_sleep(d)
                        d *= backoff
                raise last_exc  # pragma: no cover
            return awrapper

        else:
            sync_sleep = sleep if sleep is not None else time.sleep

            @functools.wraps(fn)
            def swrapper(*args, **kwargs):
                d = delay
                last_exc = None
                for attempt in range(1, tries + 1):
                    try:
                        return fn(*args, **kwargs)
                    except exceptions as e:
                        last_exc = e
                        if attempt == tries:
                            raise
                        if d > 0:
                            sync_sleep(d)
                        d *= backoff
                raise last_exc  # pragma: no cover
            return swrapper

    return decorate


## Exercise 5 — A TTL + LRU cache decorator

Implement `ttl_cache` as a decorator factory:

```python
@ttl_cache(ttl_seconds=2.0, maxsize=128)
def expensive(x): ...
```

Requirements:
- Cache results keyed by `(args, kwargs)` (or a user-provided `key` function).
- Evict least-recently-used items beyond `maxsize`.
- Entries expire after `ttl_seconds` and are recomputed on next access.
- Provide:
  - `wrapper.cache_clear()`
  - `wrapper.cache_info()` returning a dict with hits/misses/size/maxsize

Best practice tip: keep the cache internal, but expose simple observability hooks.


In [6]:
# --- Solution (Exercise 5) ---

def _default_cache_key(args, kwargs):
    """Create a hashable cache key from args/kwargs or raise TypeError."""
    try:
        return (args, tuple(sorted(kwargs.items())))
    except TypeError as e:
        raise TypeError("Unhashable arguments; provide a custom key=... function") from e

def ttl_cache(*, ttl_seconds: float, maxsize: int = 128, key=None):
    if ttl_seconds <= 0:
        raise ValueError("ttl_seconds must be > 0")
    if maxsize < 1:
        raise ValueError("maxsize must be >= 1")

    def decorate(fn):
        make_key = key if key is not None else _default_cache_key
        cache = OrderedDict()  # key -> (expires_at, value)
        hits = 0
        misses = 0
        lock = threading.Lock()

        def _purge_expired(now: float):
            # OrderedDict is in LRU order, not expiration order, so do a full scan.
            expired_keys = [k for k, (exp, _) in cache.items() if exp <= now]
            for k in expired_keys:
                cache.pop(k, None)

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal hits, misses
            k = make_key(args, kwargs)
            now = time.time()
            with lock:
                _purge_expired(now)
                if k in cache:
                    exp, val = cache.pop(k)
                    # reinsert to mark as most-recent
                    cache[k] = (exp, val)
                    hits += 1
                    return val
                misses += 1

            # compute outside lock (avoid blocking other threads)
            val = fn(*args, **kwargs)
            exp = now + ttl_seconds

            with lock:
                cache[k] = (exp, val)
                while len(cache) > maxsize:
                    cache.popitem(last=False)  # evict LRU
            return val

        def cache_clear():
            nonlocal hits, misses
            with lock:
                cache.clear()
                hits = 0
                misses = 0

        def cache_info():
            with lock:
                return {
                    "hits": hits,
                    "misses": misses,
                    "size": len(cache),
                    "maxsize": maxsize,
                    "ttl_seconds": ttl_seconds,
                }

        wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info    # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 6 — Translate exceptions (sync + async)

Create `translate_exceptions(mapping, default=None)` returning a decorator that transforms raised exceptions.

- `mapping` is a dict like `{KeyError: ValueError, ZeroDivisionError: lambda e: RuntimeError('bad math')}`
- If an exception is caught and transformed, chain the original exception using `raise new_exc from e`.
- If `default` is provided, transform any other exception not in `mapping` using `default(e)` (a callable) or a class.
- Must work for both sync and async functions.
- Preserve metadata with `wraps`.


In [7]:
# --- Solution (Exercise 6) ---

def translate_exceptions(mapping: dict, *, default=None):
    if not isinstance(mapping, dict):
        raise TypeError("mapping must be a dict")

    def _make_new(exc, spec):
        # spec can be an exception class or a callable(exc)->exception
        if spec is None:
            return None
        if isinstance(spec, type) and issubclass(spec, BaseException):
            return spec(str(exc))
        if callable(spec):
            new_exc = spec(exc)
            if not isinstance(new_exc, BaseException):
                raise TypeError("exception factory must return an exception instance")
            return new_exc
        raise TypeError("mapping/default values must be exception classes or callables")

    def decorate(fn):
        is_async = asyncio.iscoroutinefunction(fn)

        if is_async:
            @functools.wraps(fn)
            async def awrapper(*args, **kwargs):
                try:
                    return await fn(*args, **kwargs)
                except Exception as e:
                    for src_type, target in mapping.items():
                        if isinstance(e, src_type):
                            new_exc = _make_new(e, target)
                            if new_exc is None:
                                raise
                            raise new_exc from e
                    if default is not None:
                        new_exc = _make_new(e, default)
                        if new_exc is None:
                            raise
                        raise new_exc from e
                    raise
            return awrapper

        else:
            @functools.wraps(fn)
            def swrapper(*args, **kwargs):
                try:
                    return fn(*args, **kwargs)
                except Exception as e:
                    for src_type, target in mapping.items():
                        if isinstance(e, src_type):
                            new_exc = _make_new(e, target)
                            if new_exc is None:
                                raise
                            raise new_exc from e
                    if default is not None:
                        new_exc = _make_new(e, default)
                        if new_exc is None:
                            raise
                        raise new_exc from e
                    raise
            return swrapper

    return decorate


## Exercise 7 — A descriptor-based decorator for per-instance method call counts

This exercise highlights an advanced detail: **functions become descriptors when defined on a class**.

Implement `@count_calls_per_instance` that works on instance methods and maintains counts *per instance*, not globally.

Requirements:
- When used on a method `C.m`, counts are stored per `C` instance.
- Access counts via:
  - `obj.m.call_count`
  - `obj.m.reset()`
- Preserve metadata (`__name__`, `__doc__`).

Hint: use a descriptor object with `__get__` and a `WeakKeyDictionary` to avoid memory leaks.


In [8]:
# --- Solution (Exercise 7) ---

class _BoundCountedMethod:
    def __init__(self, decorator, obj):
        self._decorator = decorator
        self._obj = obj
        functools.update_wrapper(self, decorator._fn)

    def __call__(self, *args, **kwargs):
        return self._decorator._call(self._obj, *args, **kwargs)

    @property
    def call_count(self) -> int:
        return self._decorator._get_count(self._obj)

    def reset(self) -> None:
        self._decorator._reset(self._obj)

class count_calls_per_instance:
    """Decorator for instance methods that counts calls per instance."""
    def __init__(self, fn):
        self._fn = fn
        self._counts = WeakKeyDictionary()  # instance -> int
        functools.update_wrapper(self, fn)

    def _get_count(self, obj) -> int:
        return int(self._counts.get(obj, 0))

    def _reset(self, obj) -> None:
        self._counts[obj] = 0

    def _call(self, obj, *args, **kwargs):
        self._counts[obj] = self._get_count(obj) + 1
        return self._fn(obj, *args, **kwargs)

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        return _BoundCountedMethod(self, obj)


## Exercise 8 — Decorator stacking and preserving `__wrapped__` for introspection

Create two decorators:

1. `@timed` — measures elapsed time and attaches `last_elapsed` (seconds) to the wrapper.
2. `@requires(predicate)` — decorator factory that blocks execution unless `predicate(*args, **kwargs)` is True.

Requirements:
- Both use `functools.wraps` so that `inspect.signature` still reveals the original signature.
- Demonstrate stacking order and verify that `inspect.signature` still works.


In [9]:
# --- Solution (Exercise 8) ---

def timed(fn):
    last_elapsed = 0.0

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        nonlocal last_elapsed
        t0 = time.perf_counter()
        try:
            return fn(*args, **kwargs)
        finally:
            last_elapsed = time.perf_counter() - t0
            wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]

    wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
    return wrapper

def requires(predicate):
    if not callable(predicate):
        raise TypeError("predicate must be callable")

    def decorate(fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            if not predicate(*args, **kwargs):
                raise PermissionError(f"Predicate blocked call to {fn.__name__}")
            return fn(*args, **kwargs)
        return wrapper

    return decorate


## Quick reference: what `wraps` really buys you

When you use `functools.wraps(fn)`, Python sets `wrapper.__wrapped__ = fn` (among other metadata).
Many inspection tools (including `inspect.signature`) will follow `__wrapped__` and show the original signature,
even if the wrapper uses `*args, **kwargs` internally.


## Test suite (run this cell)

These tests validate correctness and highlight best practices. If anything fails, review the corresponding solution.


In [11]:
# --- Tests ---

# Exercise 1
@call_counter
def _f(a, b=1):
    """doc"""
    return a + b

assert _f(1, 2) == 3
assert _f(2) == 3
assert _f.call_count == 2
_f.reset()
assert _f.call_count == 0
assert _f.__name__ == "_f"
assert "doc" in (_f.__doc__ or "")
assert str(inspect.signature(_f)) == "(a, b=1)"  # thanks to __wrapped__

# Exercise 2
out = []
def _capture_print(*args, **kwargs):
    out.append(" ".join(str(a) for a in args))
_real_print = print
try:
    builtins = __import__("builtins")
    builtins.print = _capture_print

    @trace(prefix="DBG", max_len=20)
    def g(x, y=1):
        return x + y

    assert g(10, y=2) == 12
    assert any(s.startswith("DBG g(") and "->" in s for s in out)
finally:
    builtins.print = _real_print

# Exercise 3
@enforce_types
def h(x: int, items: list[int]) -> int:
    return x + sum(items)

assert h(1, [2, 3]) == 6
try:
    h("1", [2, 3])  # type: ignore[arg-type]
    raise AssertionError("Expected TypeError")
except TypeError as e:
    assert "Argument 'x'" in str(e)

try:
    h(1, [2, "3"])  # type: ignore[list-item]
    raise AssertionError("Expected TypeError")
except TypeError as e:
    assert "items" in str(e)

@enforce_types
def h2(x: int) -> str:
    return x  # type: ignore[return-value]

try:
    h2(1)
    raise AssertionError("Expected TypeError for return")
except TypeError as e:
    assert "Return value" in str(e)

# Exercise 4 (sync)
calls = {"n": 0}
slept = []
def fake_sleep(t):
    slept.append(t)

@retry(tries=4, exceptions=(ValueError,), delay=0.1, backoff=2.0, sleep=fake_sleep)
def flaky():
    calls["n"] += 1
    if calls["n"] < 3:
        raise ValueError("nope")
    return "ok"

assert flaky() == "ok"
assert calls["n"] == 3
assert slept == [0.1, 0.2]

# Exercise 4 (async)
async_calls = {"n": 0}
async_slept = []
async def fake_async_sleep(t):
    async_slept.append(t)

@retry(tries=3, exceptions=(RuntimeError,), delay=0.05, backoff=2.0, sleep=fake_async_sleep)
async def aflaky():
    async_calls["n"] += 1
    if async_calls["n"] < 2:
        raise RuntimeError("nope")
    return 123

async def _run_async_tests():
    assert await aflaky() == 123
    assert async_slept == [0.05]

await _run_async_tests()

# Exercise 5
@ttl_cache(ttl_seconds=0.5, maxsize=2)
def sq(x):
    return x * x

assert sq(2) == 4
assert sq(2) == 4
info = sq.cache_info()
assert info["hits"] == 1 and info["misses"] == 1

# Expiration
time.sleep(0.6)
assert sq(2) == 4
info2 = sq.cache_info()
assert info2["misses"] >= 2

# LRU behavior (maxsize=2)
sq.cache_clear()
sq(1); sq(2); sq(3)
assert sq.cache_info()["size"] == 2

# Exercise 6
@translate_exceptions({KeyError: ValueError})
def lookup(d, k):
    return d[k]

try:
    lookup({}, "x")
    raise AssertionError("Expected ValueError")
except ValueError as e:
    assert isinstance(e.__cause__, KeyError)

@translate_exceptions({KeyError: ValueError}, default=RuntimeError)
def boom(kind):
    if kind == "key":
        raise KeyError("x")
    raise ZeroDivisionError("z")

try:
    boom("other")
    raise AssertionError("Expected RuntimeError")
except RuntimeError as e:
    assert isinstance(e.__cause__, ZeroDivisionError)

# Exercise 7
class C:
    @count_calls_per_instance
    def inc(self, x=1):
        return x + 1

c1 = C()
c2 = C()
assert c1.inc(10) == 11
assert c1.inc.call_count == 1
assert c1.inc() == 2
assert c1.inc.call_count == 2
assert c2.inc.call_count == 0
c1.inc.reset()
assert c1.inc.call_count == 0

# Exercise 8
@timed
@requires(lambda x: x >= 0)
def sqrt_like(x: float) -> float:
    return x ** 0.5

assert abs(sqrt_like(9.0) - 3.0) < 1e-9
assert hasattr(sqrt_like, "last_elapsed")

try:
    sqrt_like(-1.0)
    raise AssertionError("Expected PermissionError")
except PermissionError:
    pass

print("All tests passed ✅")


All tests passed ✅
