# Transform with Plum

**⚠️ Context for Jeremy:**

My previous implementation of fastcore.Transform was done by reimplementing
`fastcore.typedispatch.TypeDispatch` with an implementation based on Plum.

In this notebook I re-build Transform from scratch directly using Plum.
The benefit will be that fastcore/fasttransform will be even simpler.
Users will know to use plum directly for typedispatch and see Transform as a tool that uses it.

I start with focusing on the `encode(s)` method.
Right now I have two implementations. 

One simpler implementation, which ran into some complications with global state if classes get redefined.


One more complicated implementation, using metaclasses. Which addresses those issues.

I think now's a good time to get some input again.
Do you think I'm on the right track If so I'll continue by adding more of the features mentioned in Transform, i.e.:

- decodes and setups support
- implement return type casting (see note at the bottom of this notebook about that)

In [None]:
from types import MethodType

from plum import NotFoundLookupError, AmbiguousLookupError
from plum.dispatcher import dispatch, Dispatcher
from plum.function import Function

from fastcore.test import *

In [None]:
# This should work in plum
from plum.dispatcher import dispatch

class MyClass(int): pass

@dispatch
def plum_func(x:MyClass|float): return x/2

@dispatch
def plum_func(x:str|list): return str(x)+'_1'

assert plum_func(MyClass(5)) == 2.5

# Transforms

> Definition of `Transform` and `Pipeline`

The classes here provide functionality for creating a composition of *partially reversible functions*. By "partially reversible" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).

Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item.

## Transform -

### The main `Transform` features:

- **Type dispatch** - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch. 
- **Handling of tuples** - When a tuple (or a subclass of tuple) of data is passed to a transform it will get applied to each element separately. You can opt out of this behavior by passing a list or an `L`, as only tuples gets this specific behavior. An alternative is to use `ItemTransform` defined below, which will always take the input as a whole.
- **Reversability** - A transform can be made reversible by implementing the <code>decodes</code> method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the `decode` method that is used to decode will be applied over each element of a tuple separately.
- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorch's `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.
- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.
- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.
- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.
- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating <code>encodes</code> or <code>decodes</code> methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.

### Defining a `Transform`

There are a few ways to create a transform with different ratios of simplicity to flexibility.
- **Extending the `Transform` class** - Use inheritence to implement the methods you want.
- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments. 
- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single <code>encodes</code> implementation.
- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically.

**⚠️ Question to Jeremy:** The second type of instantion, and then specifically with multiple (enc=,dec=) functions. I haven't seen used in the docs or in fastai. Should it be kept?

## Simpler implementation

Here's an implementation which does not rely on `fastcore.dispatch`.
I show some tests and edge cases below.

In [None]:
class Transform:
    _instances = {}
    def __init_subclass__(cls):
        if hasattr(cls,'encodes') and not isinstance(cls.encodes, Function):
            cls.encodes = dispatch(cls.encodes)
    
    @staticmethod
    def _new_transform(f):
        fname = f.__name__
        inst = Transform._instances.get(fname)
        if not inst:
            inst = Transform.__new__(Transform)
            inst.encodes = Function(f)
            if fname !="<lambda>":
                Transform._instances[fname] = inst
        inst.encodes = inst.encodes.dispatch(f)
        return inst
    
    @staticmethod
    def _new_transform_subclass(cls, f):
        # Update qualname to make it appear as a class method
        # Plum needs this to register it correctly
        f.__qualname__ = f"{cls.__name__}.{f.__name__}"
        if not hasattr(cls, "encodes"):
            cls.encodes = f
            cls.encodes = dispatch(cls.encodes)
        else:
            cls.encodes = cls.encodes.dispatch(f)
        return cls

    def __new__(cls, f=None, *args, **kwargs):
        if callable(f) and not isinstance(f, type) and len(args) == len(kwargs) == 0:
            if cls is Transform: return cls._new_transform(f)
            else: return cls._new_transform_subclass(cls, f)
        return super().__new__(cls)        
    
    def __call__(self, *args, **kwargs): 
        if not hasattr(self, "encodes"): return args[0]
        try:
            return self.encodes(*args, **kwargs) 
        except NotFoundLookupError: 
            return args[0]
    
    

A simple way to create a `Transform` is to pass a function to the constructor.  In the below example, we pass an anonymous function that does integer division by 2:

### Lambda usage

In [None]:
f = Transform(lambda o: f"f OBJ {o=}!")

If you call this transform, it will apply the transformation:

In [None]:
test_eq_type(f(2), "f OBJ o=2!")

All dispatching information is stored inside the `.encodes` attribute where it's used

In [None]:
f.encodes

<multiple-dispatch function <lambda> (with 1 registered and 0 pending method(s))>

Another way to define a Transform is to extend the `Transform` class:

### Subclass decorator usage

In [None]:
class A(Transform): pass

However, to enable your transform to do something, you have to define an <code>encodes</code> method.  Note that we can use the class name as a decorator to add this method to the original class.

In [None]:
@A
def encodes(self, x:int): return f"A INT {x=}!"

In [None]:
@A
def encodes(self, x:str): return f"A STR {x=}!"

In [None]:
a1 = A()
test_eq(a1(1), "A INT x=1!")
test_eq(a1('a'), "A STR x='a'!")

Note how the dispatch is stored globally in a class namespace

In [None]:
len(a1.encodes.methods)

2

### Class decorator usage

In [None]:
@Transform
def g(x:str): return f"g OBJ {x=}!"


@Transform
def g(x:int): return f"g INT {x=}!"

test_eq(g("a"), "g OBJ x='a'!")
test_eq(g(5), "g INT x=5!")

Note how the instances directly created from `Transform` are tracked in in `Transform._instances` so we know where to add it if the `@Transform` decorator is called again.

In [None]:
Transform._instances

{'g': <__main__.Transform>}

**⚠️ Question for Jeremy:** 

I'm not tracking `Transform(<lambda>)` functions because re-use there would probably be unexpected by the user. Do you agree?

### Classmethod usage

In [None]:
class B:
    @classmethod
    def create(cls, x:int): return x+1
test_eq(Transform(B.create)(1), 2)

### Extending method defined in the class

In [None]:
class C(Transform):
    def encodes(self, x): return 'obj'

@C
def encodes(self, x:int): return 'int'

c = C()
test_eq(c.encodes(0), 'int')
test_eq(c.encodes(0.0), 'obj')

### Initiating with multiple methods

**⚠️ Question for Jeremy:** Do you agree with the trade-off below?

In fastcore.Transform you could do this:

```
class D(Transform):
    def encodes(self, x): return 'obj'
    def encodes(self, x:int): return 'int'
```

To get this working it required the use of metaclasses.

With Plum we could do the following:

```
class D(Transform):
    @dispatch
    def encodes(self, x): return 'obj'
    @dispatch
    def encodes(self, x:int): return 'int'
```

But for this to work we do need to store all multiple dispatch information in the global `dispatch` object. Which we'll see have a nasty downstream effect further in the notebook.

...But in return we get a simpler Transform class without metaclasses. I also dont mind that we're more explicit about dispatch happening here.

In [None]:
class D(Transform):
    @dispatch
    def encodes(self, x): return 'obj'
    @dispatch
    def encodes(self, x:int): return 'int'

d = D()
test_eq(d.encodes(0), 'int')
test_eq(d.encodes(0.0), 'obj')

### Ambiguous vs NoFound lookups

In [None]:
class E(Transform): pass

@E
def encodes(self, x:int|str): return f"E INT|STR {x=}!"

@E
def encodes(self, x:float|str): return f"E FLOAT|STR {x=}!"

e = E()

test_eq(e(5), "E INT|STR x=5!")
test_eq(e(.5), "E FLOAT|STR x=0.5!")
test_eq(e([1]), [1])  # NoFoundLookups returns self

try:
    e("hi there")  # could be either encodes function
except AmbiguousLookupError:
    print("Caught an expected AmbiguousLookupError")

Caught an expected AmbiguousLookupError


**⚠️ Question for Jeremy:** 
I think this makes sense. The original intended behavior from Transforms to return self when not defined is kept. But if the encodes are internally inconsistent then we raise an error.

### Further Tests

In [None]:
# inherited transforms
class F(Transform): pass
@F
def encodes(self, x:int): return "INT"

class G(F): pass
@G
def encodes(self, x:str): return "STR"

g = G()

test_eq(len(g.encodes.methods), 2)

In [None]:
# return arg[0] if no encodes has been defined
class _Tst(Transform): pass 
f3 = _Tst() 
test_eq(f3(2), 2)

In [None]:
# Type inheritance
# class MyClass(int): pass

class H(Transform):
    def encodes(self, x:MyClass|float): return x/2

@H
def encodes(self, x:str|list): return str(x)+'_1'

h = H()

test_eq(h(MyClass(5)), 2.5)

In [None]:
assert len(h.encodes.methods) == 2

### Breaking case

In [None]:
class Q(Transform):
    def encodes(self, x:float): return x/2
@Q
def encodes(self, x:str): return str(x)+'_1'

q = Q()

assert len(q.encodes.methods) == 2

class Q(Transform):
        def encodes(self, x:int): return x/2
@Q
def encodes(self, x:str): return str(x)+'_1'

q = Q()

assert len(q.encodes.methods) > 2   # 🚨 Expected fail but it shouldnt of course.

**⚠️ Note for Jeremy:**  We keep adding methods to Q even when the class is redefined. 
This happens because we use the global dispatch registry to enable this pattern:

```
class D(Transform):
    @dispatch
    def encodes(self, x): return 'obj'
    @dispatch
    def encodes(self, x:int): return 'int'
```

I think the best way to go ahead is to use metaclasses and to bring back the scope of encodes to a dispatch object inside the class (D in this case). 

This unexpected behavior would not be pleasant for the user I think.


## Metaclass approach

Attempt at implementing Transform that both:

1. Lets you define multiple encodes during class definition
2. Does not overload the global dispatch when classes are redefined

In [None]:
# clear global dispatch
from plum.dispatcher import Dispatcher
dispatch = Dispatcher()

In [None]:
class _TfmMeta(type):
    def __new__(cls, name, bases, dict):
        res = super().__new__(cls, name, bases, dict)
        for nm in _tfm_methods:
            base_td = [getattr(b,nm,None) for b in bases]
            if nm in res.__dict__: getattr(res,nm).bases = base_td
        return res

    @classmethod
    def __prepare__(cls, name, bases): 
        return _TfmDict()
     

In [None]:
_tfm_methods = 'encodes','decodes','setups'

def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

class _TfmDict(dict):
    def __setitem__(self, k, v):
        if not _is_tfm_method(k, v): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k,Function(v).dispatch(v))
        self[k].dispatch(v)
     

In [None]:
class Transform(metaclass=_TfmMeta):
    _instances = {}
    def __init_subclass__(cls):
        if hasattr(cls,'encodes') and not isinstance(cls.encodes, Function):
            cls.encodes = dispatch(cls.encodes)
    
    @staticmethod
    def _new_transform(f):
        fname = f.__name__
        inst = Transform._instances.get(fname)
        if not inst:
            inst = Transform.__new__(Transform)
            inst.encodes = Function(f)
            # don't register lambda's they're intended to be single use
            if fname !="<lambda>":
                Transform._instances[fname] = inst
        inst.encodes = inst.encodes.dispatch(f)
        return inst
    
    @staticmethod
    def _new_transform_subclass(cls, f):
        # Update qualname to make it appear as a class method
        # Plum needs this to register it correctly
        f.__qualname__ = f"{cls.__name__}.{f.__name__}"
        if not hasattr(cls, "encodes"):
            cls.encodes = f
            cls.encodes = Function(cls.encodes).dispatch(cls.encodes)
        else:
            cls.encodes = cls.encodes.dispatch(f)
        return cls

    def __new__(cls, f=None, *args, **kwargs):
        if callable(f) and not isinstance(f, type) and len(args) == len(kwargs) == 0:
            if cls is Transform: return cls._new_transform(f)
            else: return cls._new_transform_subclass(cls, f)
        return super().__new__(cls)        
    
    def __call__(self, *args, **kwargs): 
        if not hasattr(self, "encodes"): return args[0]
        try:
            return self.encodes(*args, **kwargs) 
        except NotFoundLookupError: 
            return args[0]
    
    

### Try with previously failing test

In [None]:
class Q(Transform):
    def encodes(self, x:float): return x/2
@Q
def encodes(self, x:str): return str(x)+'_1'

q = Q()

assert len(q.encodes.methods) == 2

class Q(Transform):
        def encodes(self, x:int): return x/2
@Q
def encodes(self, x:str): return str(x)+'_1'

q = Q()

assert len(q.encodes.methods) == 2   # 🎉 Now it has the expected number of methods

### Old tests that did pass

In [None]:
# overwrite lambda's
f = Transform(lambda o:o//2)
f = Transform(lambda o:o//5)

test_eq(len(f.encodes.methods), 1)

In [None]:
@Transform
def g(x): return "obj"

@Transform
def g(x:int): return "int"

test_eq(len(g.encodes.methods), 2)

In [None]:
class A(Transform):
    def encodes(self, x:int): return "BOO!"
@A
def encodes(self, x): return x+1

f1 = A()
assert len(f1.encodes.methods) == 2

In [None]:
class A(Transform): pass

@A
def encodes(self, x): return x+1

f1 = A()
test_eq(len(f1.encodes.methods), 1)

In [None]:
class _Tst(Transform): pass 

f3 = _Tst() # no encodes method have been defined
test_eq_type(f3(2), 2)

In [None]:
class A:
    @classmethod
    def create(cls, x:int): return x+1
test_eq(Transform(A.create)(1), 2)

In [None]:
class A(Transform):
    def encodes(self, x): return 'obj'

@A
def encodes(self, x:int): return 'int'

a = A()
test_eq(a.encodes(0), 'int')
test_eq(a.encodes(0.0), 'obj')

In [None]:
@Transform
def f(x:int): return x//2
test_eq_type(f(2), 1)

@Transform
def f(x:float): return x*2
test_eq_type(f(2), 1)
test_eq_type(f(2.), 4.)

In [None]:
class A(Transform): pass

@A
def encodes(self, x:int): return "INT!"

@A
def encodes(self, x:float): return "FLOAT!"

a= A()

test_eq(a(5), "INT!")
test_eq(a(5.), "FLOAT!")

In [None]:
class A(Transform):
    def encodes(self, x:MyClass|float): return x/2
    def encodes(self, x:str|list): return str(x)+'_1'

f = A()

test_eq(len(f.encodes.methods), 2)
test_eq(f(MyClass(2)), 1.) # input is of type MyClass 
test_eq(f(6.0), 3.0) # input is of type float
test_eq(f('a'), 'a_1') # input is of type str
test_eq(f(['a','b','c']), "['a', 'b', 'c']_1") # input is of type list

## Return type casting

From fastcore.transform:

> Without any intervention it is easy for operations to change types in Python. For example, `FloatSubclass` (defined below) becomes a `float` after performing multiplication:

```python
class FloatSubclass(float): pass
test_eq_type(FloatSubclass(3.0) * 2, 6.0)
```

> This behavior is often not desirable when performing transformations on data.  Therefore, `Transform` will attempt to cast the output to be of the same type as the input by default.  In the below example, the output will be cast to a `FloatSubclass` type to match the type of the input:

```python
@Transform
def f(x): return x*2

test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))
```

> We can optionally turn off casting by annotating the transform function with a return type of `None`:  

```python
@Transform
def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation

test_eq_type(f(FloatSubclass(3.0)), 6.0)  # Casting is turned off because of -> None annotation
```

**⚠️ Question to Jeremy:**

Plum does things differently. It tries to cast the output to the return type annotation of the function.
I haven't thought yet about how to address this but I'm curious what your input is on this matter.

My current thinking:

1. Check if we can turn off return type casting from Plum
2. Check if we can leverage Plum's return type casting to actually cast to the first arg type

In [None]:
@dispatch
def foo(x)->None: return x*2

In [None]:
try:
    foo(5)
except TypeError:
    pass
# Returns: TypeError: Cannot convert `10` to `NoneType`.