In [1]:
#|default_exp dispatch

In [2]:
#|export
from __future__ import annotations
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *

from collections import defaultdict

from plum.function import Function
from plum.signature import Signature
from plum import NotFoundLookupError, AmbiguousLookupError

In [3]:
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

## Utilities

In [4]:
#|export
# TODO(Rens): find better spot for this?
def _get_name(f):
    """Get the name of a function or callable object"""
    return getattr(f, '__name__', getattr(f.__class__, '__name__', str(f)))

# TypeDispatch

> Basic single and dual parameter dispatch

Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies.  This is a prominent feature in some  programming languages like Julia.  For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:

```julia
collide_with(x::Asteroid, y::Asteroid) = ... 
# deal with asteroid hitting asteroid

collide_with(x::Asteroid, y::Spaceship) = ... 
# deal with asteroid hitting spaceship

collide_with(x::Spaceship, y::Asteroid) = ... 
# deal with spaceship hitting asteroid

collide_with(x::Spaceship, y::Spaceship) = ... 
# deal with spaceship hitting spaceship
```

Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.

The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions,  which ensures that the proper function is called when passed inputs.

In [5]:
#|export
# TODO(Rens): Add docs
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, funcs=(), bases=()):
        self.func = None
        self.bases = [b for b in (bases if isinstance(bases, Iterable) else (bases,)) 
                    if b is not None]  # Filter out None bases
        self.inst = None
        self.owner = None
        
        # Try to get func from bases if we don't have one
        if not funcs and self.bases:
            for base in self.bases:
                if base.func is not None:
                    self.func = base.func
                    break
        
        funcs = funcs if isinstance(funcs, Iterable) else (funcs,)
        for f in funcs: self.add(f)     

    def add(self, f):
        "Add type `t` and function `f`"
        if not self.func:
            # TODO(Rens): extract to separate func?
            # Wrap f in a function that has __name__
            if not hasattr(f, '__name__'):
                orig_f = f
                def wrapped(*args, **kwargs): return orig_f(*args, **kwargs)
                wrapped.__name__ = _get_name(f)
                f = wrapped
            self.func = Function(f)
        self.func.dispatch(f)
    
    def __call__(self, *args, **kwargs):
        # TODO(Rens): handle staticfunctions/classmethods
        if (self.inst or self.owner):
            # Add the proper instance/owner as first arg
            args = ((self.inst or self.owner),) + args
            
        # Get types for dispatch
        # Get types first, like original implementation
        # TODO(Rens): remove this!?
        ts = tuple(type(a) for a in args[:2])
        
        # Find matching function
        f = self[ts]
        if not f: return args[1] if self.inst or self.owner else args[0]
        return f(*args, **kwargs)
    
    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        if not self.func: return None
        if self.func: self.func._resolve_pending_registrations()

        # If a single Signature is passed, use it directly
        if not isinstance(k,Signature):
            k = Signature(*k) if isinstance(k, Iterable) else Signature(k)
            
        try:
            return self.func._resolver.resolve(k).implementation
        except NotFoundLookupError:
            pass

        for base in self.bases:
            if (res := base[k]): return res

        return None
        
    def __get__(self, instance, owner):
        self.inst = instance
        self.owner = owner
        return self
        
    def returns(self, x):
        "Get the return type of annotation of `x`."
        return anno_ret(self[type(x)])

To demonstrate how `TypeDispatch` works, we define a set of functions that accept a variety of input types, specified with different type annotations:

In [6]:
def f2(x:int, y:float): return x+y              #int and float for 2nd arg
def f_nin(x:numbers.Integral)->int:  return x+1 #integral numeric
def f_ni2(x:int): return x                      #integer
def f_sol(x:str|list): return x                 #str or list
def f_num(x:numbers.Number): return x           #Number (root of numerics)          

We can optionally initialize `TypeDispatch` with a list of functions we want to search.  Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:

In [7]:
t = TypeDispatch([f_nin,f_ni2,f_num,f_sol,None])
t

<__main__.TypeDispatch at 0x10df53c20>

Note that only the first two arguments are used for `TypeDispatch`.  If your function only contains one argument, the second parameter will be shown as `object`.  If you pass `None` into `TypeDispatch`, then this will be displayed as `(object, object) -> NoneType`.

`TypeDispatch` is a dictionary-like object, which means that you can retrieve a function by the associated type annotation.  For example, the statement:

```py
t[float]
```
Will return `f_num` because that is the matching function that has a type annotation that is a super-class of of `float` - `numbers.Number`:

In [8]:
assert issubclass(float, numbers.Number)
test_eq(t[float], f_num)

The same is true for other types as well:

In [9]:
test_eq(t[np.int32], f_nin)
test_eq(t[str], f_sol)
test_eq(t[list], f_sol)
test_eq(t[np.int32], f_nin)

If you try to get a type that doesn't match, `TypeDispatch` will return `None`:

In [10]:
test_eq(t[tuple], None)

In [11]:
show_doc(TypeDispatch.add)

---

[source](https://github.com/AnswerDotAI/fasttransform/blob/main/fasttransform/dispatch.py#L48){target="_blank" style="float:right; font-size:smaller"}

### TypeDispatch.add

>      TypeDispatch.add (f)

*Add type `t` and function `f`*

This method allows you to add an additional function to an existing `TypeDispatch` instance :

In [12]:
def f_col(x:typing.Collection): return x
t.add(f_col)
test_eq(t[dict], f_col)
t

<__main__.TypeDispatch at 0x10df53c20>

If you accidentally add the same function more than once things will still work as expected:

In [13]:
t.add(f_ni2) 
test_eq(t[int], f_ni2)

However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:

In [14]:
def f_ni3(z:int): return z # collides with f_ni2 with same type annotations
t.add(f_ni3) 
test_eq(t[int], f_ni3)

#### Using `bases`:

The argument `bases` can optionally accept a single instance of `TypeDispatch` or a collection (i.e. a tuple or list) of `TypeDispatch` objects.  This can provide functionality similar to multiple inheritance. 

These are searched for matching functions if no match in your list of functions:

In [15]:
def f_col(x:typing.Collection): return x+'1'

t = TypeDispatch([f_nin,f_ni2,f_num,f_sol,None])
t2 = TypeDispatch(f_col, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.
t2

<__main__.TypeDispatch at 0x10eb4f2f0>

In [16]:
test_eq(t2[int], f_ni2)       # searches `t` b/c not found in `t2`
test_eq(t2[np.int32], f_nin)  # searches `t` b/c not found in `t2`
test_eq(t2[float], f_num)     # searches `t` b/c not found in `t2`
test_eq(t2[dict], f_col)      # found in `t`!
test_eq(t2('a'), 'a1')        # found in `t`!, and uses __call__

In [17]:
o = np.int32(1)
try:
    test_eq(t2(o), 2)             # found in `t2` and uses __call__
except TypeError:
    pass

# TODO(Rens): Address this type error

#### Up To Two Arguments

`TypeDispatch` supports up to two arguments when searching for the appropriate function.  The following functions `f1` and `f2` both have two parameters:

In [18]:
def f1(x:numbers.Integral, y): return x+1  #Integral is a numeric type
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
t

<__main__.TypeDispatch at 0x10eb33860>

 You can lookup functions from a `TypeDispatch` instance with two parameters like this:

In [19]:
#test_eq(t[np.int32], f1)  # TODO(Rens): this doesntr work becausde both funcs expect 2 args...
test_eq(t[np.int32, float], f1)  # TODO(Rens): Added this new
test_eq(t[int,float], f2)

Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added.  In the below example, `f1` is ignored in favor of `f2` because the first two parameters have identical type hints:

In [20]:
def f1(a:str, b:int, c:list): return a
def f2(a: str, b:int): return b
t = TypeDispatch([f1,f2])
test_eq(t[str, int], f2)
t

<__main__.TypeDispatch at 0x10eb31b20>

#### Matching

`Type Dispatch` matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions.  

Let's consider an example where we try to retrieve the function corresponding to types of `[np.int32, float]`.

In this scenario, `f2` will not be matched. This is because the first type annotation of `f2`, `int`, is not a superclass (or the same class) of `np.int32`:

In [21]:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

assert not issubclass(np.int32, int)

Instead, `f1` is a valid match, as its first argument is annoted with the type `numbers.Integeral`, which `np.int32` is a subclass of:  

In [22]:
assert issubclass(np.int32, numbers.Integral)
test_eq(t[np.int32,float], f1) 

In `f1` , the 2nd parameter `y` is not annotated, which means `TypeDispatch` will match anything where the first argument matches `int` that is not matched with anything else:

In [23]:
assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral
test_eq(t[int, str], f1)
test_eq(t[int,int], f1)

If no match is possible, `None` is returned:

In [24]:
test_eq(t[float,float], None)

In [25]:
show_doc(TypeDispatch.__call__)

---

[source](https://github.com/AnswerDotAI/fasttransform/blob/main/fasttransform/dispatch.py#L61){target="_blank" style="float:right; font-size:smaller"}

### TypeDispatch.__call__

>      TypeDispatch.__call__ (*args, **kwargs)

*Call self as a function.*

`TypeDispatch` is also callable.  When you call an instance of `TypeDispatch`, it will execute the relevant function:

In [26]:
def f_arr(x:np.ndarray): return x.sum()
def f_int(x:np.int32): return x+1
t = TypeDispatch([f_arr, f_int])

arr = np.array([5,4,3,2,1])
test_eq(t(arr), 15) # dispatches to f_arr

o = np.int32(1)
test_eq(t(o), 2) # dispatches to f_int

# TODO(Rens): Not implemented
try:
    assert t.first() is not None 
except AttributeError:
    pass

You can also call an instance of of `TypeDispatch` when there are two parameters:

In [27]:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)

When no match is found, a `TypeDispatch` instance becomes an identity function.  This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found.

In [28]:
test_eq(t('a'), 'a')

In [29]:
show_doc(TypeDispatch.returns)

---

[source](https://github.com/AnswerDotAI/fasttransform/blob/main/fasttransform/dispatch.py#L101){target="_blank" style="float:right; font-size:smaller"}

### TypeDispatch.returns

>      TypeDispatch.returns (x)

*Get the return type of annotation of `x`.*

You can optionally pass an object to `TypeDispatch.returns` and get the return type annotation back:

In [30]:
def f1(x:int) -> np.ndarray: return np.array(x)
def f2(x:str) -> float: return List
def f3(x:float): return List # f3 has no return type annotation

t = TypeDispatch([f1, f2, f3])

test_eq(t.returns(1), np.ndarray)  # dispatched to f1
test_eq(t.returns('Hello'), float) # dispatched to f2
test_eq(t.returns(1.0), None)      # dispatched to f3

class _Test: pass
_test = _Test()
test_eq(t.returns(_test), None) # type `_Test` not found, so None returned

#### Using TypeDispatch With Methods

You can use `TypeDispatch` when defining methods as well:

In [31]:
def m_str(self, x:str): return str(x)+'1'
def m_num(self, x:int): return x*2
def m_lst(self, x:list): self.foo='a'

t = TypeDispatch([m_str,m_num,m_lst])
class A: pass
a = A()
t(None, "5"), t(a, [5,]), t(None, 5), a.foo

('51', None, 10, 'a')

In [32]:
a.foo

'a'

**Note:** Ok so this works.

In [33]:
class A: f = t # set class attribute `f` equal to a TypeDispatch instance
    
a = A()

**Note:** Now I want the typedispatch to automatically add A itself when calling t

In [34]:
a.f("5"), a.f(5)

('51', 10)

In [35]:
assert not hasattr(a,"foo")
a.f([])
assert a.foo == "a"

As discussed in `TypeDispatch.__call__`, when there is not a match, `TypeDispatch.__call__` becomes an identity function.  In the below example, a tuple does not match any type annotations so a tuple is returned:

In [36]:
test_eq(a.f(()), ()) 

We extend the previous example by using `bases` to add an additional method that supports tuples:

In [37]:
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, bases=t)

class A2: f = t2
a2 = A2()

In [38]:
test_eq(a2.f(1), 2)         # int
test_eq(a2.f(1.), 1.)       # not found
test_is(a2.f.inst, a2)      # finds instance
test_eq(a2.f(False), 0)     # bool is integer so 0*2=0
test_is(a2.f([1,]), None)   # doesnt return anything
test_eq(a2.foo, 'a')        # ..but does set foo to a
test_eq(a2.f(()), (1,))     # new tuple function

#### Using TypeDispatch With Class Methods

You can use `TypeDispatch` when defining class methods too:

In [39]:
def m_str(cls, x:str|float): return str(x)+'1'
def m_num(cls, x:int): return x*2
def m_lst(cls, x:list): cls.foo='a'

t = TypeDispatch([m_str,m_num,m_lst])
class A: f = t # set class attribute `f` equal to a TypeDispatch

test_eq(A.f(1), 2)         #dispatch to m_num
test_eq(A.f(1.), "1.01")   #dispatch to m_str
test_is(A.f.owner, A)

A.f([1,2]) # this triggers t.m_bll to run, which sets A.foo to 'a'
test_eq(A.foo, 'a')

## typedispatch Decorator


In [40]:
#|export
# this works in 09 vision, but not in 03 data core
# class DispatchReg:
#     "A global registry for `TypeDispatch` objects keyed by function name"
#     def __init__(self): self.d = defaultdict(TypeDispatch)
#     def __call__(self, f):
#         if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
#         else: nm = f'{f.__qualname__}'
#         if isinstance(f, classmethod): f=f.__func__
#         self.d[nm].add(f)
#         return self.d[nm]

# typedispatch = DispatchReg()

# # this works in 03 data core, but not with 09 vision
# from plum.dispatcher import Dispatcher
# typedispatch = Dispatcher()

In [41]:
#|export
from plum.dispatcher import Dispatcher, is_in_class

class FastFunction(Function):
    def __getitem__(self, k):
        self._resolve_pending_registrations()

        # If a single Signature is passed, use it directly
        if not isinstance(k,Signature):
            k = Signature(*k) if isinstance(k, Iterable) else Signature(k)

        try:
            return self._resolver.resolve(k).implementation
        except NotFoundLookupError:
            pass
        
        # TODO(Rens): this is a temporary fix for fastai 
        # which uses show_batch[object] while the function has 3 args
        # so technically show_batch[object, object, object] would be ok.
        # if k.types == (object,):
        if self.__name__ == "show_batch" and len(k.types) == 1:
            return self.__getitem__(k.types + (object, object))
            
        return None


class FastDispatcher(Dispatcher):
    def _get_function(self, method: Callable) -> FastFunction:
        # If a class is the owner, use a namespace specific for that class. Otherwise,
        # use the global namespace.
        if is_in_class(method):
            owner = get_class(method)
            if owner not in self.classes:
                self.classes[owner] = {}
            namespace = self.classes[owner]
        else:
            owner = None
            namespace = self.functions

        # Create a new function only if the function does not already exist.
        name = method.__name__
        if name not in namespace:
            namespace[name] = FastFunction(
                method,
                owner=owner,
                warn_redefinition=self.warn_redefinition,
            )

        return namespace[name]

typedispatch = FastDispatcher()

In [42]:
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral|int, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y
@typedispatch
def f_td_test(x:int, y:int): return x*y

test_eq(f_td_test(3,2.0), 5)
assert issubclass(int, numbers.Integral)
test_eq(f_td_test(3,2), 6)

test_eq(f_td_test('a','b'), 'ab')

In [43]:
@typedispatch
def f1(x:int): return "1 INT!"
@typedispatch
def f1(x:str): return "1 STR!"
@typedispatch
def f2(x:int): return "2 INT!"
@typedispatch
def f2(x:str): return "2 STR!"

assert f2[str]("foo") == "2 STR!"
assert f2[int](5) == "2 INT!"
assert f1[str]("foo") == "1 STR!"
assert f1[int](5) == "1 INT!"

#### Using typedispatch With other decorators

You can use `typedispatch` with `classmethod` and `staticmethod` decorator

In [44]:
# TODO(Rens): Doesnt work w Plum yet...see if remove or fix

# class A:
#     @typedispatch
#     def f_td_test(self, x:numbers.Integral, y): return x+1
#     @typedispatch
#     @classmethod
#     def f_td_test(cls, x:int, y:float): return x+y
#     @typedispatch
#     @staticmethod
#     def f_td_test(x:int, y:int): return x*y
    
# test_eq(A.f_td_test(3,2), 6)
# test_eq(A.f_td_test(3,2.0), 5)
# test_eq(A().f_td_test(3,'2.0'), 4)

## Casting

Now that we can dispatch on types, let's make it easier to cast objects to a different type.

In [45]:
#|export
_all_=['cast']

In [46]:
#|export
def retain_meta(x, res, as_copy=False):
    "Call `res.set_meta(x)`, if it exists"
    if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)
    return res

In [47]:
#|export
def default_set_meta(self, x, as_copy=False):
    "Copy over `_meta` from `x` to `res`, if it's missing"
    if hasattr(x, '_meta') and not hasattr(self, '_meta'):
        meta = x._meta
        if as_copy: meta = copy(meta)
        self._meta = meta
    return self

In [48]:
#|export
@typedispatch
def cast(x, typ):
    "cast `x` to type `typ` (may also change `x` inplace)"
    res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
    if risinstance('ndarray', res): res = res.view(typ)
    elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
    else:
        try: res.__class__ = typ
        except: res = typ(res)
    return retain_meta(x, res)

This works both for plain python classes:...

In [49]:
mk_class('_T1', 'a')   # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)        
assert t2 is t            # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)

test_eq_type(_T2(a=1), t2) 


...as well as for arrays and tensors.

In [50]:
class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

To customize casting for other types, define a separate `cast` function with `typedispatch` for your type.

In [51]:
#|export
def retain_type(new, old=None, typ=None, as_copy=False):
    "Cast `new` to type of `old` or `typ` if it's a superclass"
    # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
    if new is None: return
    assert old is not None or typ is not None
    if typ is None:
        if not isinstance(old, type(new)): return new
        typ = old if isinstance(old,type) else type(old)
    # Do nothing the new type is already an instance of requested type (i.e. same type)
    if typ==NoneType or isinstance(new, typ): return new
    return retain_meta(old, cast(new, typ), as_copy=as_copy)

In [52]:
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
c = retain_type(b, typ=_T)
test_eq_type(c, a)

If `old` has a `_meta` attribute, its content is passed when casting `new` to the type of `old`.  In the below example, only the attribute `a`, but not `other_attr` is kept, because `other_attr` is not in `_meta`:

In [53]:
class _A():
    set_meta = default_set_meta
    def __init__(self, t): self.t=t

class _B1(_A):
    def __init__(self, t, a=1):
        super().__init__(t)
        self._meta = {'a':a}
        self.other_attr = 'Hello' # will not be kept after casting.
        
x = _B1(1, a=2)
b = _A(1)
c = retain_type(b, old=x)
test_eq(c._meta, {'a': 2})
assert not getattr(c, 'other_attr', None)

In [54]:
#|export
def retain_types(new, old=None, typs=None):
    "Cast each item of `new` to type of matching item in `old` if it's a superclass"
    if not is_listy(new): return retain_type(new, old, typs)
    if typs is not None:
        if isinstance(typs, dict):
            t = first(typs.keys())
            typs = typs[t]
        else: t,typs = typs,None
    else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
    return t(L(new, old, typs).map_zip(retain_types, cycled=True))

In [55]:
class T(tuple): pass

t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

In [56]:
#|export
def explode_types(o):
    "Return the type of `o`, potentially in nested dictionaries for thing that are listy"
    if not is_listy(o): return type(o)
    return {type(o): [explode_types(o_) for o_ in o]}

In [57]:
test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})

## Export -

In [58]:
#|hide
#|eval: false
from nbdev import nbdev_export
nbdev_export()