In [None]:
#default_exp dispatch

In [None]:
#export
from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *
from warnings import warn

In [None]:
from nbdev.showdoc import *
from fastcore.test import *

# Type dispatch

> Basic single and dual parameter dispatch

## Helpers

In [None]:
#exports
def type_hints(f):
    "Same as `typing.get_type_hints` but returns `{}` if not allowed type"
    return typing.get_type_hints(f) if isinstance(f, typing._allowed_types) else {}

In [None]:
#export
def anno_ret(func):
    "Get the return annotation of `func`"
    if not func: return None
    ann = type_hints(func)
    if not ann: return None
    return ann.get('return')

In [None]:
#hide
def f(x) -> float: return x
test_eq(anno_ret(f), float)
def f(x) -> typing.Tuple[float,float]: return x
test_eq(anno_ret(f), typing.Tuple[float,float])
def f(x) -> None: return x
test_eq(anno_ret(f), NoneType)
def f(x): return x
test_eq(anno_ret(f), None)
test_eq(anno_ret(None), None)

In [None]:
#export
cmp_instance = functools.cmp_to_key(lambda a,b: 0 if a==b else 1 if issubclass(a,b) else -1)

In [None]:
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])

In [None]:
#export
def _p2_anno(f):
    "Get the 1st 2 annotations of `f`, defaulting to `object`"
    hints = type_hints(f)
    ann = [o for n,o in hints.items() if n!='return']
    while len(ann)<2: ann.append(object)
    return ann[:2]

In [None]:
def _f(a): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a, b): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a:None, b)->str: pass
test_eq(_p2_anno(_f), (NoneType,object))
def _f(a:str, b)->float: pass
test_eq(_p2_anno(_f), (str,object))
def _f(a:None, b:str)->float: pass
test_eq(_p2_anno(_f), (NoneType,str))
def _f(a:int, b:int)->float: pass
test_eq(_p2_anno(_f), (int,int))
def _f(self, a:int, b:int): pass
test_eq(_p2_anno(_f), (int,int))
def _f(a:int, b:str)->float: pass
test_eq(_p2_anno(_f), (int,str))
test_eq(_p2_anno(attrgetter('foo')), (object,object))

## TypeDispatch -

The following class is the basis that allows us to do type dipatch with type annotations. It contains a dictionary type -> functions and ensures that the proper function is called when passed an object (depending on its type).

In [None]:
#export
class _TypeDict:
    def __init__(self): self.d,self.cache = {},{}

    def _reset(self):
        self.d = {k:self.d[k] for k in sorted(self.d, key=cmp_instance, reverse=True)}
        self.cache = {}

    def add(self, t, f):
        "Add type `t` and function `f`"
        if not isinstance(t,tuple): t=tuple(L(t))
        for t_ in t: self.d[t_] = f
        self._reset()

    def all_matches(self, k):
        "Find first matching type that is a super-class of `k`"
        if k not in self.cache:
            types = [f for f in self.d if k==f or (isinstance(k,type) and issubclass(k,f))]
            self.cache[k] = [self.d[o] for o in types]
        return self.cache[k]

    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        res = self.all_matches(k)
        return res[0] if len(res) else None

    def __repr__(self): return self.d.__repr__()
    def first(self): return first(self.d.values())

In [None]:
#export
class TypeDispatch:
    "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`"
    def __init__(self, funcs=(), bases=()):
        self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))
        for o in L(funcs): self.add(o)
        self.inst = None

    def add(self, f):
        "Add type `t` and function `f`"
        if f and getattr(f,'__defaults__',None): warn(f"Function {f.__name__} has defaults parameters. These will be ignored.")
        a0,a1 = _p2_anno(f)
        t = self.funcs.d.get(a0)
        if t is None:
            t = _TypeDict()
            self.funcs.add(a0, t)
        t.add(a1, f)

    def first(self): return self.funcs.first().first()
    def returns(self, x): return anno_ret(self[type(x)])
    def returns_none(self, x):
        r = anno_ret(self[type(x)])
        return r if r == NoneType else None

    def _attname(self,k): return getattr(k,'__name__',str(k))
    def __repr__(self):
        r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", v.__class__.__name__)}'
             for k in self.funcs.d for l,v in self.funcs[k].d.items()]
        r = r + [o.__repr__() for o in self.bases]
        return '\n'.join(r)

    def __call__(self, *args, **kwargs):
        ts = L(args).map(type)[:2]
        f = self[tuple(ts)]
        if not f: return args[0]
        if self.inst is not None: f = MethodType(f, self.inst)
        return f(*args, **kwargs)

    def __get__(self, inst, owner):
        self.inst = inst
        return self

    def __getitem__(self, k):
        "Find first matching type that is a super-class of `k`"
        k = L(k)
        while len(k)<2: k.append(object)
        r = self.funcs.all_matches(k[0])
        for t in r:
            o = t[k[1]]
            if o is not None: return o
        for base in self.bases:
            res = base[k]
            if res is not None: return res
        return None

In [None]:
def f_col(x:typing.Collection): return x
def f_nin(x:numbers.Integral)->int:  return x+1
def f_ni2(x:int): return x
def f_bll(x:(bool,list)): return x
def f_num(x:numbers.Number): return x
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])

t.add(f_ni2) #Should work even if we add the same function twice.
test_eq(t[int], f_ni2)
test_eq(t[np.int32], f_nin)
test_eq(t[str], None)
test_eq(t[float], f_num)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
t.add(f_col)
test_eq(t[str], f_col)
test_eq(t[np.int32], f_nin)
o = np.int32(1)
test_eq(t(o), 2)
test_eq(t.returns(o), int)
assert t.first() is not None
t

(list,object) -> f_bll
(typing.Collection,object) -> f_col
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(object,object) -> NoneType

If `bases` is set to a collection of `TypeDispatch` objects, then they are searched matching functions if no match is found in this object.

In [None]:
def f_str(x:str): return x+'1'

t2 = TypeDispatch(f_str, bases=t)
test_eq(t2[int], f_ni2)
test_eq(t2[np.int32], f_nin)
test_eq(t2[float], f_num)
test_eq(t2[bool], f_bll)
test_eq(t2[str], f_str)
test_eq(t2('a'), 'a1')
test_eq(t2[np.int32], f_nin)
test_eq(t2(o), 2)
test_eq(t2.returns(o), int)

In [None]:
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x*2

t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t
a = A()
test_eq(a.f(1), '11')
test_eq(a.f(1.), 2.)
test_is(a.f.inst, a)
a.f(False)
test_eq(a.foo, 'a')
test_eq(a.f(()), ())

In [None]:
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, t)
class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 2.)
test_is(a2.f.inst, a2)
a2.f(False)
test_eq(a2.foo, 'a')
test_eq(a2.f(()), (1,))

In [None]:
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

test_eq(t[int], f1)
test_eq(t[int,int], f1)
test_eq(t[int,float], f2)
test_eq(t[float,float], None)
test_eq(t[np.int32,float], f1)
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
test_eq(t('a'), 'a')
t

(int,float) -> f2
(Integral,object) -> f1

## typedispatch Decorator

In [None]:
#export
class DispatchReg:
    "A global registry for `TypeDispatch` objects keyed by function name"
    def __init__(self): self.d = defaultdict(TypeDispatch)
    def __call__(self, f):
        nm = f'{f.__qualname__}'
        self.d[nm].add(f)
        return self.d[nm]

typedispatch = DispatchReg()

In [None]:
@typedispatch
def f_td_test(x, y): return f'{x}{y}'
@typedispatch
def f_td_test(x:numbers.Integral, y): return x+1
@typedispatch
def f_td_test(x:int, y:float): return x+y
@typedispatch
def f_td_test(x:int, y:int): return x*y

test_eq(f_td_test(3,2.0), 5)
test_eq(f_td_test(3,2), 6)
test_eq(f_td_test('a','b'), 'ab')

In [None]:
def outer():
    @typedispatch
    def f_td_test(x:int,y:int=10): return x*y
test_warns(outer,True)



## Casting

Now that we can dispatch on types, let's make it easier to cast objects to a different type.

In [None]:
#export
_all_=['cast']

In [None]:
#export
def retain_meta(x, res, copy_meta=False):
    "Call `res.set_meta(x)`, if it exists"
    if hasattr(res,'set_meta'): res.set_meta(x, copy_meta=copy_meta)
    return res

In [None]:
#export
def default_set_meta(self, x, copy_meta=False):
    "Copy over `_meta` from `x` to `res`, if it's missing"
    if hasattr(x, '_meta') and not hasattr(self, '_meta'):
        meta = x._meta
        if copy_meta: meta = copy(meta)
        self._meta = meta
    return self

In [None]:
#export
@typedispatch
def cast(x, typ):
    "cast `x` to type `typ` (may also change `x` inplace)"
    res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
    if isinstance(res, ndarray): res = res.view(typ)
    elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
    else:
        try: res.__class__ = typ
        except: res = typ(res)
    return retain_meta(x, res)

This works both for plain python classes:...

In [None]:
mk_class('_T1', 'a')
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)
test_eq_type(_T2(a=1), t2)

...as well as for arrays and tensors.

In [None]:
class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

To customize casting for other types, define a separate `cast` function with `typedispatch` for your type.

In [None]:
#export
def retain_type(new, old=None, typ=None, copy_meta=False):
    "Cast `new` to type of `old` or `typ` if it's a superclass"
    # e.g. old is TensorImage, new is Tensor - if not subclass then do nothing
    if new is None: return
    assert old is not None or typ is not None
    if typ is None:
        if not isinstance(old, type(new)): return new
        typ = old if isinstance(old,type) else type(old)
    # Do nothing the new type is already an instance of requested type (i.e. same type)
    if typ==NoneType or isinstance(new, typ): return new
    return retain_meta(old, cast(new, typ), copy_meta=copy_meta)

In [None]:
class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
test_eq_type(retain_type(b, typ=_T), a)

If `old` has a `_meta` attribute, its content is passed when casting `new` to the type of `old`.

In [None]:
class _A():
    set_meta = default_set_meta
    def __init__(self, t): self.t=t

class _B1(_A):
    def __init__(self, t, a=1):
        super().__init__(t)
        self._meta = {'a':a}
        
x = _B1(1, a=2)
b = _A(1)
test_eq(retain_type(b, old=x)._meta, {'a': 2})

In [None]:
a = {L: [int, tuple]}
first(a.keys())

In [None]:
#export
def retain_types(new, old=None, typs=None):
    "Cast each item of `new` to type of matching item in `old` if it's a superclass"
    if not is_listy(new): return retain_type(new, old, typs)
    if typs is not None:
        if isinstance(typs, dict):
            t = first(typs.keys())
            typs = typs[t]
        else: t,typs = typs,None
    else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
    return t(L(new, old, typs).map_zip(retain_types, cycled=True))

In [None]:
class T(tuple): pass

t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

In [None]:
#export
def explode_types(o):
    "Return the type of `o`, potentially in nested dictionaries for thing that are listy"
    if not is_listy(o): return type(o)
    return {type(o): [explode_types(o_) for o_ in o]}

In [None]:
test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()