In [None]:
#|default_exp utils

In [None]:
#|export
from typing import Any

from plum import dispatch
from numpy import ndarray

from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *

In [None]:
from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

# Utils

Core utility functions adapted from `fastcore.dispatch` module. 

This future-proofs the code since fastcore's dispatch module is planned for deprecation in favor of Plum.

The functions here have not been changed, except for `retain_type` which has the same functionality but now accepts the type hints as Plum provides them.

In [None]:
#|export
def get_name(o):
    if hasattr(o,'__qualname__'): return o.__qualname__
    if hasattr(o,'__name__'): return o.__name__
    return o.__class__.__name__

In [None]:
#|export
def is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')

## Casting

Now that we can dispatch on types, let's make it easier to cast objects to a different type.

In [None]:
#|export
def retain_meta(x, res, as_copy=False):
    "Call `res.set_meta(x)`, if it exists"
    if hasattr(res,'set_meta'): res.set_meta(x, as_copy=as_copy)
    return res
     

In [None]:
#|export
def default_set_meta(self, x, as_copy=False):
    "Copy over `_meta` from `x` to `res`, if it's missing"
    if hasattr(x, '_meta') and not hasattr(self, '_meta'):
        meta = x._meta
        if as_copy: meta = copy(meta)
        self._meta = meta
    return self
     

In [None]:
#|export
@dispatch
def cast(x, typ):
    "cast `x` to type `typ` (may also change `x` inplace)"
    res = typ._before_cast(x) if hasattr(typ, '_before_cast') else x
    if risinstance('ndarray', res): res = res.view(typ)
    elif hasattr(res, 'as_subclass'): res = res.as_subclass(typ)
    else:
        try: res.__class__ = typ
        except: res = typ(res)
    return retain_meta(x, res)

This works both for plain python classes:...



In [None]:

mk_class('_T1', 'a')   # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)        
assert t2 is t            # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)

test_eq_type(_T2(a=1), t2) 


...as well as for arrays and tensors.


In [None]:
class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

If old has a _meta attribute, its content is passed when casting new to the type of old. In the below example, only the attribute a, but not other_attr is kept, because other_attr is not in _meta:

## Retain type


In [None]:
#|export
def retain_type(new, old, ret_type,as_copy=False):
    if new is None: return new
    if ret_type is NoneType: return new
    if ret_type is Any:
        if not isinstance(old, type(new)): return new
        ret_type = old if isinstance(old,type) else type(old)
    if ret_type is NoneType or isinstance(new,ret_type): return new
    return retain_meta(old, cast(new, ret_type), as_copy=as_copy)   

In [None]:
class FS(float):
    def __str__(self): return f'FS({float(self)})'
    def __repr__(self): return f'FS({float(self)})'
    

### Return type conversion

We try and convert new to the return type if it's given.

In [None]:
test_eq(retain_type(1., 2., FS), FS(1.))

Even if it won't work, we'll let the exception be raised:

In [None]:
# Raise error if return type is not compatible with new
try: retain_type("a", 2., FS)
except ValueError as e: print(f"Caught expected {e=}")
    
# TODO(Rens): why doesnt test_fail catch this error?
# test_fail(retain_type("a", 2., FS))

Caught expected e=ValueError("could not convert string to float: 'a'")


### Old type conversion

If the return type is `Any` then new looks at old for conversion guidance.

In [None]:
test_eq(retain_type(1., FS(2.), Any), FS(1.))

But if new isn't subclass of old, keep new:

In [None]:
test_eq(retain_type(FS(1.), 2.0, Any), FS(1.))
test_eq(retain_type("a", 2.0, Any), "a")

No casting needed if new is already of type old.
Then we return the original object.

In [None]:
x = FS(1.)
test_is(retain_type(x, FS(2.), Any), x) 

### Edge cases with None

We dont convert at all if None is return type annotation:

In [None]:
test_eq(retain_type(1., FS(2.), NoneType), 1.)  

None stays None:

In [None]:
test_eq(retain_type(None,FS(2.), Any), None)  

If old was None then we just return new.

In [None]:
test_eq(retain_type(FS(1.), None, Any), FS(1.))

## Export -

In [None]:
#|hide
#|eval: false
from nbdev import nbdev_export
nbdev_export()