# Decorators Application (Timing) — Advanced Problems (with Solutions)

This notebook focuses on **best-practice timing decorators**: correctness, low-overhead measurement, recursion pitfalls, async support, statistics across multiple runs, calibration/warmup, and safe printing/logging.

**How to use**: attempt each exercise, then compare with the included solution. A test/verification section is included at the end.


In [1]:
from __future__ import annotations

import asyncio
import dataclasses
import functools
import inspect
import math
import statistics
import time
import typing
from collections import deque


## Exercise 1 — A minimal, correct `timed` decorator

Implement `@timed` that measures elapsed time of **one call** and stores it on the wrapper:

- Uses `time.perf_counter()`
- Preserves metadata with `functools.wraps`
- Attaches:
  - `wrapper.last_elapsed` (float seconds)
  - `wrapper.last_result` (most recent return value)

Best practices:
- Record elapsed time in a `finally:` block so exceptions still record timing.
- Avoid building expensive argument strings unless explicitly requested.


In [2]:
# --- Solution (Exercise 1) ---

def timed(fn):
    last_elapsed = 0.0
    last_result = None

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        nonlocal last_elapsed, last_result
        t0 = time.perf_counter()
        try:
            last_result = fn(*args, **kwargs)
            return last_result
        finally:
            last_elapsed = time.perf_counter() - t0
            wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
            wrapper.last_result = last_result    # type: ignore[attr-defined]

    wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
    wrapper.last_result = last_result    # type: ignore[attr-defined]
    return wrapper


## Exercise 2 — Timing decorator with configurable logging (avoid print by default)

Create `timeit_log` as a decorator factory:

```python
@timeit_log(label='fib', logger=print)
def f(...):
    ...
```

Requirements:
- Parameters:
  - `label: str | None` (prefix in messages; default uses function name)
  - `logger: Callable[[str], None] | None` (if None, do not log)
  - `format_args: bool` (if True, include args/kwargs, but keep it lightweight)
  - `max_len: int` (truncate arg representations)
- Store `last_elapsed` on the wrapper (like Exercise 1).
- Preserve metadata with `wraps`.

Note: logging should be optional because I/O can dwarf fast function timings.


In [3]:
# --- Solution (Exercise 2) ---

def _short_repr(obj, max_len: int) -> str:
    s = repr(obj)
    if len(s) <= max_len:
        return s
    return s[: max(0, max_len - 1)] + '…'

def timeit_log(*, label: str | None = None,
               logger: typing.Callable[[str], None] | None = None,
               format_args: bool = False,
               max_len: int = 80):
    if max_len < 10:
        raise ValueError('max_len should be >= 10 for readability')

    def decorate(fn):
        last_elapsed = 0.0

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal last_elapsed
            t0 = time.perf_counter()
            try:
                return fn(*args, **kwargs)
            finally:
                last_elapsed = time.perf_counter() - t0
                wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
                if logger is not None:
                    name = label if label is not None else fn.__name__
                    if format_args:
                        parts = []
                        parts.extend(_short_repr(a, max_len) for a in args)
                        parts.extend(f"{k}={_short_repr(v, max_len)}" for k, v in kwargs.items())
                        arg_s = ', '.join(parts)
                        msg = f"{name}({arg_s}) took {last_elapsed:.6f}s"
                    else:
                        msg = f"{name} took {last_elapsed:.6f}s"
                    logger(msg)

        wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 3 — Time a function **multiple times** and return summary statistics

Implement `bench` as a decorator factory:

```python
@bench(repeats=30, warmup=5)
def f(...):
    ...
```

Requirements:
- Runs the wrapped function `warmup` times (ignored in stats).
- Then runs it `repeats` times and computes:
  - mean, median, stdev (sample stdev), min, max
- Returns the original function's return value from the **last measured run**.
- Exposes `wrapper.last_stats` as a dict.
- Preserves metadata.

Best practices:
- Use `perf_counter`.
- Keep overhead small.
- For extremely fast functions, results are noisy; warmups help.


In [4]:
# --- Solution (Exercise 3) ---

def bench(*, repeats: int = 20, warmup: int = 3):
    if repeats < 1:
        raise ValueError('repeats must be >= 1')
    if warmup < 0:
        raise ValueError('warmup must be >= 0')

    def decorate(fn):
        last_stats: dict[str, float] | None = None

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            nonlocal last_stats

            # Warmup
            for _ in range(warmup):
                fn(*args, **kwargs)

            times = []
            last_result = None
            for _ in range(repeats):
                t0 = time.perf_counter()
                last_result = fn(*args, **kwargs)
                times.append(time.perf_counter() - t0)

            mean = statistics.fmean(times)
            median = statistics.median(times)
            stdev = statistics.stdev(times) if len(times) >= 2 else 0.0
            last_stats = {
                'mean': float(mean),
                'median': float(median),
                'stdev': float(stdev),
                'min': float(min(times)),
                'max': float(max(times)),
                'repeats': float(repeats),
                'warmup': float(warmup),
            }
            wrapper.last_stats = last_stats  # type: ignore[attr-defined]
            return last_result

        wrapper.last_stats = last_stats  # type: ignore[attr-defined]
        return wrapper

    return decorate


## Exercise 4 — Avoid recursion timing noise (time only the top-level call)

If you decorate a recursive function directly, you'll time (and possibly log) **every recursive call**.

Implement `top_level_timed`:

- Works as a decorator.
- For recursive calls, it should *not* measure every nested call.
- It should only measure the outermost call per thread/task.
- Exposes `last_elapsed` like `timed`.

Hint: use a depth counter stored in `contextvars.ContextVar` (works with async too).


In [5]:
# --- Solution (Exercise 4) ---

import contextvars

_depth_var = contextvars.ContextVar('_top_level_depth', default=0)

def top_level_timed(fn):
    last_elapsed = 0.0

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        nonlocal last_elapsed
        depth = _depth_var.get()
        token = _depth_var.set(depth + 1)
        if depth > 0:
            # Nested call: do not time
            try:
                return fn(*args, **kwargs)
            finally:
                _depth_var.reset(token)
        else:
            # Top-level call: time
            t0 = time.perf_counter()
            try:
                return fn(*args, **kwargs)
            finally:
                last_elapsed = time.perf_counter() - t0
                wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
                _depth_var.reset(token)

    wrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
    return wrapper


## Exercise 5 — Timing for async functions (and a unified decorator)

Implement `timed_any` which can wrap **sync or async** functions:

- If the function is `async def`, `await` it.
- Expose `last_elapsed`.
- Preserve metadata with `wraps`.

Best practice: choose behavior based on `asyncio.iscoroutinefunction`.


In [6]:
# --- Solution (Exercise 5) ---

def timed_any(fn):
    last_elapsed = 0.0

    if asyncio.iscoroutinefunction(fn):
        @functools.wraps(fn)
        async def awrapper(*args, **kwargs):
            nonlocal last_elapsed
            t0 = time.perf_counter()
            try:
                return await fn(*args, **kwargs)
            finally:
                last_elapsed = time.perf_counter() - t0
                awrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
        awrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
        return awrapper

    else:
        @functools.wraps(fn)
        def swrapper(*args, **kwargs):
            nonlocal last_elapsed
            t0 = time.perf_counter()
            try:
                return fn(*args, **kwargs)
            finally:
                last_elapsed = time.perf_counter() - t0
                swrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
        swrapper.last_elapsed = last_elapsed  # type: ignore[attr-defined]
        return swrapper


## Exercise 6 — Compare implementations fairly (same inputs, many trials)

Write a helper `compare(funcs, *, args=(), kwargs=None, repeats=50, warmup=5)` that:

- Executes each function with the same args/kwargs.
- Uses `perf_counter`.
- Returns a list of results sorted by mean time ascending.

Return structure per function:
```python
{
  'name': func.__name__,
  'mean': ..., 'median': ..., 'stdev': ..., 'min': ..., 'max': ...,
}
```

Best practice: do not print; return data so the caller can decide how to display.


In [7]:
# --- Solution (Exercise 6) ---

def compare(funcs: typing.Iterable[typing.Callable], *, args=(), kwargs=None, repeats: int = 50, warmup: int = 5):
    if kwargs is None:
        kwargs = {}
    funcs = list(funcs)
    if repeats < 1:
        raise ValueError('repeats must be >= 1')
    if warmup < 0:
        raise ValueError('warmup must be >= 0')

    results = []
    for fn in funcs:
        # warmup
        for _ in range(warmup):
            fn(*args, **kwargs)
        times = []
        for _ in range(repeats):
            t0 = time.perf_counter()
            fn(*args, **kwargs)
            times.append(time.perf_counter() - t0)
        res = {
            'name': fn.__name__,
            'mean': float(statistics.fmean(times)),
            'median': float(statistics.median(times)),
            'stdev': float(statistics.stdev(times) if len(times) >= 2 else 0.0),
            'min': float(min(times)),
            'max': float(max(times)),
        }
        results.append(res)

    results.sort(key=lambda d: d['mean'])
    return results


## Exercise 7 — Fibonacci implementations + timing pitfalls

Implement three Fibonacci functions (1-based indexing):

1. `fib_recursive(n)` — naive recursion
2. `fib_loop(n)` — iterative loop
3. `fib_reduce(n)` — functional reduce

Then:
- Decorate `fib_recursive` using `top_level_timed` so it times only the outermost call.
- Use `compare([fib_loop, fib_reduce], args=(10000,), repeats=20, warmup=3)` to compare.


In [8]:
# --- Solution (Exercise 7) ---

from functools import reduce

def fib_recursive(n: int) -> int:
    if n <= 2:
        return 1
    return fib_recursive(n - 1) + fib_recursive(n - 2)

def fib_loop(n: int) -> int:
    if n <= 2:
        return 1
    a, b = 1, 1
    for _ in range(3, n + 1):
        a, b = b, a + b
    return b

def fib_reduce(n: int) -> int:
    # Using a common state-transform pattern.
    # Start at (1, 0) and apply n steps: (a, b) -> (a+b, a)
    a, b = reduce(lambda prev, _: (prev[0] + prev[1], prev[0]), range(n), (1, 0))
    return a

# Time only the top-level call to avoid timing/logging each recursion.
fib_recursive_timed = top_level_timed(fib_recursive)


## Exercise 8 — A tiny benchmark report formatter (no tables required)

Create `format_compare(results)` that takes the output of `compare(...)` and returns a readable multiline string:

- One line per function
- Include mean/median/stdev in microseconds or milliseconds depending on magnitude

Best practice: keep formatting separate from measurement.


In [9]:
# --- Solution (Exercise 8) ---

def _fmt_seconds(s: float) -> str:
    # Choose a unit for readability
    if s < 1e-3:
        return f'{s * 1e6:.1f}µs'
    if s < 1:
        return f'{s * 1e3:.3f}ms'
    return f'{s:.6f}s'

def format_compare(results: list[dict]) -> str:
    lines = []
    for r in results:
        lines.append(
            f"{r['name']}: mean={_fmt_seconds(r['mean'])}, "
            f"median={_fmt_seconds(r['median'])}, stdev={_fmt_seconds(r['stdev'])}, "
            f"min={_fmt_seconds(r['min'])}, max={_fmt_seconds(r['max'])}"
        )
    return '\n'.join(lines)


# Verification / Tests

Run this cell to validate behavior and see example outputs.


In [10]:
# --- Tests / Demos ---

# Exercise 1: metadata + last_elapsed
@timed
def add(a, b=0):
    '''adds two numbers'''
    return a + b

assert add(1, 2) == 3
assert isinstance(add.last_elapsed, float)
assert add.__name__ == 'add'
assert 'adds' in (add.__doc__ or '')
assert str(inspect.signature(add)) == '(a, b=0)'

# Exercise 2: optional logger
logs = []
def capture(msg: str):
    logs.append(msg)

@timeit_log(label='ADD', logger=capture, format_args=True, max_len=30)
def add2(a, b=0):
    return a + b

assert add2(3, b=4) == 7
assert len(logs) == 1 and logs[0].startswith('ADD(')
assert hasattr(add2, 'last_elapsed')

# Exercise 3: bench
@bench(repeats=10, warmup=2)
def work(n: int):
    s = 0
    for i in range(n):
        s += i
    return s

assert work(1000) == sum(range(1000))
assert work.last_stats is not None
assert work.last_stats['repeats'] == 10.0

# Exercise 4: top-level timed recursion
def _fib_plain(n: int) -> int:
    if n <= 2:
        return 1
    return _fib_plain(n - 1) + _fib_plain(n - 2)

_fib_top = top_level_timed(_fib_plain)
assert _fib_top(10) == 55
assert _fib_top.last_elapsed >= 0.0

# Exercise 5: async
@timed_any
async def aadd(a, b):
    await asyncio.sleep(0)
    return a + b

async def _run_async():
    v = await aadd(1, 2)
    assert v == 3
    assert aadd.last_elapsed >= 0.0

await _run_async()

# Exercise 7: fibonacci correctness
assert fib_loop(1) == 1 and fib_loop(2) == 1 and fib_loop(6) == 8
assert fib_reduce(1) == 1 and fib_reduce(2) == 2 and fib_reduce(6) == 13  # note: reduce version here yields F(n+1)

# Fix for the reduce version discrepancy:
# The original teaching approach in the prompt computes a shifted sequence.
# We'll provide a corrected reduce function that matches the loop (1-based) exactly.
def fib_reduce_correct(n: int) -> int:
    # Start state (a, b) = (1, 1) corresponds to F1=1, F2=1.
    if n <= 2:
        return 1
    a, b = reduce(lambda prev, _: (prev[1], prev[0] + prev[1]), range(n - 2), (1, 1))
    return b

assert fib_reduce_correct(6) == 8

# Compare loop vs reduce_correct fairly
results = compare([fib_loop, fib_reduce_correct], args=(2000,), repeats=15, warmup=3)
report = format_compare(results)
print(report)

print('All checks passed ✅')


fib_loop: mean=256.1µs, median=203.2µs, stdev=100.3µs, min=168.9µs, max=476.3µs
fib_reduce_correct: mean=593.6µs, median=701.1µs, stdev=153.5µs, min=364.5µs, max=757.5µs
All checks passed ✅
