# Advanced Practice: Single-Dispatch Generic Functions (Decorators) — Problems + Solutions

This notebook contains **advanced problems with full solutions** on **single-dispatch generic functions** (custom + `functools.singledispatch`) and common patterns/pitfalls.

## Best practices emphasized
- Prefer **standard library** (`functools.singledispatch`, `functools.singledispatchmethod`).
- Register handlers on **ABCs** (`numbers.Integral`, `collections.abc.Mapping`, etc.) when appropriate.
- Avoid dispatch traps (e.g., `str` is a `Sequence`).
- Keep handlers **small**, **testable**, and **composable**.
- Demonstrate safe extensibility: registration from outside, temporary overrides, and controlled fallbacks.


In [1]:
from __future__ import annotations

from dataclasses import dataclass, is_dataclass, fields
from decimal import Decimal
from fractions import Fraction
from html import escape
from numbers import Integral, Real
from collections.abc import Mapping, Sequence, Set
from functools import singledispatch, singledispatchmethod, wraps, update_wrapper
from contextlib import contextmanager
import inspect

print('Ready')

Ready



## Problem 1 — Build a robust `htmlize` using `functools.singledispatch`

Implement `htmlize(x)` that returns **HTML strings** with these rules:

1. **Default**: HTML-escape `str(x)`.
2. `None` → `"<i>None</i>"`
3. `bool` → `"true"` / `"false"` (lowercase)
4. `Integral` (but not bool) → decimal + hex: `"255(<i>0xff</i>)"`
5. `Real` (but not Integral) → round to 2 decimals: `"3.14"`
6. `Decimal` → quantize/display with 2 decimals (bankers rounding is fine)
7. `str` → escape + replace newlines with `<br/>\n`
8. `Mapping` → unordered list of `key=value` where **both** key and value are rendered via `htmlize`
9. `Sequence` (lists/tuples/etc.) → unordered list of items rendered via `htmlize`
   - BUT: **must not treat `str`, `bytes`, `bytearray` as sequences** for this purpose.
10. `Set` → unordered list, but stable ordering: sort by `repr(item)`.
11. Must support **nested** structures.

Add tests that cover:
- nested list containing strings with newlines, tuples, dicts
- `True` does **not** render as hex
- `Decimal('1.005')` → `"1.00"` or `"1.01"` depending on rounding strategy (explain which you use)
- `"abc"` does not recurse infinitely

### Your task
Fill in the TODOs below.


In [2]:
# --- Problem 1 (skeleton) ---

@singledispatch
def htmlize(x):
    """Default htmlizer."""
    # TODO: implement default
    raise NotImplementedError


# TODO: register handlers for NoneType, bool, Integral, Real, Decimal, str, Mapping, Sequence, Set


def _run_problem_1_tests():
    # TODO: add tests (asserts)
    raise NotImplementedError


# Uncomment after completing
# _run_problem_1_tests()
# print('Problem 1 tests passed')

### Solution 1

Notes:
- We register `bool` **separately** so it overrides `Integral` (since `bool` is an `Integral`).
- We register `str` separately so it overrides `Sequence` (since `str` is a `Sequence`).
- `Decimal` is not a `Real`, so it gets its own handler.
- For `Decimal` rounding: we use `quantize(Decimal('0.01'))` with the current context rounding.


In [3]:
# --- Solution 1 ---

@singledispatch
def htmlize(x):
    return escape(str(x))


@htmlize.register(type(None))
def _(x):
    return "<i>None</i>"


@htmlize.register(bool)
def _(x: bool):
    return "true" if x else "false"


@htmlize.register(Integral)
def _(x: Integral):
    # bool is handled above; this is for true integrals
    return f"{int(x)}(<i>{hex(int(x))}</i>)"


@htmlize.register(Real)
def _(x: Real):
    # Integral is a Real; but Integral has its own handler and wins.
    return f"{float(x):.2f}"


@htmlize.register(Decimal)
def _(x: Decimal):
    q = x.quantize(Decimal("0.01"))
    return f"{q}"


@htmlize.register(str)
def _(s: str):
    return escape(s).replace("\n", "<br/>\n")


@htmlize.register(Mapping)
def _(m: Mapping):
    items = [f"<li>{htmlize(k)}={htmlize(v)}</li>" for k, v in m.items()]
    return "<ul>\n" + "\n".join(items) + "\n</ul>"


@htmlize.register(Set)
def _(s: Set):
    # stable ordering
    items_sorted = sorted(s, key=lambda x: repr(x))
    items = [f"<li>{htmlize(item)}</li>" for item in items_sorted]
    return "<ul>\n" + "\n".join(items) + "\n</ul>"


@htmlize.register(Sequence)
def _(seq: Sequence):
    # IMPORTANT: str/bytes/bytearray are also sequences, but should not be treated as such here.
    if isinstance(seq, (str, bytes, bytearray)):
        # Let the more specific registrations handle these.
        # (For bytes/bytearray we fall back to default unless you register them explicitly.)
        return htmlize.dispatch(object)(seq)
    items = [f"<li>{htmlize(item)}</li>" for item in seq]
    return "<ul>\n" + "\n".join(items) + "\n</ul>"


def _run_problem_1_tests():
    # Nested structure
    data = [
        "line1\nline2<>",
        (1, True, 3.14159),
        {"k": [None, Decimal("2.005"), Fraction(1, 3)]},
        {3, 1, 2},
    ]
    out = htmlize(data)
    assert "<ul>" in out and "</ul>" in out
    assert "line1<br/>" in out
    assert "&lt;&gt;" in out  # escaped

    # bool is not hex
    assert htmlize(True) == "true"

    # int is hex
    assert htmlize(255) == "255(<i>0xff</i>)"

    # float formatting
    assert htmlize(3.14159) == "3.14"

    # Decimal rounding depends on context; we accept either 1.00 or 1.01 here
    d = htmlize(Decimal("1.005"))
    assert d in {"1.00", "1.01"}

    # No infinite recursion
    assert htmlize("abc") == "abc"

    # Set stable ordering: should include all items
    s_out = htmlize({"b", "a"})
    assert "a" in s_out and "b" in s_out


_run_problem_1_tests()
print('Problem 1 tests passed')

Problem 1 tests passed


## Problem 2 — Custom `singledispatch` with best-match via MRO + caching

Write a decorator `better_singledispatch` that:

- Works like `functools.singledispatch` for **one-argument** functions.
- Supports:
  - `.register(type_)` decorator
  - `.dispatch(cls)` to get the implementation for a type
  - `.registry` exposing the underlying dict
- Picks the **most specific** match:
  1. Exact type hit
  2. First registered type found in the class `mro()`
  3. Otherwise check registered ABCs/supertypes using `issubclass(cls, registered_type)`.
     - If multiple candidates remain and none is strictly more specific than the others, raise `RuntimeError`.
- Uses a **cache** so repeated calls for the same type are fast.
- Clears cache whenever `.register(...)` is used.

Then demonstrate:
- Registering a handler for `Integral` will match both `int` and `bool` unless `bool` is registered too.
- Registering both `Sequence` and `str`, calling dispatch for `str` returns the `str` handler.


In [4]:
# --- Problem 2 (skeleton) ---

def better_singledispatch(func):
    # TODO: build registry, cache, register, dispatch, wrapper
    raise NotImplementedError


# TODO: demo + asserts
# @better_singledispatch
# def demo(x): ...


### Solution 2

This is a **teaching implementation** (not production-complete). The stdlib version handles more corner cases.

Ambiguity example (we choose to raise): if you register unrelated ABCs that both match a concrete class, and neither is more specific than the other.

In [5]:
# --- Solution 2 ---

import inspect
from functools import wraps

def better_singledispatch(func):
    registry: dict[type, callable] = {object: func}
    cache: dict[type, callable] = {}

    def _best_abc_match(cls: type) -> callable:
        candidates = [t for t in registry.keys()
                      if t is not object and issubclass(cls, t)]
        if not candidates:
            return registry[object]

        # Keep only "maximally specific" candidates (no other candidate is a strict subclass of it)
        maximal = []
        for t in candidates:
            if any((u is not t) and issubclass(u, t) for u in candidates):
                continue
            maximal.append(t)

        if len(maximal) != 1:
            names = ", ".join(sorted(t.__name__ for t in maximal))
            raise RuntimeError(f"Ambiguous dispatch for {cls.__name__}: {names}")
        return registry[maximal[0]]

    def dispatch(cls: type) -> callable:
        if cls in cache:
            return cache[cls]

        # 1) exact hit (including object only if cls is literally object)
        if cls in registry:
            impl = registry[cls]
            cache[cls] = impl
            return impl

        # 2) MRO hit, but DO NOT pick object here (object is the final fallback)
        for base in inspect.getmro(cls)[1:]:
            if base is object:
                break
            if base in registry:
                impl = registry[base]
                cache[cls] = impl
                return impl

        # 3) ABC / issubclass match (this is where Integral, Sequence, etc. get picked)
        impl = _best_abc_match(cls)
        cache[cls] = impl
        return impl

    def register(type_: type):
        def deco(impl_func):
            registry[type_] = impl_func
            cache.clear()
            return impl_func
        return deco

    @wraps(func)
    def wrapper(*args, **kwargs):
        if not args:
            raise TypeError("single-dispatch function requires at least 1 positional argument")
        impl = dispatch(args[0].__class__)
        return impl(*args, **kwargs)

    wrapper.register = register
    wrapper.dispatch = dispatch
    wrapper.registry = registry
    wrapper._cache = cache
    return wrapper



@better_singledispatch
def demo(x):
    return f"default:{type(x).__name__}"


@demo.register(Integral)
def _(x):
    return "integral"

assert demo(10) == "integral"
assert demo(True) == "integral"  # until bool is registered

@demo.register(bool)
def _(x: bool):
    return "bool"

assert demo(True) == "bool"

@demo.register(Sequence)
def _(x):
    return "sequence"

@demo.register(str)
def _(x: str):
    return "string"

assert demo.dispatch(str) is demo.registry[str]
assert demo("abc") == "string"  # str beats Sequence

print('Problem 2 tests passed')

Problem 2 tests passed


## Problem 3 — Temporary override (context manager) for dispatch behavior

Sometimes you want a **temporary** handler (e.g., within a test).

Using **your custom** `better_singledispatch` from Problem 2, implement:

```python
@contextmanager
def temporary_register(generic_func, type_, impl):
    ...
```

Requirements:
- Inside the context, `generic_func` uses `impl` for `type_`.
- Exiting restores the previous handler (or removes the registration if it did not exist).
- Must clear cache appropriately.

Add tests proving the registry is restored.


In [6]:
# --- Problem 3 (skeleton) ---

@contextmanager
def temporary_register(generic_func, type_, impl):
    # TODO
    raise NotImplementedError


# TODO tests


### Solution 3

We rely on:
- `generic_func.registry` (a mutable dict in our implementation)
- `generic_func._cache` for caching


In [7]:
# --- Solution 3 ---

@contextmanager
def temporary_register(generic_func, type_, impl):
    registry = generic_func.registry
    cache = getattr(generic_func, "_cache", None)
    had_old = type_ in registry
    old = registry.get(type_)

    registry[type_] = impl
    if cache is not None:
        cache.clear()
    try:
        yield
    finally:
        if had_old:
            registry[type_] = old
        else:
            registry.pop(type_, None)
        if cache is not None:
            cache.clear()


@better_singledispatch
def fmt(x):
    return "default"

@fmt.register(int)
def _(x):
    return "int"

assert fmt(1) == "int"
assert fmt("x") == "default"

def tmp_int(_):
    return "TEMP_INT"

with temporary_register(fmt, int, tmp_int):
    assert fmt(1) == "TEMP_INT"

assert fmt(1) == "int"  # restored

def tmp_str(_):
    return "TEMP_STR"

with temporary_register(fmt, str, tmp_str):
    assert fmt("x") == "TEMP_STR"

assert fmt("x") == "default"  # removed because it didn't exist before

print('Problem 3 tests passed')

Problem 3 tests passed


## Problem 4 — Method dispatch: fix a common pitfall

### The pitfall
If you apply `@singledispatch` directly to an **instance method**, dispatch happens on `self` (the first argument), which is almost never what you want.

### Task
- Create a `Renderer` class with a method `render(node)` that dispatches on `node`.
- Use `functools.singledispatchmethod`.

Node types:
- `Text(value: str)` → escape and replace newlines
- `Number(value: Real)` → 2 decimals
- `Items(values: list)` → `<ul>...` recursively rendering children

Add tests for nested structures.


In [8]:
# --- Problem 4 (skeleton) ---

@dataclass(frozen=True)
class Text:
    value: str

@dataclass(frozen=True)
class Number:
    value: Real

@dataclass(frozen=True)
class Items:
    values: list


class Renderer:
    # TODO: implement using @singledispatchmethod
    pass


def _run_problem_4_tests():
    # TODO
    raise NotImplementedError


# Uncomment after completing
# _run_problem_4_tests(); print('Problem 4 tests passed')

### Solution 4

In [9]:
# --- Solution 4 ---

class Renderer:
    @singledispatchmethod
    def render(self, node):
        return escape(str(node))

    @render.register
    def _(self, node: Text):
        return escape(node.value).replace("\n", "<br/>\n")

    @render.register
    def _(self, node: Number):
        return f"{float(node.value):.2f}"

    @render.register
    def _(self, node: Items):
        items = [f"<li>{self.render(child)}</li>" for child in node.values]
        return "<ul>\n" + "\n".join(items) + "\n</ul>"


def _run_problem_4_tests():
    r = Renderer()
    tree = Items([
        Text("a\n<b>"),
        Number(3.14159),
        Items([Text("inner"), Number(2)]),
    ])
    out = r.render(tree)
    assert "<ul>" in out
    assert "a<br/>" in out
    assert "&lt;b&gt;" in out
    assert "3.14" in out
    assert "2.00" in out


_run_problem_4_tests(); print('Problem 4 tests passed')

Problem 4 tests passed


## Problem 5 — Dataclasses + singledispatch: generic support without registering every class

`singledispatch` is type-based. If you have many dataclasses, registering each one is noisy.

### Task
Write a function `pretty(x)` using `functools.singledispatch` such that:
- Default: `repr(x)` escaped
- For dataclass instances: render as an HTML definition list:
  - `<dl><dt>field</dt><dd>value</dd>...</dl>`
  - where each value uses `pretty(...)` recursively
- Must also handle `Mapping` and `Sequence` as in Problem 1.

Constraint:
- You cannot register on a predicate (like `is_dataclass`). Use a clean workaround.

Hint: One common pattern is to keep default behavior and detect dataclasses inside it.


In [10]:
# --- Problem 5 (skeleton) ---

@singledispatch
def pretty(x):
    # TODO
    raise NotImplementedError


# TODO: register Mapping, Sequence, str


@dataclass(frozen=True)
class Person:
    name: str
    age: int
    tags: list[str]


def _run_problem_5_tests():
    # TODO
    raise NotImplementedError


# Uncomment after completing
# _run_problem_5_tests(); print('Problem 5 tests passed')

### Solution 5

We implement dataclass support inside the **default** handler. This keeps the API extensible without requiring registrations for every dataclass.


In [11]:
# --- Solution 5 ---

@singledispatch
def pretty(x):
    if is_dataclass(x) and not isinstance(x, type):
        parts = []
        for f in fields(x):
            val = getattr(x, f.name)
            parts.append(f"<dt>{escape(f.name)}</dt><dd>{pretty(val)}</dd>")
        return "<dl>" + "".join(parts) + "</dl>"
    return escape(repr(x))


@pretty.register(str)
def _(s: str):
    return escape(s).replace("\n", "<br/>\n")


@pretty.register(Mapping)
def _(m: Mapping):
    items = [f"<li>{pretty(k)}={pretty(v)}</li>" for k, v in m.items()]
    return "<ul>" + "".join(items) + "</ul>"


@pretty.register(Sequence)
def _(seq: Sequence):
    if isinstance(seq, (str, bytes, bytearray)):
        return pretty.dispatch(object)(seq)
    items = [f"<li>{pretty(item)}</li>" for item in seq]
    return "<ul>" + "".join(items) + "</ul>"


@dataclass(frozen=True)
class Person:
    name: str
    age: int
    tags: list[str]


def _run_problem_5_tests():
    p = Person(name="Ada\nLovelace", age=36, tags=["math", "<coder>"])
    out = pretty(p)
    assert out.startswith("<dl>") and out.endswith("</dl>")
    assert "Ada<br/>" in out
    assert "&lt;coder&gt;" in out

    nested = {"p": p, "nums": [1, 2, 3]}
    out2 = pretty(nested)
    assert "<ul>" in out2
    assert "<dl>" in out2


_run_problem_5_tests(); print('Problem 5 tests passed')

Problem 5 tests passed


## Problem 6 — Single-dispatch is *single*: simulate pair dispatch for two arguments

`functools.singledispatch` dispatches on the **first argument only**.

### Task
Implement a decorator `pairdispatch` for a function `combine(a, b)` that:
- Registers implementations by **(type(a), type(b))**.
- Picks the best match using **MRO** of each argument type.
- Provides `.register(type_a, type_b)` and `.dispatch(type_a, type_b)`.
- Default falls back to base function.

Demonstrate with:
- `combine(int, int)` adds
- `combine(str, str)` concatenates with a hyphen
- `combine(Sequence, Sequence)` returns a list concatenation (excluding str)
- `combine(object, object)` default returns a tuple


In [12]:
# --- Problem 6 (skeleton) ---

def pairdispatch(func):
    # TODO
    raise NotImplementedError


# TODO: implement combine + tests

### Solution 6

Strategy:
- For each call, consider all pairs from `mro(type(a)) × mro(type(b))`.
- Choose the first pair that is registered.
- Cache by `(type(a), type(b))`.

This yields a predictable “most specific wins” behavior.

In [13]:
# --- Solution 6 ---  (pair dispatch with MRO + ABC/issubclass support + caching)

import inspect
from functools import wraps
from collections.abc import Sequence


def pairdispatch(func):
    """
    A simple 2-argument dispatcher.

    - Register by (type(a), type(b))
    - Best match search order:
      1) concrete MROs (excluding object)
      2) registered ABC/supertypes that match via issubclass (not in MRO)
      3) (object, object) fallback
    - Cache by (type(a), type(b))
    """
    registry: dict[tuple[type, type], callable] = {(object, object): func}
    cache: dict[tuple[type, type], callable] = {}

    def register(t1: type, t2: type):
        def deco(impl):
            registry[(t1, t2)] = impl
            cache.clear()
            return impl
        return deco

    def _candidate_types(cls: type) -> list[type]:
        """
        Return ordered candidate supertypes for cls:
        - concrete mro (most specific -> less specific), excluding object
        - registered supertypes/ABCs that match issubclass and aren't already in mro
        - finally object
        """
        mro = [t for t in inspect.getmro(cls) if t is not object]

        # Collect all types ever mentioned in the registry (both sides)
        registered_types = {t for (a, b) in registry.keys() for t in (a, b)}

        # ABCs/supertypes not in mro but matching issubclass
        extra = [
            t for t in registered_types
            if t is not object and t not in mro and issubclass(cls, t)
        ]

        # Heuristic ordering for extras: "more specific" first.
        # Not perfect, but stable and practical for exercises.
        extra.sort(key=lambda t: (-len(inspect.getmro(t)), t.__name__))

        return mro + extra + [object]

    def dispatch(c1: type, c2: type) -> callable:
        key = (c1, c2)
        if key in cache:
            return cache[key]

        a_types = _candidate_types(c1)
        b_types = _candidate_types(c2)

        for ta in a_types:
            for tb in b_types:
                impl = registry.get((ta, tb))
                if impl is not None:
                    cache[key] = impl
                    return impl

        impl = registry[(object, object)]
        cache[key] = impl
        return impl

    @wraps(func)
    def wrapper(a, b, *args, **kwargs):
        impl = dispatch(a.__class__, b.__class__)
        return impl(a, b, *args, **kwargs)

    wrapper.register = register
    wrapper.dispatch = dispatch
    wrapper.registry = registry
    wrapper._cache = cache
    return wrapper


@pairdispatch
def combine(a, b):
    """Default: return a tuple."""
    return (a, b)


@combine.register(int, int)
def _(a: int, b: int):
    return a + b


@combine.register(str, str)
def _(a: str, b: str):
    return f"{a}-{b}"


@combine.register(Sequence, Sequence)
def _(a: Sequence, b: Sequence):
    # Exclude str-like objects from the generic Sequence handler.
    # If a more specific (str, str) exists, it will be chosen earlier.
    if isinstance(a, (str, bytes, bytearray)) or isinstance(b, (str, bytes, bytearray)):
        return combine.registry[(object, object)](a, b)
    return list(a) + list(b)


def _run_problem_6_tests():
    assert combine(1, 2) == 3
    assert combine("a", "b") == "a-b"
    assert combine([1], (2, 3)) == [1, 2, 3]
    assert combine("x", [1]) == ("x", [1])  # str excluded from Sequence logic here
    assert combine(object(), object()) is not None


_run_problem_6_tests()
print("Problem 6 tests passed")


Problem 6 tests passed
