In [None]:
# | default_exp _components.basics

In [None]:
# | export

import builtins
import copy as cp
import functools
import types

from types import FunctionType, MethodType, UnionType
from typing import Union, TypeVar, Any, Callable, Type
from functools import partial

# Fastcore dependencies

In [None]:
# |exporti


def test_eq(a: Any, b: Any) -> None:
    "`test` that `a==b`"
    if a != b:
        raise ValueError(f"{a} != {b}")

## Patching

> copied from https://github.com/fastai/fastcore/blob/master/nbs/01_basics.ipynb

In [None]:
# |exporti
F = TypeVar("F", bound=Callable[..., Any])


def copy_func(f: FunctionType) -> FunctionType:
    "Copy a non-builtin function (NB `copy.copy` does not work for this)"
    if not isinstance(f, FunctionType):
        return cp.copy(f)
    fn = FunctionType(
        f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__
    )
    fn.__kwdefaults__ = f.__kwdefaults__
    fn.__dict__.update(f.__dict__)
    fn.__annotations__.update(f.__annotations__)
    fn.__qualname__ = f.__qualname__
    return fn

In [None]:
def foo():
    """ Test doc """
    pass

c = copy_func(foo)

In [None]:
def foo():
    pass


a = cp.copy(foo)
b = cp.deepcopy(foo)

a.someattr = "hello"  # since a and b point at the same object, updating a will update b
test_eq(b.someattr, "hello")

assert a is foo and b is foo

However, with copy_func, you can retrieve a copy of a function without a reference to the original object:

In [None]:
c = copy_func(foo)  # c is an indpendent object
assert c is not foo

In [None]:
def g(x, *, y=3):
    return x + y


test_eq(copy_func(g)(4), 7)

In [None]:
# |exporti


def patch_to(cls: Type, as_prop:bool=False, cls_method:bool=False)->Callable[[F], F]:
    "Decorator: add `f` to `cls`"
    if not isinstance(cls, (tuple, list)):
        cls = (cls,)

    def _inner(f: F)->F:
        for c_ in cls:
            nf = copy_func(f)
            nm = f.__name__
            # `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually
            for o in functools.WRAPPER_ASSIGNMENTS:
                setattr(nf, o, getattr(f, o))
            nf.__qualname__ = f"{c_.__name__}.{nm}"
            if cls_method:
                setattr(c_, nm, MethodType(nf, c_))
            else:
                setattr(c_, nm, property(nf) if as_prop else nf)
        # Avoid clobbering existing functions
        existing_func = globals().get(nm, builtins.__dict__.get(nm, None))
        return existing_func

    return _inner

In [None]:
class _T3(int):
    pass

@patch_to(_T3)
def foo(self):
    """ Test doc """
    pass

assert _T3(1).foo.__doc__ == """ Test doc """, foo.__doc__

     
The @patch_to decorator allows you to monkey patch a function into a class as a method:

In [None]:
class _T3(int):
    pass


@patch_to(_T3)
def func1(self, a):
    return self + a


t = _T3(1)  # we initilized `t` to a type int = 1
test_eq(t.func1(2), 3)  # we add 2 to `t`, so 2 + 1 = 3

     
You can access instance properties in the usual way via self:

In [None]:
class _T4:
    def __init__(self, g):
        self.g = g


@patch_to(_T4)
def greet(self, x):
    return self.g + x


t = _T4("hello ")  # this sets self.g = 'helllo '
test_eq(
    t.greet("world"), "hello world"
)  # t.greet('world') will append 'world' to 'hello '

     
You can instead specify that the method should be a class method by setting cls_method=True:

In [None]:
class _T5(int):
    attr = 3  # attr is a class attribute we will access in a later method


@patch_to(_T5, cls_method=True)
def func(cls, x):
    return cls.attr + x  # you can access class attributes in the normal way


test_eq(_T5.func(4), 7)

In [None]:
# Additionally you can specify that the function you want to patch should be a class attribute with as_prop=True:


@patch_to(_T5, as_prop=True)
def add_ten(self):
    return self + 10


t = _T5(4)
test_eq(t.add_ten, 14)

     
Instead of passing one class to the @patch_to decorator, you can pass multiple classes in a tuple to simulteanously patch more than one class with the same method:

In [None]:
class _T6(int):
    pass


class _T7(int):
    pass


@patch_to((_T6, _T7))
def func_mult(self, a):
    return self * a


t = _T6(2)
test_eq(t.func_mult(4), 8)
t = _T7(2)
test_eq(t.func_mult(4), 8)

In [None]:
# | exporti


def eval_type(t, glb, loc):
    "`eval` a type or collection of types, if needed, for annotations in py3.10+"
    if isinstance(t, str):
        if "|" in t:
            return Union[eval_type(tuple(t.split("|")), glb, loc)]
        return eval(t, glb, loc)
    if isinstance(t, (tuple, list)):
        return type(t)([eval_type(c, glb, loc) for c in t])
    return t


def union2tuple(t):
    if getattr(t, "__origin__", None) is Union or (
        UnionType and isinstance(t, UnionType)
    ):
        return t.__args__
    return t


def get_annotations_ex(obj, *, globals=None, locals=None):
    "Backport of py3.10 `get_annotations` that returns globals/locals"
    if isinstance(obj, type):
        obj_dict = getattr(obj, "__dict__", None)
        if obj_dict and hasattr(obj_dict, "get"):
            ann = obj_dict.get("__annotations__", None)
            if isinstance(ann, types.GetSetDescriptorType):
                ann = None
        else:
            ann = None

        obj_globals = None
        module_name = getattr(obj, "__module__", None)
        if module_name:
            module = sys.modules.get(module_name, None)
            if module:
                obj_globals = getattr(module, "__dict__", None)
        obj_locals = dict(vars(obj))
        unwrap = obj
    elif isinstance(obj, types.ModuleType):
        ann = getattr(obj, "__annotations__", None)
        obj_globals = getattr(obj, "__dict__")
        obj_locals, unwrap = None, None
    elif callable(obj):
        ann = getattr(obj, "__annotations__", None)
        obj_globals = getattr(obj, "__globals__", None)
        obj_locals, unwrap = None, obj
    else:
        raise TypeError(f"{obj!r} is not a module, class, or callable.")

    if ann is None:
        ann = {}
    if not isinstance(ann, dict):
        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
    if not ann:
        ann = {}

    if unwrap is not None:
        while True:
            if hasattr(unwrap, "__wrapped__"):
                unwrap = unwrap.__wrapped__
                continue
            if isinstance(unwrap, functools.partial):
                unwrap = unwrap.func
                continue
            break
        if hasattr(unwrap, "__globals__"):
            obj_globals = unwrap.__globals__

    if globals is None:
        globals = obj_globals
    if locals is None:
        locals = obj_locals

    return dict(ann), globals, locals

In [None]:
# | export


def patch(f: F = None, *, as_prop: bool = False, cls_method: bool = False) -> F:
    "Decorator: add `f` to the first parameter's class (based on f's type annotations)"
    if f is None:
        return partial(patch, as_prop=as_prop, cls_method=cls_method)
    ann, glb, loc = get_annotations_ex(f)
    cls = union2tuple(
        eval_type(ann.pop("cls") if cls_method else next(iter(ann.values())), glb, loc)
    )
    return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)

In [None]:
class _T8(int):
    pass


@patch
def func(self: _T8, a):
    """ Test doc """
    return self + a

assert _T8().func.__doc__ == """ Test doc """, func.__doc__

---

### _T8.func

>      _T8.func (a)

Test doc

     
@patch is an alternative to @patch_to that allows you similarly monkey patch class(es) by using type annotations:

In [None]:
class _T8(int):
    pass


@patch
def func(self: _T8, a):
    return self + a


t = _T8(1)  # we initilized `t` to a type int = 1
test_eq(t.func(3), 4)  # we add 3 to `t`, so 3 + 1 = 4
test_eq(t.func.__qualname__, "_T8.func")

     
Similarly to patch_to, you can supply a union of classes instead of a single class in your type annotations to patch multiple classes:

In [None]:
class _T9(int):
    pass


@patch
def func2(x: _T8 | _T9, a):
    return x * a  # will patch both _T8 and _T9


t = _T8(2)
test_eq(t.func2(4), 8)
test_eq(t.func2.__qualname__, "_T8.func2")

t = _T9(2)
test_eq(t.func2(4), 8)
test_eq(t.func2.__qualname__, "_T9.func2")

     
Just like patch_to decorator you can use as_prop and cls_method parameters with patch decorator:

In [None]:
@patch(as_prop=True)
def add_ten(self: _T5):
    return self + 10


t = _T5(4)
test_eq(t.add_ten, 14)

In [None]:
class _T5(int):
    attr = 3  # attr is a class attribute we will access in a later method


@patch(cls_method=True)
def func(cls: _T5, x):
    return cls.attr + x  # you can access class attributes in the normal way


test_eq(_T5.func(4), 7)

In [None]:
print("ok")