# Advanced Practice: Passing and Returning Functions

These exercises help you build skill with higher-order functions â€” functions that take or return other functions â€” using clean, production-ready patterns:

- currying/uncurrying
- function composition & pipelines
- memoization with full `*args/**kwargs` support
- retry wrappers using closures
- decorator factories (e.g., timing, tagging)
- predicate combinators (`all_of`, `any_of`, `negate`)
- stateful closures (counters)

ðŸ‘‰ **Instructions**
- Implement where marked `# YOUR CODE HERE`.
- Do **not** change the test cells.
- Use only the standard library.


In [1]:
from __future__ import annotations
from typing import Callable, Any, Tuple, Dict, Iterable, List
from functools import wraps
import time


## Problem 1 â€” Currying and Uncurrying (binary functions)

Implement `curry2(func)` that turns a binary function `f(a,b)` into a curried version `f(a)(b)`; and `uncurry2(func)` that turns a curried `g(a)(b)` back into a binary `g(a,b)`.

**Notes**
- Preserve metadata with `functools.wraps`.
- The curried function should accept exactly one argument per call.


In [2]:
def curry2(func: Callable[[Any, Any], Any]) -> Callable[[Any], Callable[[Any], Any]]:
    """Curry a binary function: f(a,b) -> f(a)(b)."""
    @wraps(func)
    def curried(a):
        @wraps(func)
        def inner(b):
            return func(a, b)
        return inner
    return curried

def uncurry2(func: Callable[[Any], Callable[[Any], Any]]) -> Callable[[Any, Any], Any]:
    """Uncurry a curried binary function: g(a)(b) -> g(a,b)."""
    @wraps(func)
    def uncurried(a, b):
        return func(a)(b)
    return uncurried


In [3]:
# Tests â€” do not modify
def add(a, b):
    return a + b
c_add = curry2(add)
assert c_add(2)(3) == 5
u_add = uncurry2(c_add)
assert u_add(10, 7) == 17
print("âœ… Problem 1 tests passed.")


âœ… Problem 1 tests passed.


## Problem 2 â€” Composition and Pipelines

Implement two helpers:
- `compose(*funcs)` returns a function that applies from **right to left** (`compose(f,g,h)(x) == f(g(h(x)))`).
- `pipe(*funcs)` returns a function that applies from **left to right** (`pipe(f,g,h)(x) == h(g(f(x)))`).

Both should handle single-arg functions and preserve metadata on the outer wrapper.


In [4]:
def compose(*funcs: Callable[[Any], Any]) -> Callable[[Any], Any]:
    """Compose functions rightâ†’left."""
    if not funcs:
        raise ValueError("compose requires at least one function")
    def composed(x):
        val = x
        for f in reversed(funcs):
            val = f(val)
        return val
    return composed

def pipe(*funcs: Callable[[Any], Any]) -> Callable[[Any], Any]:
    """Compose functions leftâ†’right."""
    if not funcs:
        raise ValueError("pipe requires at least one function")
    def piped(x):
        val = x
        for f in funcs:
            val = f(val)
        return val
    return piped


In [5]:
# Tests â€” do not modify
inc = lambda x: x + 1
dbl = lambda x: x * 2
sq = lambda x: x * x
assert compose(sq, dbl, inc)(3) == sq(dbl(inc(3))) == 64
assert pipe(inc, dbl, sq)(3) == sq(dbl(inc(3))) == 64
print("âœ… Problem 2 tests passed.")


âœ… Problem 2 tests passed.


## Problem 3 â€” Memoization (decorator)

Implement `memoize(func)` that caches results **by full call signature** (`args` and hashable `kwargs`).

**Rules**
- Use a dictionary with keys `(args, frozenset(kwargs.items()))`.
- Provide `cache_clear()` and `cache_info()` on the wrapped function.
- Preserve function metadata with `@wraps`.


In [6]:
def memoize(func: Callable[..., Any]) -> Callable[..., Any]:
    cache: Dict[Tuple[Tuple[Any, ...], frozenset], Any] = {}
    hits = misses = 0
    @wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal hits, misses
        key = (args, frozenset(kwargs.items()))
        if key in cache:
            hits += 1
            return cache[key]
        misses += 1
        res = func(*args, **kwargs)
        cache[key] = res
        return res
    def cache_clear():
        nonlocal hits, misses
        cache.clear()
        hits = misses = 0
    def cache_info():
        return {"size": len(cache), "hits": hits, "misses": misses}
    wrapper.cache_clear = cache_clear  # type: ignore[attr-defined]
    wrapper.cache_info = cache_info    # type: ignore[attr-defined]
    return wrapper


In [7]:
# Tests â€” do not modify
calls = {"n": 0}
@memoize
def slow_add(a, b=0):
    calls["n"] += 1
    time.sleep(0.001)
    return a + b
assert slow_add(2, b=3) == 5
assert slow_add(2, b=3) == 5  # cached
info = slow_add.cache_info()
assert info["hits"] == 1 and info["misses"] == 1
slow_add.cache_clear()
assert slow_add.cache_info()["size"] == 0
print("âœ… Problem 3 tests passed.")


âœ… Problem 3 tests passed.


## Problem 4 â€” Retry wrapper (closure)

Implement `with_retries(func, retries=3, delay=0.0)` that returns a **new function** wrapping `func`. On exception, it retries up to `retries` times, sleeping `delay` seconds between attempts. If all attempts fail, re-raise the last exception.

Use a closure to capture `retries` and `delay`.


In [8]:
def with_retries(func: Callable[..., Any], retries: int = 3, delay: float = 0.0) -> Callable[..., Any]:
    @wraps(func)
    def wrapped(*args, **kwargs):
        last_exc = None
        for attempt in range(retries + 1):
            try:
                return func(*args, **kwargs)
            except Exception as ex:  # noqa: BLE001 (teaching example)
                last_exc = ex
                if attempt < retries:
                    if delay > 0:
                        time.sleep(delay)
                    continue
                raise last_exc
    return wrapped


In [9]:
# Tests â€” do not modify
def flaky_factory(failures: int):
    state = {"left": failures}
    def fn(x):
        if state["left"] > 0:
            state["left"] -= 1
            raise RuntimeError("boom")
        return x * 2
    return fn

f = with_retries(flaky_factory(2), retries=2)
assert f(5) == 10
try:
    with_retries(flaky_factory(3), retries=2)(1)
    raise AssertionError("expected RuntimeError")
except RuntimeError:
    pass
print("âœ… Problem 4 tests passed.")


âœ… Problem 4 tests passed.


## Problem 5 â€” Timing decorator **factory**

Implement `timeit_logger(print_fn=print)` which returns a decorator. The decorator wraps a function, measures elapsed wall time (seconds), calls `print_fn(f"{func.__name__}: {elapsed}")`, and returns the underlying result.

Preserve metadata and avoid swallowing exceptions.


In [10]:
def timeit_logger(print_fn: Callable[[str], None] = print) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time.perf_counter()
            try:
                return func(*args, **kwargs)
            finally:
                elapsed = time.perf_counter() - start
                print_fn(f"{func.__name__}: {elapsed:.6f}s")
        return wrapper
    return decorator


In [11]:
# Tests â€” do not modify
messages: List[str] = []
def capture(msg: str):
    # Keep only function name and ms-ish for a loose check
    if msg.startswith("sleepy:"):
        print("sleepy: 0.0100")  # deterministic print for the notebook
    messages.append(msg)

@timeit_logger(print_fn=capture)
def sleepy():
    time.sleep(0.01)
    return 42

assert sleepy() == 42
assert any(m.startswith("sleepy:") for m in messages)
print("âœ… Problem 5 tests passed.")


sleepy: 0.0100
âœ… Problem 5 tests passed.


## Problem 6 â€” Predicate combinators

Implement three higher-order predicate helpers:
- `negate(p)` â†’ returns a predicate that negates `p(x)`
- `all_of(*preds)` â†’ returns a predicate true iff **all** predicates are true
- `any_of(*preds)` â†’ returns a predicate true iff **any** predicate is true

They should accept functions of one argument and return a function of one argument.


In [12]:
def negate(p: Callable[[Any], bool]) -> Callable[[Any], bool]:
    return lambda x: not p(x)

def all_of(*preds: Callable[[Any], bool]) -> Callable[[Any], bool]:
    def combined(x):
        return all(p(x) for p in preds)
    return combined

def any_of(*preds: Callable[[Any], bool]) -> Callable[[Any], bool]:
    def combined(x):
        return any(p(x) for p in preds)
    return combined


In [13]:
# Tests â€” do not modify
is_even = lambda x: x % 2 == 0
is_pos  = lambda x: x > 0
is_odd  = negate(is_even)
assert is_odd(3) and not is_odd(4)
pos_even = all_of(is_pos, is_even)
assert pos_even(2) and not pos_even(-2) and not pos_even(3)
pos_or_even = any_of(is_pos, is_even)
assert pos_or_even(-2) and pos_or_even(1) and not pos_or_even( -3 )
print("âœ… Problem 6 tests passed.")


âœ… Problem 6 tests passed.


## Problem 7 â€” Stateful counter (closure)

Implement `make_counter(start=0, step=1)` that returns a **callable** `counter()` which, when invoked, increments internal state by `step` and returns the **current value**.

Also expose methods on the callable: `reset(value=start)` and `value()` to read the current count **without** incrementing.


In [14]:
def make_counter(start: int = 0, step: int = 1) -> Callable[[], int]:
    curr = start
    def counter():
        nonlocal curr
        curr += step
        return curr
    def reset(value: int = start):
        nonlocal curr
        curr = value
    def value():
        return curr
    counter.reset = reset  # type: ignore[attr-defined]
    counter.value = value  # type: ignore[attr-defined]
    return counter


In [15]:
# Tests â€” do not modify
c = make_counter(start=10, step=2)
assert c.value() == 10
assert c() == 12
assert c() == 14
c.reset()
assert c.value() == 10
c.reset(5)
assert c() == 7
print("âœ… Problem 7 tests passed.")


âœ… Problem 7 tests passed.


## Problem 8 â€” Tagging decorator factory (returns a function)

Implement `tagged(prefix)` which returns a decorator. The decorator wraps any function so that calling it returns a **tuple** `(prefix, result)`. Preserve metadata with `@wraps` and support arbitrary `*args/**kwargs`.


In [16]:
def tagged(prefix: str) -> Callable[[Callable[..., Any]], Callable[..., Tuple[str, Any]]]:
    def deco(func: Callable[..., Any]) -> Callable[..., Tuple[str, Any]]:
        @wraps(func)
        def wrapper(*args, **kwargs) -> Tuple[str, Any]:
            return (prefix, func(*args, **kwargs))
        return wrapper
    return deco


In [17]:
# Tests â€” do not modify
@tagged("OK")
def greet(name: str, *, excited: bool = False) -> str:
    return f"Hello, {name}{'!' if excited else ''}"
p, msg = greet("Python", excited=True)
assert p == "OK" and msg == "Hello, Python!"
assert greet.__name__ == "greet"  # metadata preserved
print("âœ… Problem 8 tests passed.")


âœ… Problem 8 tests passed.
