# Advanced Practice: Class Decorators & Monkey Patching (with Solutions)

This notebook contains **advanced, interview-style problems** about:

- class decorators (including decorator factories)
- safe monkey patching
- injecting methods & descriptors at runtime
- wrapping class methods **without breaking** `@staticmethod`, `@classmethod`, and `@property`
- implementing comparison operators safely (`NotImplemented`, no infinite recursion)
- preserving metadata with `functools.wraps`

Each exercise includes a **reference solution** and **assert-based tests**.

> Best practice: run the tests after each solution cell. If an `assert` fails, you will get an exception.

In [1]:
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone
from functools import wraps
import contextlib
import inspect
import threading
from typing import Any, Callable, Mapping, MutableMapping, Optional, Type, TypeVar

T = TypeVar("T")

In [2]:
def _fixed_clock(dt: datetime) -> Callable[[], datetime]:
    """Return a clock() function that always returns dt (useful for deterministic tests)."""
    def clock() -> datetime:
        return dt
    return clock

## Exercise 1 — `@debug_info`: inject a deterministic `debug()` method

Write a **decorator factory** `debug_info(...)` that adds a `debug()` instance method to a class.

### Requirements
- Signature: `debug_info(*, clock=datetime.now, tz=timezone.utc, include_private=False, override=False)`
- The injected `debug()` returns a **list of strings** describing:
  - `time: ...` (using the provided `clock()` and converted to `tz`)
  - `class: ...` (class name)
  - `id: ...` (hex id of object)
  - then one line per attribute in **sorted order**: `name: value`
- If `include_private=False`, skip attributes whose names start with `_`.
- Must support both `__dict__` objects and `__slots__` objects.
- If the class already defines `.debug`, do not overwrite unless `override=True`.

### Stretch (optional)
- If attribute access raises, include `name: <error ...>` instead of crashing.

In [3]:
def debug_info(
    *,
    clock: Callable[[], datetime] = datetime.now,
    tz: timezone = timezone.utc,
    include_private: bool = False,
    override: bool = False,
) -> Callable[[Type[T]], Type[T]]:
    """Class decorator factory that injects a safe, deterministic debug() method."""

    def decorator(cls: Type[T]) -> Type[T]:
        if hasattr(cls, "debug") and not override:
            return cls

        def _iter_attr_names(obj: Any) -> list[str]:
            names: set[str] = set()

            # __dict__ attributes
            d = getattr(obj, "__dict__", None)
            if isinstance(d, dict):
                names.update(d.keys())

            # __slots__ attributes (supports str or iterable of str)
            slots = getattr(obj.__class__, "__slots__", ())
            if isinstance(slots, str):
                slots = (slots,)
            for s in slots or ():
                # slots may include __dict__ / __weakref__
                if s in ("__dict__", "__weakref__"):
                    continue
                names.add(s)

            # Filter private if requested
            if not include_private:
                names = {n for n in names if not n.startswith("_")}

            return sorted(names)

        def debug(self: Any) -> list[str]:
            ts = clock()
            # ensure tz-aware and converted
            if ts.tzinfo is None:
                ts = ts.replace(tzinfo=timezone.utc)
            ts = ts.astimezone(tz)

            out: list[str] = [
                f"time: {ts.isoformat()}",
                f"class: {self.__class__.__name__}",
                f"id: {hex(id(self))}",
            ]

            for name in _iter_attr_names(self):
                try:
                    value = getattr(self, name)
                    out.append(f"{name}: {value!r}")
                except Exception as ex:  # noqa: BLE001 (teaching context)
                    out.append(f"{name}: <error {type(ex).__name__}: {ex}>")

            return out

        cls.debug = debug  # type: ignore[attr-defined]
        return cls

    return decorator


# --- Tests (deterministic) ---
@debug_info(clock=_fixed_clock(datetime(2020, 1, 1, 12, 0, 0, tzinfo=timezone.utc)))
class A:
    def __init__(self) -> None:
        self.x = 1
        self._y = 2


a = A()
lines = a.debug()
assert lines[0] == "time: 2020-01-01T12:00:00+00:00"
assert "class: A" in lines[1]
assert lines[3:] == ["x: 1"]  # private skipped

@debug_info(clock=_fixed_clock(datetime(2020, 1, 1, 12, 0, 0)), include_private=True)
class Slots:
    __slots__ = ("x", "_y")
    def __init__(self) -> None:
        self.x = 10
        self._y = 20

s = Slots()
lines = s.debug()
assert lines[0].endswith("+00:00")  # naive assumed UTC, then made aware
assert "x: 10" in lines
assert "_y: 20" in lines

## Exercise 2 — `@safe_total_ordering`: derive missing comparisons safely

Implement a class decorator `safe_total_ordering` that:

- Requires the class to define `__eq__` and **at least one** of: `__lt__`, `__le__`, `__gt__`, `__ge__`.
- Adds the missing comparison dunders using **only dunder calls**, not operators, so you can inspect `NotImplemented`.
- Must return `NotImplemented` when the base operation returns `NotImplemented`.
- Must avoid infinite recursion when both objects disagree about comparison support.

### Notes
- You do **not** need to re-implement every nuance of `functools.total_ordering`, but your implementation should handle `NotImplemented` correctly.
- Use the identity rules below (pick a consistent set):
  - `a <= b` iff `a < b` or `a == b`
  - `a > b` iff `b < a`
  - `a >= b` iff `a > b` or `a == b`

In [4]:
def safe_total_ordering(cls: Type[T]) -> Type[T]:
    """A compact total-ordering decorator that respects NotImplemented.

    Requires __eq__ and at least one ordering method. Derives the rest using only
    dunder calls (no operators) to avoid accidental reflection loops.
    """
    if "__eq__" not in cls.__dict__:
        raise TypeError("safe_total_ordering requires __eq__ defined on the class")

    base = None
    for candidate in ("__lt__", "__le__", "__gt__", "__ge__"):
        if candidate in cls.__dict__:
            base = candidate
            break
    if base is None:
        raise TypeError("safe_total_ordering requires one of __lt__/__le__/__gt__/__ge__")

    def _eq(self: Any, other: Any):
        return type(self).__eq__(self, other)

    def _call(self: Any, meth: str, other: Any):
        return getattr(type(self), meth)(self, other)

    # ---- Derive __lt__ depending on base ----
    if "__lt__" not in cls.__dict__:
        if base == "__gt__":
            def __lt__(self: Any, other: Any):
                r = _call(other, "__gt__", self)  # other > self
                return r
        elif base == "__le__":
            def __lt__(self: Any, other: Any):
                r = _call(self, "__le__", other)
                if r is NotImplemented:
                    return NotImplemented
                if not r:
                    return False
                r2 = _eq(self, other)
                if r2 is NotImplemented:
                    return NotImplemented
                return not r2
        else:  # base == "__ge__"
            def __lt__(self: Any, other: Any):
                r = _call(self, "__ge__", other)
                if r is NotImplemented:
                    return NotImplemented
                if not r:
                    return True
                r2 = _eq(self, other)
                if r2 is NotImplemented:
                    return NotImplemented
                return False  # >= is True means either > or ==; in both cases not <
        cls.__lt__ = __lt__  # type: ignore[assignment]

    # ---- Derive __le__ ----
    if "__le__" not in cls.__dict__:
        def __le__(self: Any, other: Any):
            r = _call(self, "__lt__", other)
            if r is NotImplemented:
                return NotImplemented
            if r:
                return True
            r2 = _eq(self, other)
            return r2
        cls.__le__ = __le__  # type: ignore[assignment]

    # ---- Derive __gt__ ----
    if "__gt__" not in cls.__dict__:
        def __gt__(self: Any, other: Any):
            # a > b  <=>  b < a
            return _call(other, "__lt__", self)
        cls.__gt__ = __gt__  # type: ignore[assignment]

    # ---- Derive __ge__ ----
    if "__ge__" not in cls.__dict__:
        def __ge__(self: Any, other: Any):
            r = _call(self, "__gt__", other)
            if r is NotImplemented:
                return NotImplemented
            if r:
                return True
            r2 = _eq(self, other)
            return r2
        cls.__ge__ = __ge__  # type: ignore[assignment]

    return cls


# --- Tests ---
@safe_total_ordering
class ByLen:
    def __init__(self, s: str) -> None:
        self.s = s
    def __repr__(self) -> str:
        return f"ByLen({self.s!r})"
    def __eq__(self, other: Any):
        if isinstance(other, ByLen):
            return len(self.s) == len(other.s)
        return NotImplemented
    def __lt__(self, other: Any):
        if isinstance(other, ByLen):
            return len(self.s) < len(other.s)
        return NotImplemented

a, b, c = ByLen("a"), ByLen("bb"), ByLen("cc")
assert a < b
assert b > a
assert b >= c
assert not (b < c) and b == c
assert a.__lt__(123) is NotImplemented
assert a.__ge__(123) is NotImplemented

# base-method variant: only __ge__ + __eq__
@safe_total_ordering
class RevInt:
    def __init__(self, x: int) -> None:
        self.x = x
    def __eq__(self, other: Any):
        if isinstance(other, RevInt):
            return self.x == other.x
        return NotImplemented
    def __ge__(self, other: Any):
        if isinstance(other, RevInt):
            return self.x <= other.x  # reversed ordering
        return NotImplemented

r1, r2 = RevInt(1), RevInt(2)
assert r1 >= r2  # because 1 <= 2 (reversed)
assert r1 > r2   # derived
assert not (r1 < r2)

## Exercise 3 — `@freeze_new_attributes`: allow mutation, forbid *new* attributes after `__init__`

Write a class decorator `freeze_new_attributes` that:

- Wraps `__init__` so that after initialization completes, the instance is "sealed".
- After sealing, attempts to set a **new** attribute name (not already present) should raise `AttributeError`.
- Existing attributes may still be modified (this is *not* full immutability).
- Must work with both `__dict__` and `__slots__`.
- Should not break inheritance (subclasses may call `super().__init__()`).

### Hints
- Add a private flag such as `_sealed` only during setup.
- Use `object.__setattr__` to avoid recursion.

In [5]:
def freeze_new_attributes(cls: Type[T]) -> Type[T]:
    """Seal instances after __init__: existing attributes may change, new ones are forbidden.

    If the class is slots-only *without* an instance __dict__, it is already sealed by design
    (cannot add new attributes), so we leave it unchanged.
    """
    slots = getattr(cls, "__slots__", None)
    supports_instance_dict = slots is None or ("__dict__" in (slots if isinstance(slots, (tuple, list)) else (slots,)))

    if not supports_instance_dict:
        # Slots-only: instances can't gain new attributes anyway.
        return cls

    orig_init = cls.__init__

    @wraps(orig_init)
    def __init__(self, *args, **kwargs):
        # store flag in instance dict
        self.__dict__["_sealed"] = False
        try:
            orig_init(self, *args, **kwargs)
        finally:
            self.__dict__["_sealed"] = True

    def __setattr__(self, name: str, value: Any) -> None:
        if name == "_sealed":
            self.__dict__[name] = bool(value)
            return

        sealed = self.__dict__.get("_sealed", False)
        if not sealed:
            object.__setattr__(self, name, value)
            return

        # After sealing, allow if attribute already exists (dict or slot already set)
        if name in self.__dict__ or hasattr(self, name):
            object.__setattr__(self, name, value)
            return

        raise AttributeError(f"Cannot set new attribute {name!r} on sealed {self.__class__.__name__}")

    cls.__init__ = __init__  # type: ignore[assignment]
    cls.__setattr__ = __setattr__  # type: ignore[assignment]
    return cls


# --- Tests ---
@freeze_new_attributes
class User:
    def __init__(self, name: str) -> None:
        self.name = name

u = User("Ada")
u.name = "Grace"  # ok
try:
    u.age = 10
    raise AssertionError("Expected AttributeError for new attribute")
except AttributeError:
    pass


# Slots-only class: already sealed; decorator should not break it.
@freeze_new_attributes
class SlotUser:
    __slots__ = ("name",)
    def __init__(self, name: str) -> None:
        self.name = name

su = SlotUser("Ada")
su.name = "Grace"  # ok
try:
    su.age = 10  # type: ignore[attr-defined]
    raise AssertionError("Expected AttributeError for new slot")
except AttributeError:
    pass


# Slots + dict: decorator should seal dynamic attributes after init.
@freeze_new_attributes
class SlotDictUser:
    __slots__ = ("name", "__dict__")
    def __init__(self, name: str) -> None:
        self.name = name
        self.extra = 1  # allowed during init

sdu = SlotDictUser("Ada")
sdu.extra = 2  # ok (existing)
try:
    sdu.new = 99
    raise AssertionError("Expected AttributeError for new attribute post-init")
except AttributeError:
    pass

## Exercise 4 — `@log_methods`: wrap methods without breaking descriptors

Create a decorator factory `log_methods(logger=print, *, skip_dunder=True)` that wraps methods on a class so every call logs:

`<ClassName>.<method>(args..., kwargs...) -> <return repr>`

### Requirements
- Wrap **instance methods**, `@classmethod`s, and `@staticmethod`s.
- Do **not** wrap `@property` objects (leave them untouched).
- If `skip_dunder=True`, skip any attribute whose name starts and ends with `__`.
- Use `functools.wraps` so metadata (e.g., `__name__`, `__doc__`) is preserved.
- Avoid double-wrapping if the decorator is applied twice.

### Hint
- `classmethod` and `staticmethod` store the underlying function in `.__func__`.

In [6]:
def log_methods(
    logger: Callable[[str], None] = print,
    *,
    skip_dunder: bool = True,
) -> Callable[[Type[T]], Type[T]]:
    """Wrap methods so calls are logged, while preserving descriptors and metadata."""
    def decorator(cls: Type[T]) -> Type[T]:
        for name, attr in list(cls.__dict__.items()):
            if skip_dunder and name.startswith("__") and name.endswith("__"):
                continue
            if isinstance(attr, property):
                continue

            def _already_wrapped(fn: Callable[..., Any]) -> bool:
                return getattr(fn, "__logged__", False) is True

            if isinstance(attr, staticmethod):
                fn = attr.__func__
                if _already_wrapped(fn):
                    continue

                @wraps(fn)
                def wrapper(*args, __fn=fn, __name=name, **kwargs):
                    result = __fn(*args, **kwargs)
                    logger(f"{cls.__name__}.{__name}({args!r}, {kwargs!r}) -> {result!r}")
                    return result

                wrapper.__logged__ = True  # type: ignore[attr-defined]
                setattr(cls, name, staticmethod(wrapper))

            elif isinstance(attr, classmethod):
                fn = attr.__func__
                if _already_wrapped(fn):
                    continue

                @wraps(fn)
                def wrapper(*args, __fn=fn, __name=name, **kwargs):
                    result = __fn(*args, **kwargs)
                    logger(f"{cls.__name__}.{__name}({args!r}, {kwargs!r}) -> {result!r}")
                    return result

                wrapper.__logged__ = True  # type: ignore[attr-defined]
                setattr(cls, name, classmethod(wrapper))

            elif callable(attr):
                fn = attr
                if _already_wrapped(fn):
                    continue

                @wraps(fn)
                def wrapper(self, *args, __fn=fn, __name=name, **kwargs):
                    result = __fn(self, *args, **kwargs)
                    logger(f"{cls.__name__}.{__name}({args!r}, {kwargs!r}) -> {result!r}")
                    return result

                wrapper.__logged__ = True  # type: ignore[attr-defined]
                setattr(cls, name, wrapper)

        return cls

    return decorator


# --- Tests (capture logs) ---
_logs: list[str] = []

def _capture(msg: str) -> None:
    _logs.append(msg)

@log_methods(_capture)
class Calc:
    def add(self, a: int, b: int) -> int:
        return a + b

    @staticmethod
    def mul(a: int, b: int) -> int:
        return a * b

    @classmethod
    def make(cls, x: int) -> "Calc":
        inst = cls()
        inst.x = x
        return inst

    @property
    def x2(self) -> int:
        return self.x * 2  # type: ignore[attr-defined]

c = Calc()
assert c.add(1, 2) == 3
assert Calc.mul(2, 3) == 6
inst = Calc.make(10)
assert inst.x2 == 20

assert any("Calc.add" in m for m in _logs)
assert any("Calc.mul" in m for m in _logs)
assert any("Calc.make" in m for m in _logs)
assert not any("x2" in m for m in _logs)  # property not wrapped

# Metadata preserved
assert Calc.add.__name__ == "add"
assert "a" in str(inspect.signature(Calc.add))

## Exercise 5 — `patch_attr`: safe monkey patching as a context manager

Implement `patch_attr(obj, name, value)` as a context manager that temporarily sets `obj.name = value` and restores the original value on exit.

### Requirements
- Works for both classes and instances.
- If the attribute didn't exist, remove it on exit.
- Must restore even if an exception is raised inside the context.
- Do not patch dunder attributes (raise `ValueError` if `name` looks like `__dunder__`).

### Example
```python
with patch_attr(SomeClass, "method", new_method):
    ...
# old method is back
```

In [7]:
@contextlib.contextmanager
def patch_attr(obj: Any, name: str, value: Any):
    if name.startswith("__") and name.endswith("__"):
        raise ValueError("Refusing to patch dunder attributes")

    sentinel = object()
    old = getattr(obj, name, sentinel)
    had = old is not sentinel

    setattr(obj, name, value)
    try:
        yield
    finally:
        if had:
            setattr(obj, name, old)
        else:
            # attribute was created by us; remove it
            try:
                delattr(obj, name)
            except AttributeError:
                pass


# --- Tests ---
class Greeter:
    def hi(self) -> str:
        return "hi"

g = Greeter()

def new_hi(self) -> str:
    return "HELLO"

assert g.hi() == "hi"
with patch_attr(Greeter, "hi", new_hi):
    assert g.hi() == "HELLO"
assert g.hi() == "hi"

with patch_attr(Greeter, "new_attr", 123):
    assert Greeter.new_attr == 123  # type: ignore[attr-defined]
assert not hasattr(Greeter, "new_attr")

## Exercise 6 — `@auto_repr`: generate `__repr__` from `__init__` signature

Write a class decorator `auto_repr` that generates `__repr__` if the class doesn't already define one.

### Requirements
- Use `inspect.signature(cls.__init__)` to determine parameter names (excluding `self`).
- The generated repr should look like: `ClassName(x=..., y=...)` in the same parameter order.
- Fetch values from attributes with the same names (e.g., `self.x` for parameter `x`).
- If an attribute is missing, show `<missing>`.
- Must preserve an explicitly defined `__repr__` (do nothing if present in `cls.__dict__`).

In [8]:
def auto_repr(cls: Type[T]) -> Type[T]:
    """Generate a repr based on __init__ parameters, unless __repr__ is already defined."""
    if "__repr__" in cls.__dict__:
        return cls

    sig = inspect.signature(cls.__init__)
    param_names = [p.name for p in sig.parameters.values() if p.name != "self"]
    _MISSING = object()

    def __repr__(self) -> str:
        parts = []
        for n in param_names:
            v = getattr(self, n, _MISSING)
            if v is _MISSING:
                parts.append(f"{n}=<missing>")
            else:
                parts.append(f"{n}={v!r}")
        return f"{self.__class__.__name__}({', '.join(parts)})"

    cls.__repr__ = __repr__  # type: ignore[assignment]
    return cls


# --- Tests ---
@auto_repr
class Pair:
    def __init__(self, x: int, y: int = 0) -> None:
        self.x = x
        self.y = y

p = Pair(1, 2)
assert repr(p) == "Pair(x=1, y=2)"

@auto_repr
class Weird:
    def __init__(self, x: int, y: int) -> None:
        self.x = x  # forget y

w = Weird(1, 2)
assert "y=<missing>" in repr(w)

# Ensure manual __repr__ is respected
class Manual:
    def __init__(self, x: int) -> None:
        self.x = x
    def __repr__(self) -> str:
        return "MANUAL"
Manual2 = auto_repr(Manual)
assert repr(Manual2(1)) == "MANUAL"

## Exercise 7 — `@validate_init`: type-check selected `__init__` arguments using signature binding

Create a decorator factory `validate_init(spec)` where `spec` is a mapping of parameter name -> expected type or tuple of types.

### Requirements
- Works as a **class decorator**: `@validate_init({"x": int, "name": str})`.
- Wrap `__init__` and use `inspect.signature(...).bind_partial(...)` (or `bind`) to map args/kwargs to parameter names.
- If a specified parameter is provided and its value is not an instance of the expected type(s), raise `TypeError`
  with a message containing the parameter name.
- If parameter is omitted (uses default), do not error.
- Preserve `__init__` metadata (`wraps`).

### Note
This is runtime validation (not static typing).

In [9]:
def validate_init(spec: Mapping[str, type | tuple[type, ...]]) -> Callable[[Type[T]], Type[T]]:
    def decorator(cls: Type[T]) -> Type[T]:
        orig_init = cls.__init__
        sig = inspect.signature(orig_init)

        @wraps(orig_init)
        def __init__(self, *args, **kwargs):
            bound = sig.bind_partial(self, *args, **kwargs)
            # drop self
            arguments = {k: v for k, v in bound.arguments.items() if k != "self"}

            for name, expected in spec.items():
                if name in arguments:
                    val = arguments[name]
                    if not isinstance(val, expected):
                        exp_name = (
                            expected.__name__
                            if isinstance(expected, type)
                            else " or ".join(t.__name__ for t in expected)
                        )
                        raise TypeError(f"__init__ arg {name!r} must be {exp_name}, got {type(val).__name__}")
            orig_init(self, *args, **kwargs)

        cls.__init__ = __init__  # type: ignore[assignment]
        return cls
    return decorator


# --- Tests ---
@validate_init({"x": int, "name": (str, type(None))})
class Thing:
    def __init__(self, x: int, name: Optional[str] = None) -> None:
        self.x = x
        self.name = name

Thing(1)
Thing(2, "ok")
try:
    Thing("bad")  # type: ignore[arg-type]
    raise AssertionError("Expected TypeError")
except TypeError as e:
    assert "x" in str(e)

## Exercise 8 — `@register`: build a plugin registry with a class decorator

Implement a decorator factory `register(registry, *, key=None)` that:

- Adds the class to a mutable mapping `registry` under `key` (or `cls.__name__` if key is `None`).
- Raises `KeyError` if the key is already taken.
- Returns the class unchanged (so normal instantiation works).

### Why this matters
Class decorators are commonly used to build plugin systems without requiring manual registration calls.

In [10]:
def register(registry: MutableMapping[str, Type[Any]], *, key: Optional[str] = None):
    def decorator(cls: Type[Any]) -> Type[Any]:
        k = key or cls.__name__
        if k in registry:
            raise KeyError(f"Duplicate registration for key {k!r}")
        registry[k] = cls
        return cls
    return decorator


# --- Tests ---
REG: dict[str, Type[Any]] = {}

@register(REG)
class CSVParser:
    pass

@register(REG, key="json")
class JSONParser:
    pass

assert REG["CSVParser"] is CSVParser
assert REG["json"] is JSONParser

try:
    @register(REG, key="json")
    class Other:  # noqa: F811
        pass
    raise AssertionError("Expected KeyError")
except KeyError:
    pass

## Exercise 9 — `@singleton`: enforce exactly one instance (thread-safe)

Write a class decorator `singleton` that turns a class into a singleton:

- Calling `Cls(...)` always returns the same instance.
- `__init__` should run only once (first call).
- Must be **thread-safe** (use a lock).
- Do not break `isinstance(obj, Cls)`.

### Hint
A common pattern is to override `__new__` and guard initialization with a flag.

In [11]:
def singleton(cls: Type[T]) -> Type[T]:
    lock = threading.Lock()
    orig_new = cls.__new__
    orig_init = cls.__init__
    instance_attr = "_singleton_instance"
    init_attr = "_singleton_initialized"

    @wraps(orig_new)
    def __new__(c, *args, **kwargs):  # type: ignore[override]
        inst = getattr(c, instance_attr, None)
        if inst is not None:
            return inst
        with lock:
            inst = getattr(c, instance_attr, None)
            if inst is None:
                inst = orig_new(c)  # type: ignore[misc]
                setattr(c, instance_attr, inst)
        return inst

    @wraps(orig_init)
    def __init__(self, *args, **kwargs):  # type: ignore[override]
        if getattr(self, init_attr, False):
            return
        with lock:
            if getattr(self, init_attr, False):
                return
            orig_init(self, *args, **kwargs)
            setattr(self, init_attr, True)

    cls.__new__ = __new__  # type: ignore[assignment]
    cls.__init__ = __init__  # type: ignore[assignment]
    return cls


# --- Tests ---
@singleton
class Config:
    def __init__(self, value: int) -> None:
        self.value = value

c1 = Config(1)
c2 = Config(2)
assert c1 is c2
assert c2.value == 1  # __init__ ran only once
assert isinstance(c1, Config)