In [None]:
#|default_exp core

In [None]:
#|export
from typing import Any

from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *
from fastcore.dispatch import retain_meta, cast  # move to fasttransform

from plum.function import Function
from plum import NotFoundLookupError
from fastcore.dispatch import retain_type

In [None]:
from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

from plum import AmbiguousLookupError

# Transforms

> Definition of `Transform` and `Pipeline`

The classes here provide functionality for creating a composition of *partially reversible functions*. By "partially reversible" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).

Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item.

## Utils

In [None]:
#|export
def retain_type(new, old, ret_type,as_copy=False):
    if new is None: return new
    if ret_type is NoneType: return new
    if ret_type is Any:
        if not isinstance(old, type(new)): return new
        ret_type = old if isinstance(old,type) else type(old)
    if ret_type is NoneType or isinstance(new,ret_type): return new
    # fastcore.retain_meta and cast are used because
    # the retain_meta logic is embedded in fastai (and torch itself?)
    # see 00_torch_core set_meta functions.
    return retain_meta(old, cast(new, ret_type), as_copy=as_copy)
    

In [None]:
class FS(float):
    def __str__(self): return f'FS({float(self)})'
    def __repr__(self): return f'FS({float(self)})'
    
# None stays none
test_eq(retain_type(None,FS(2.), Any), None)  

In [None]:
# Dont convert if None as return type annotation
test_eq(retain_type(1., FS(2.), NoneType), 1.)  

In [None]:
# Use return type annotation if given
test_eq(retain_type(1., 2., FS), FS(1.))

In [None]:
# Raise error if return type is not compatible with new
try: retain_type("a", 2., FS)
except ValueError as e: print(f"Caught expected {e=}")
    
# TODO: why doesnt test_fail catch this error?
# test_fail(retain_type("a", 2., FS))

Caught expected e=ValueError("could not convert string to float: 'a'")


In [None]:
# When ret_type is Any: try convert to old type
#   If new isn't subclass of old, keep new
test_eq(retain_type(FS(1.), 2.0, Any), FS(1.))
test_eq(retain_type("a", 2.0, Any), "a")
#   No cast needed if new is alreaedy of type old
test_eq(retain_type(FS(1.), FS(2.), Any), FS(1.)) 
#   Return new if old type was None
test_eq(retain_type(FS(1.), None, Any), FS(1.))
#   Convert
test_eq(retain_type(1., FS(2.), Any), FS(1.))

## Transform -

In [None]:
#|export
_tfm_methods = 'encodes','decodes','setups'

def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

class _TfmDict(dict):
    def __setitem__(self, k, v):
        if not _is_tfm_method(k, v): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k,Function(v).dispatch(v))
        self[k].dispatch(v)
     

In [None]:
#|export
class _TfmMeta(type):
    @classmethod
    def __prepare__(cls, name, bases): 
        return _TfmDict()

### The main `Transform` features:

- **Type dispatch** - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch. 
- **Handling of tuples** - When a tuple (or a subclass of tuple) of data is passed to a transform it will get applied to each element separately. You can opt out of this behavior by passing a list or an `L`, as only tuples gets this specific behavior. An alternative is to use `ItemTransform` defined below, which will always take the input as a whole.
- **Reversability** - A transform can be made reversible by implementing the <code>decodes</code> method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the `decode` method that is used to decode will be applied over each element of a tuple separately.
- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorch's `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.
- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.
- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.
- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.
- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating <code>encodes</code> or <code>decodes</code> methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.

In [None]:
#|export
def _has_self_arg(f) -> bool:
    "Check if function `f` has 'self' as first parameter"
    try: return f.__code__.co_varnames[0] == 'self'
    # Attribute error if not callable
    # IndexError if no (kw)args
    except (AttributeError, IndexError): return False

In [None]:
#|export
def _subclass_decorator(cls, f):
    nm = f.__name__
    # needed for plum to register dispatch correctly
    # f.__qualname__ = f"{cls.__name__}.{nm}"
    if not hasattr(cls, nm): setattr(cls, nm, Function(f).dispatch(f))
    else: getattr(cls,nm).dispatch(f)
    return cls

In [None]:
#|export
class Transform(metaclass=_TfmMeta):
    "Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
    
    def __init_subclass__(cls):
        # convert _tfm_methods that aren't plum.Functions yet
        for nm in _tfm_methods:
            if hasattr(cls, nm) and not isinstance(getattr(cls, nm), Function):
                f = getattr(cls, nm)
                setattr(cls, nm, Function(f).dispatch(f))

        # Add binding logic to subclass __init__
        def __init__(self):
            for nm in _tfm_methods:
                if hasattr(self.__class__, nm):
                    setattr(self, nm, MethodType(getattr(self.__class__, nm), self))
    
        cls.__init__ = __init__

    def __new__(cls, enc=None, dec=None):
        # subclass of Transform decorator usage
        if (
            issubclass(cls,Transform) and   
            _has_self_arg(enc) and
            enc.__name__ in _tfm_methods and
            dec is None
        ): return _subclass_decorator(cls, enc)
        # default usecase
        return super().__new__(cls)

    def __init__(self,enc=None,dec=None):
        enc = L(enc)
        if enc: self.encodes = Function(enc[0])
        for e in enc: self.encodes.dispatch(e)

        dec = L(dec)
        if dec: self.decodes = Function(dec[0])
        for d in dec: self.decodes.dispatch(d)

    def __call__(self,*args,**kwargs):
        return self._do_call('encodes',*args,**kwargs)
    
    def decode(self, *args, **kwargs):
        return self._do_call('decodes',*args, **kwargs)
    
    def setup(self, *args, **kwargs):
        raise NotImplementedError()
        
    def _do_call(self, nm, *args, **kwargs): 
        if not hasattr(self, nm): return args[0]
        f_args = args if type(self) is Transform else (self,)+args
        try:
            method, ret_type = getattr(self,nm)._resolve_method_with_cache(f_args)
        except NotFoundLookupError: 
            return args[0]
        res = method(*f_args,**kwargs)
        return retain_type(res, args[0], ret_type)
    
add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform")

###  The purpose of transforms: encodes, decodes

Transforms help with transforming data.
They can always encode it, and optionally decode it.

In [None]:
def enc(x): return x*2
def dec(x): return x/2

f = Transform(enc=enc, dec=dec)

In [None]:
f(5.0)

10.0

In [None]:
f.decode(f(5.0))

5.0

### Defining a `Transform`

There are a few ways to create a transform with different ratios of simplicity to flexibility.
- **Extending the `Transform` class** - Use inheritence to implement the methods you want.
- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments.
- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single <code>encodes</code> implementation.
- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically.

#### Define with lambda function

A simple way to create a `Transform` is to pass a function to the constructor.  In the below example, we pass an anonymous function that does integer division by 2:

In [None]:
f = Transform(lambda o: f"f OBJ {o=}!")

If you call this transform, it will apply the transformation:

In [None]:
test_eq_type(f(2), "f OBJ o=2!")

#### Define with subclass decorator

Another way to define a Transform is to extend the `Transform` class:

In [None]:
class A(Transform): pass

However, to enable your transform to do something, you have to define an <code>encodes</code> method.  Note that we can use the class name as a decorator to add this method to the original class.

In [None]:
@A
def encodes(self, x:int): return f"A INT {x=}!"

In [None]:
@A
def encodes(self, x:str): return f"A STR {x=}!"

In [None]:
@A
def decodes(self,x): return x*2

In [None]:
a1 = A()
test_eq(a1(1), "A INT x=1!")
test_eq(a1('a'), "A STR x='a'!")

In [None]:
test_eq(a1.decodes(2), 4)

#### Define with subclass method(s)

You can define multiple encodes methods when you sublcass from Transform, and they'll be picked up automatically.

In [None]:
class C(Transform):
    def encodes(self, x): return f'C OBJ {x=}!'
    def encodes(self, x:int): return f'C INT {x=}!'

You can still extend your subclass by using the decorator.

In [None]:
@C
def encodes(self, x:float): return f'C FLOAT {x=}!'

In [None]:
c = C()
test_eq(c.encodes(0), 'C INT x=0!')
test_eq(c.encodes("a"), "C OBJ x='a'!")
test_eq(c.encodes(0.0), 'C FLOAT x=0.0!')

#### Define with class decorator

You can define a Transform also by using the `@Transform` decorator directly.

In [None]:
@Transform
def g(x:str): return f"g OBJ {x=}!"
test_eq(g("a"), "g OBJ x='a'!")

However in that case it is not extendible, the previous implementation gets overwritten:

In [None]:
@Transform
def g(x:int): return f"g INT {x=}!"

test_eq(g(5), "g INT x=5!")
test_eq(g('a'), 'a')  # <- resorts to returning self
test_eq(len(g.encodes.methods), 1)

#### Define with classmethod

In [None]:
class B:
    @classmethod
    def create(cls, x:int): return x+1
test_eq(Transform(B.create)(1), 2)

### Important attributes of Transform

#### NoFound lookups return self

`fastcore.transform.Transform` has a rule that a transform will return it's first argument if there's no method that fits the input types with which the function's called. By default Plum, would raise a `NoFoundLookupError`, we catch this error and return the first argument to stay consistent with the old implementation as it's a useful default in the context of Transform's in datapipelines.

In [None]:
# return arg[0] if no encodes has been defined
class _Tst(Transform): pass 
f3 = _Tst() 
test_eq(f3(2), 2)

In [None]:
# return arg[0] if no matching type has been found
class _Tst(Transform):
    def encodes(self, x:str): return "str!"
f3 = _Tst() 
test_eq(f3(2), 2)

#### Ambiguous vs NoFound lookups

A difference with `fastcore.transform.Transform` is that this version is stricter about ambiguous lookups.

That's because Plum has a better underlying system for allocating the inputs to the right function.

In [None]:
class E(Transform): pass

@E
def encodes(self, x:int|str): return f"E INT|STR {x=}!"

@E
def encodes(self, x:float|str): return f"E FLOAT|STR {x=}!"

e = E()

test_eq(e(5), "E INT|STR x=5!")
test_eq(e(.5), "E FLOAT|STR x=0.5!")
test_eq(e([1]), [1])  # NoFoundLookups returns self

try: e("hi there")  # could be either encodes function
except AmbiguousLookupError: print("Caught an expected AmbiguousLookupError")

Caught an expected AmbiguousLookupError


#### Multiple inheritance is supported

In [None]:
# inherited transforms
class F(Transform): pass
@F
def encodes(self, x:int): return "INT"

class G(F): pass
@G
def encodes(self, x:str): return "STR"

g = G()

test_eq(len(g.encodes.methods), 2)  # g has two encodes methods

class H(G): pass

@H
def encodes(self, x:list): return "LIST"

h = H()

test_eq(len(h.encodes.methods),3)

#### Type inheritance for input types is supported

In [None]:
# Type inheritance
class MyClass(int): pass

class H(Transform):
    def encodes(self, x:MyClass|float): return x/2

@H
def encodes(self, x:str|list): return str(x)+'_1'

h = H()

test_eq(h(MyClass(5)), 2.5)

In [None]:
assert len(h.encodes.methods) == 2

## Return type casting

Without any intervention it is easy for operations to change types in Python. For example, `FloatSubclass` (defined below) becomes a `float` after performing multiplication:

In [None]:
class FloatSubclass(float):
    def __repr__(self): return f'FloatSubclass({super().__repr__()})'
    def __str__(self): return f'{super().__str__()}'
    

In [None]:
test_eq_type(FloatSubclass(3.0) * 2, 6.0)

This behavior is often not desirable when performing transformations on data.  Therefore, `Transform` will attempt to cast the output to be of the same type as the input by default.  In the below example, the output will be cast to a `FloatSubclass` type to match the type of the input:

### Without type annotations

In [None]:
@Transform
def f(x): return x*2

test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))

We can optionally turn off casting by annotating the transform function with a return type of None:

### Return type None

In [None]:
@Transform
def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation

test_eq_type(f(FloatSubclass(3.0)), 6.0)  # Casting is turned off because of -> None annotation
 

However, Transform will only cast output back to the input type when the input is a subclass of the output. In the below example, the input is of type FloatSubclass which is not a subclass of the output which is of type str. Therefore, the output doesn't get cast back to FloatSubclass and stays as type str:

In [None]:
@Transform
def f(x): return str(x)
    
test_eq_type(f(Float(2.)), '2.0')

Transform will attempt to convert the function output to the return type annotation.

### Specific return types

If a return type annotation is given, Transform will convert it to that type:

In [None]:
@Transform
def f(x)->FloatSubclass: return float(x)

# Output is converted to FloatSubclass because its a subtype of float
test_eq(f(1.), FloatSubclass(1.))

If the function returns a subclass of the annotated return type, that more specific type will be preserved since it's already compatible with the annotation:

In [None]:
@Transform
def f(x)->float: return FloatSubclass(x)

# FloatSubclass output is kept because more specific than float
test_eq(f(1.), FloatSubclass(1.))

When return types are given, the conversion will even happen if the output type is not a subclass of the return type annotation:

In [None]:
@Transform
def f(x)->str: return FloatSubclass(x)

test_eq(f(1.), "FloatSubclass(1.0)")

And here we get an expected error because it's not possible to match the explicit return type:

In [None]:
@Transform
def f(x)->int: return str(x)

try: f("foo")
except Exception as e: print(f"Caught Exception: {e=}")

Caught Exception: e=ValueError("invalid literal for int() with base 10: 'foo'")


### Type annotation with Decode

Just like encodes, the decodes method will cast outputs to match the input type in the same way. In the below example, the output of decodes remains of type MySubclass:

In [None]:
class MySubclass(int): pass

def enc(x): return MySubclass(x+1)
def dec(x): return x-1

f = Transform(enc,dec)
t = f(1) # t is of type MySubclass
test_eq_type(f.decode(t), MySubclass(1)) # the output of decode is cast to MySubclass to match the input type.
 