# Advanced Practice: Closures

These exercises deepen your understanding of Python closures: capturing state, using `nonlocal`, avoiding late-binding bugs, and building decorator-like utilities.

ðŸ‘‰ **Instructions**
- Implement where marked `# YOUR CODE HERE`.
- Do **not** modify test cells.
- Use only the standard library.
- Be careful with mutation vs. rebinding â€” use `nonlocal` when appropriate.


In [1]:
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Tuple
from time import perf_counter


## Problem 1 â€” `make_counter`

Create `make_counter(start=0, step=1)` that returns a function with no parameters. Each call returns the **current value**, then advances by `step`. The internal state must be kept in the closure (no globals). Support negative `step` too.

In [2]:
def make_counter(start: int = 0, step: int = 1) -> Callable[[], int]:
    current = start
    def next_value() -> int:
        nonlocal current
        value = current
        current += step
        return value
    return next_value


In [3]:
# Tests â€” do not modify
c = make_counter(10, 3)
assert [c(), c(), c()] == [10, 13, 16]
c2 = make_counter(-2, -2)
assert [c2(), c2(), c2()] == [-2, -4, -6]
print("âœ… Problem 1 tests passed.")


âœ… Problem 1 tests passed.


## Problem 2 â€” `make_accumulator`

Return a closure that **adds** incoming values to an internal total and returns the running sum each call.

In [4]:
def make_accumulator(initial: float = 0.0) -> Callable[[float], float]:
    total = float(initial)
    def add(x: float) -> float:
        nonlocal total
        total += x
        return total
    return add


In [5]:
# Tests â€” do not modify
acc = make_accumulator(5)
assert [acc(1), acc(2.5), acc(-3.5)] == [6.0, 8.5, 5.0]
print("âœ… Problem 2 tests passed.")


âœ… Problem 2 tests passed.


## Problem 3 â€” `make_tag`

Create a closure `make_tag(tag)` that returns a function wrapping text as an HTML tag: `wrap("hi") -> '<tag>hi</tag>'`. Ensure the `tag` is captured correctly.

In [6]:
def make_tag(tag: str) -> Callable[[str], str]:
    open_t, close_t = f"<{tag}>", f"</{tag}>"
    def wrap(text: str) -> str:
        return f"{open_t}{text}{close_t}"
    return wrap


In [7]:
# Tests â€” do not modify
bold = make_tag('b'); ital = make_tag('i')
assert bold('x') == '<b>x</b>'
assert ital('y') == '<i>y</i>'
print("âœ… Problem 3 tests passed.")


âœ… Problem 3 tests passed.


## Problem 4 â€” `memoize`

Write `memoize(func)` that returns a **closure** caching results per `args` and `kwargs`. Use a dict for the cache. Keep it simple: assume arguments are hashable.

Tip: a cache key can be `(args, tuple(sorted(kwargs.items())))`.

In [8]:
def memoize(func: Callable[..., Any]) -> Callable[..., Any]:
    cache: Dict[Tuple[Any, ...], Any] = {}
    def keyify(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Any, ...]:
        if kwargs:
            return args + (object(),) + tuple(sorted(kwargs.items()))
        return args
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        k = keyify(args, kwargs)
        if k in cache:
            return cache[k]
        result = func(*args, **kwargs)
        cache[k] = result
        return result
    return wrapper


In [9]:
# Tests â€” do not modify
calls = {"n": 0}
def slow_add(a, b):
    calls["n"] += 1
    return a + b
m_add = memoize(slow_add)
assert m_add(2,3) == 5
assert m_add(2,3) == 5 and calls["n"] == 1
assert m_add(2,b=3) == 5 and calls["n"] == 2  # different key
print("âœ… Problem 4 tests passed.")


âœ… Problem 4 tests passed.


## Problem 5 â€” `once`

Return a closure that executes `func(*args, **kwargs)` **only on the first call**, caches the result, and returns the same result on subsequent calls without re-invoking `func`.

In [10]:
def once(func: Callable[..., Any]) -> Callable[..., Any]:
    done = False
    saved: Any = None
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        nonlocal done, saved
        if not done:
            saved = func(*args, **kwargs)
            done = True
        return saved
    return wrapper


In [11]:
# Tests â€” do not modify
counter = {"n": 0}
def build_value(x):
    counter["n"] += 1
    return x*10
f = once(build_value)
assert f(3) == 30 and counter["n"] == 1
assert f(999) == 30 and counter["n"] == 1
print("âœ… Problem 5 tests passed.")


âœ… Problem 5 tests passed.


## Problem 6 â€” Late binding trap: fix lambdas in a loop

Write `make_powers_bug(ns)` that (intentionally) returns a list of `lambda x: x**n` built in a loop over `ns` **without** capturing correctly (late binding). Then write `make_powers_fixed(ns)` that captures each `n` properly so each lambda uses its own exponent.

Hint for the fix: default argument `n=n` or an inner-closure layer.

In [12]:
def make_powers_bug(ns: Iterable[int]) -> List[Callable[[int], int]]:
    funcs = []
    for n in ns:
        funcs.append(lambda x: x ** n)  # late-binding bug on purpose
    return funcs

def make_powers_fixed(ns: Iterable[int]) -> List[Callable[[int], int]]:
    funcs: List[Callable[[int], int]] = []
    for n in ns:
        funcs.append(lambda x, n=n: x ** n)  # capture per-iteration
    return funcs


In [13]:
# Tests â€” do not modify
funcs_bug = make_powers_bug([1,2,3])
assert [f(2) for f in funcs_bug] == [8,8,8]  # all captured last n=3
funcs_ok = make_powers_fixed([1,2,3])
assert [f(2) for f in funcs_ok] == [2,4,8]
print("âœ… Problem 6 tests passed.")


âœ… Problem 6 tests passed.


## Problem 7 â€” `make_registry`

Create `make_registry()` that returns two closures `(register, get)`. 
- `register(name, func)` stores a callable under `name`.
- `get(name)` retrieves the callable or raises `KeyError`.

Keep the registry **private** inside the closure.

In [14]:
def make_registry() -> Tuple[Callable[[str, Callable[..., Any]], None],
                              Callable[[str], Callable[..., Any]]]:
    store: Dict[str, Callable[..., Any]] = {}
    def register(name: str, func: Callable[..., Any]) -> None:
        store[name] = func
    def get(name: str) -> Callable[..., Any]:
        return store[name]
    return register, get


In [15]:
# Tests â€” do not modify
reg, get = make_registry()
def add(a,b): return a+b
reg('plus', add)
assert get('plus')(2,3) == 5
try:
    get('missing')
    raise AssertionError('expected KeyError')
except KeyError:
    pass
print("âœ… Problem 7 tests passed.")


âœ… Problem 7 tests passed.


## Problem 8 â€” `time_it_closure`

Build a timing wrapper **factory** `time_it_closure(label: str = '')` that returns a closure `wrap(func)` which itself returns a function timing `func(*args, **kwargs)`.

In other words: `timed = time_it_closure('F:')(func)` â†’ `timed(...)` prints an elapsed message with the label and returns the original result.

You may print to stdout; tests will not depend on exact timing values.

In [16]:
def time_it_closure(label: str = '') -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    def make_timer(func: Callable[..., Any]) -> Callable[..., Any]:
        def inner(*args: Any, **kwargs: Any) -> Any:
            start = perf_counter()
            result = func(*args, **kwargs)
            end = perf_counter()
            if label:
                print(f"{label} elapsed: {end - start}")
            else:
                print(f"elapsed: {end - start}")
            return result
        return inner
    return make_timer


In [17]:
# Tests â€” do not modify (non-strict; just sanity)
def mul(a,b): return a*b
wrap = time_it_closure('F:')
timed_mul = wrap(mul)
res = timed_mul(2,3)
assert res == 6
print("\nâœ… Problem 8 tests passed.")


F: elapsed: 1.00000761449337e-06

âœ… Problem 8 tests passed.


## Problem 9 â€” `stateful_map`

Implement `stateful_map(func)` returning a closure `apply(x)` that:
- calls `func(x, index)` where `index` is the number of times `apply` has been called so far (starting at 0),
- returns the function's result,
- and increments the internal counter.

This demonstrates closures maintaining *multiple* free variables.

In [18]:
def stateful_map(func: Callable[[Any, int], Any]) -> Callable[[Any], Any]:
    idx = 0
    def apply(x: Any) -> Any:
        nonlocal idx
        result = func(x, idx)
        idx += 1
        return result
    return apply


In [19]:
# Tests â€” do not modify
f = stateful_map(lambda x, i: (x, i))
assert [f('a'), f('b'), f('c')] == [('a',0), ('b',1), ('c',2)]
print("âœ… Problem 9 tests passed.")


âœ… Problem 9 tests passed.
