In [1]:
#|default_exp transform

In [2]:
#|export
from typing import Any

from fastcore.imports import *
from fastcore.foundation import *
from fastcore.utils import *

from plum.function import Function
from plum import NotFoundLookupError

from fasttransform.utils import get_name, is_tuple, retain_type

SyntaxError: multiple exception types must be parenthesized (transform.py, line 104)

In [None]:
from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

from plum import AmbiguousLookupError

# Transforms

> Definition of `Transform` and `Pipeline`

The classes here provide functionality for creating a composition of *partially reversible functions*. By "partially reversible" we mean that a transform can be `decode`d, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).

Classes are also provided and for composing transforms, and mapping them over collections. `Pipeline` is a transform which composes several `Transform`, knowing how to decode them or show an encoded item.

The goal of this module is to replace `fastcore.Transform` by using the package Plum for multiple dispatch rather than the `fastcore.dispatch` module. Plum is a well maintained library, that provides better dispatch functionality.

## Transform -

### The main `Transform` features:

- **Type dispatch** - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch. 
- **Handling of tuples** - When a tuple (or a subclass of tuple) of data is passed to a transform it will get applied to each element separately. You can opt out of this behavior by passing a list or an `L`, as only tuples gets this specific behavior. An alternative is to use `ItemTransform` defined below, which will always take the input as a whole.
- **Reversability** - A transform can be made reversible by implementing the <code>decodes</code> method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the `decode` method that is used to decode will be applied over each element of a tuple separately.
- **Type propagation** - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like `ArrayImage` which is a thin wrapper of pytorch's `Tensor`. You can opt out of this behavior by adding `->None` return type annotation.
- **Preprocessing** - The `setup` method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.
- **Filtering based on the dataset type** - By setting the `split_idx` flag you can make the transform be used only in a specific `DataSource` subset like in training, but not validation.
- **Ordering** - You can set the `order` attribute which the `Pipeline` uses when it needs to merge two lists of transforms.
- **Appending new behavior with decorators** - You can easily extend an existing `Transform` by creating <code>encodes</code> or <code>decodes</code> methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.

### Utils -

In [None]:
#|export
def merge_funcs(*fs):
    "Merge multiple plum Functions by combining their methods"
    fs = fs[::-1]  # overwrite old implementations with new ones
    res = Function(fs[-1].methods[0].implementation)
    for f in fs: 
        for m in f.methods: res.dispatch(m.implementation)
    return res

In [None]:
def f1(x:int): return 'int1'
def f2(x:float): return 'float2'
def f3(x:str): return 'str3' 
def f4(x:int): return 'int4'

f = Function(f1).dispatch(f1).dispatch(f2)
g = Function(f3).dispatch(f3).dispatch(f4)

h = merge_funcs(f,g)
assert h(1) == 'int1'
assert h('a') == 'str3'
assert h(1.) == 'float2'

In [None]:
#|export
_tfm_methods = 'encodes','decodes','setups'
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

In [None]:
#|export
def _has_self_arg(f) -> bool:
    try: return f.__code__.co_varnames[0] == 'self'
    except (AttributeError, IndexError): return False

In [None]:
#|export
class _TfmDict(dict):
    def __setitem__(self, k, v):
        if not _is_tfm_method(k, v): return super().__setitem__(k,v)
        if k not in self: super().__setitem__(k, Function(v))
        self[k].dispatch(v)

In [None]:
#|export
class _TfmMeta(type):
    @classmethod
    def __prepare__(cls, name, bases): return _TfmDict()

    def __call__(cls, *args, **kwargs):
        if issubclass(cls,Transform) and len(args)==1 and _has_self_arg(args[0]) and len(kwargs)==0: 
            f, nm = args[0], args[0].__name__
            if nm not in _tfm_methods: raise RuntimeError(f"{nm} not in {_tfm_methods}")
            if not hasattr(cls, nm): setattr(cls, nm, Function(f).dispatch(f))
            else: getattr(cls,nm).dispatch(f)
            return cls
        return super().__call__(*args, **kwargs)


    def __new__(cls, name, bases, namespace):
        new_cls = super().__new__(cls, name, bases, namespace)
        for nm in _tfm_methods:
            if hasattr(new_cls, nm):
                funcs = [getattr(new_cls, nm)] + [getattr(b, nm,None) for b in bases]
                funcs = [f for f in funcs if f]
                if funcs: setattr(new_cls, nm, merge_funcs(*funcs))
        return new_cls

In [None]:
#|export
class Transform(metaclass=_TfmMeta):
    "Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
    split_idx,init_enc,order,train_setup = None,None,0,None
    
    def __init__(self,enc=None,dec=None, split_idx=None, order=None):
        self.split_idx = ifnone(split_idx, self.split_idx)
        self.order = ifnone(order, getattr(self, 'order', 0))
        if not is_listy(enc): 
            self.order = getattr(enc,'order',self.order)
            if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))
        if enc:=L(enc): 
                self._name = get_name(enc[0])
                if not hasattr(enc[0],'__name__'): # Plum requires enc to have __name__ attr
                    f = enc[0]
                    def wrapped_enc(*args,**kwargs): return f(*args,**kwargs)
                    wrapped_enc.__name__ = self._name
                    enc[0] = wrapped_enc
                self.encodes = Function(enc[0])
        for e in enc: self.encodes.dispatch(e)
        if dec:=L(dec): self.decodes = Function(dec[0])
        for d in dec: self.decodes.dispatch(d)

    @property
    def name(self): return getattr(self, '_name', get_name(self))
    def __repr__(self):
        enc = len(self.encodes.methods) if hasattr(self, 'encodes') else 0
        dec = len(self.decodes.methods) if hasattr(self, 'decodes') else 0
        return f'{self.name}(enc:{enc},dec:{dec})'
    def __call__(self,*args,split_idx=None, **kwargs): return self._call('encodes', *args, split_idx=split_idx, **kwargs)
    def decode(self, *args,split_idx=None, **kwargs): return self._call('decodes', *args, split_idx=split_idx, **kwargs)
    def setup(self, items=None, train_setup=False):
        train_setup = train_setup if self.train_setup is None else self.train_setup
        items = getattr(items, 'train', items) if train_setup else items
        try: return self.setups(items)
        except (AttributeError, NotFoundLookupError): return None

    def _call(self, nm, *args, split_idx=None, **kwargs):
        if split_idx!=self.split_idx and self.split_idx is not None: return args[0]
        if not hasattr(self, nm): return args[0]
        return self._do_call(nm, *args, **kwargs)

    def _do_call(self, nm, *args, **kwargs):
        if is_tuple(x:=args[0]): 
            res = tuple(self._do_call(nm, x_, *args[1:], **kwargs) for x_ in x)
            return retain_type(res, x, Any)
        f = getattr(self,nm)
        if isinstance(f,MethodType): f, f_args = f._f, (self,)+args
        else: f_args = args
        try: method, ret_type = f._resolve_method_with_cache(f_args)
        except NotFoundLookupError: return x
        return retain_type(method(*f_args,**kwargs), x, ret_type)

add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform")

### Accepting operators -

In [None]:
f = Transform(attrgetter('a')) # does not raise error

### Input types -

In [None]:
#|hide
def enc(x:int|float): return x*2
f = Transform(enc)
test_eq(f.input_types, (int, float))

### Order - 

In [None]:
#|hide
class A(Transform):
    def encodes(self, x): return x**(.5)
test_eq(A.order, 0)

In [None]:
#|hide
class A(Transform):
    order = -1
    def encodes(self, x): return x**(.5)
test_eq(A.order, -1)

In [None]:
#|hide
a = A()
test_eq(a.order, -1)

In [None]:
#|hide
a = A(order=-2)
test_eq(a.order, -2)

In [None]:
#|hide
def enc(x): return x+1
enc.order = -2
a = Transform(enc)
test_eq(a.order, -2)

In [None]:
#|hide
def enc(x): return x+1
enc.order = -2
a = Transform(enc,order=-1)
test_eq(a.order, -2)

### Transform: Encode and Decode Data

A Transform encodes (transforms) data while optionally providing a decode operation to convert back. Encoding may be useful for in machine learning preprocessing pipelines, like category encoding. 

Transforms optionally provide a way to decode the transform, this may be useful where human-readable display is needed.

In [None]:
def enc(x): return x*2
def dec(x): return x/2
f = Transform(enc=enc, dec=dec)
f

In [None]:
f(5.0)

In [None]:
f.decode(10.0)

In [None]:
f.decode(f(5.0))

### Defining a `Transform`

There are a few ways to create a transform with different ratios of simplicity to flexibility.
- **Passing methods to the constructor** - Instantiate the `Transform` class and pass your functions as `enc` and `dec` arguments.
- **@Transform decorator** - Turn any function into a `Transform` by just adding a decorator - very straightforward if all you need is a single <code>encodes</code> implementation.
- **Extending the `Transform` class** - Use inheritence to implement the methods you want.
- **Passing a function to fastai APIs** - Same as above, but when passing a function to other transform aware classes like `Pipeline` or `TfmdDS` you don't even need a decorator. Your function will get converted to a `Transform` automatically.

#### Passing methods to the constructor

A simple way to create a `Transform` is to pass a function to the constructor.  In the below example, we pass an anonymous function that does integer division by 2:

In [None]:
f = Transform(lambda o: o*2)

If you call this transform, it will apply the transformation:

In [None]:
test_eq_type(f(2), 4)

#### @Transform decorator

You can define a Transform also by using the `@Transform` decorator directly.

In [None]:
@Transform
def f(x:str): return f"hello {x}!"
test_eq(f("Alex"), "hello Alex!")

#### Define with classmethod

In [None]:
class B:
    @classmethod
    def create(cls, x:int): return x+1
test_eq(Transform(B.create)(1), 2)

### Important attributes of Transform

#### Type dispatch

Type dispatch in Transforms uses type annotations to automatically select the appropriate implementation for different input types.
This lets a single Transform handle multiple data formats without explicit conditional logic.

In [None]:
def enc1(x: int): return x*2
def enc2(x: str): return f"hello {x}!"

f = Transform(enc=[enc1, enc2])

test_eq_type(f(2), 4)
test_eq(f("Alex"), "hello Alex!")

#### Return self if no type hint was found

`fastcore.transform.Transform` has a rule that a transform will return it's first argument if there's no method that fits the input types with which the function's called. 
By default Plum, would raise a `NoFoundLookupError`, we catch this error and return the first argument to stay consistent with the old implementation as it's a useful default in the context of Transform's in datapipelines.

In [None]:
# return arg[0] if no encodes has been defined
f3 = Transform()
test_eq(f3(2), 2)

In [None]:
# return arg[0] if no matching type has been found
def enc(x:str): return "str!"
f = Transform(enc)    
test_eq(f3(2), 2)

#### Ambiguous vs NoFound lookups

A difference with `fastcore.transform.Transform` is that this version is stricter about ambiguous lookups.

That's because Plum has a better underlying system for allocating the inputs to the right function.

In [None]:
def enc1(x: int|str): return f"E INT|STR {x=}!"
def enc2(x: float|str): return f"E FLOAT|STR {x=}!"

e = Transform(enc=[enc1, enc2])

test_eq(e(5), "E INT|STR x=5!")
test_eq(e(.5), "E FLOAT|STR x=0.5!")
test_eq(e([1]), [1])  # NoFoundLookups returns self

try: e("hi there")  # could be either encodes function
except AmbiguousLookupError: print("Caught an expected AmbiguousLookupError")

#### Type inheritance for input types is supported

You can bring your own types:

In [None]:
class FloatSubclass(float):
    def __repr__(self): return f'FloatSubclass({super().__repr__()})'
    def __str__(self): return f'{super().__str__()}'

In [None]:
def enc1(x: int|FloatSubclass): return x/2
h = Transform(enc1)
test_eq(h(FloatSubclass(5.0)), 2.5)

And type inheritance is supported

In [None]:
def enc1(x: int|float): return x/2
h = Transform(enc=enc1)
test_eq(h(FloatSubclass(5.0)), 2.5)

## Return type casting

Without any intervention it is easy for operations to change types in Python. For example, `FloatSubclass` (defined below) becomes a `float` after performing multiplication:

In [None]:
test_eq_type(FloatSubclass(3.0) * 2, 6.0)

This behavior is often not desirable when performing transformations on data.  Therefore, `Transform` will attempt to cast the output to be of the same type as the input by default.  In the below example, the output will be cast to a `FloatSubclass` type to match the type of the input:

### Without type annotations

In [None]:
@Transform
def f(x): return x*2

test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))

We can optionally turn off casting by annotating the transform function with a return type of None:

### Return type None

In [None]:
@Transform
def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation

test_eq_type(f(FloatSubclass(3.0)), 6.0)  # Casting is turned off because of -> None annotation
 

However, Transform will only cast output back to the input type when the input is a subclass of the output. In the below example, the input is of type FloatSubclass which is not a subclass of the output which is of type str. Therefore, the output doesn't get cast back to FloatSubclass and stays as type str:

In [3]:
@Transform
def f(x): return str(x)
    
test_eq_type(f(Float(2.)), '2.0')

NameError: name 'Transform' is not defined

Transform will attempt to convert the function output to the return type annotation.

### Specific return types

If a return type annotation is given, Transform will convert it to that type:

In [None]:
@Transform
def f(x)->FloatSubclass: return float(x)

# Output is converted to FloatSubclass because its a subtype of float
test_eq(f(1.), FloatSubclass(1.))

If the function returns a subclass of the annotated return type, that more specific type will be preserved since it's already compatible with the annotation:

In [None]:
@Transform
def f(x)->float: return FloatSubclass(x)

# FloatSubclass output is kept because more specific than float
test_eq(f(1.), FloatSubclass(1.))

When return types are given, the conversion will even happen if the output type is not a subclass of the return type annotation:

In [None]:
@Transform
def f(x)->str: return FloatSubclass(x)

test_eq(f(1.), "FloatSubclass(1.0)")

And here we get an expected error because it's not possible to match the explicit return type:

In [None]:
@Transform
def f(x)->int: return str(x)

try: f("foo")
except Exception as e: print(f"Caught Exception: {e=}")

### Type annotation with Decode

Just like encodes, the decodes method will cast outputs to match the input type in the same way. In the below example, the output of decodes remains of type IntSubclass:

In [None]:
def enc(x): return FloatSubclass(x+1)
def dec(x): return x-1

f = Transform(enc,dec)
t = f(1.0)  # t will be FloatSubclass
test_eq_type(f.decode(t), FloatSubclass(1.0))

### Transforms on Lists

Transform operates on lists as a whole, **not element-wise:**



In [None]:
def enc(x): return dict(x)
def dec(x): return list(x.items())
    
f = Transform(enc,dec)
_inp = [(1,2), (3,4)]
t = f(_inp)

test_eq(t, dict(_inp))
test_eq(f.decodes(t), _inp)

In [None]:
#|hide
f.split_idx = 1
test_eq(f(_inp, split_idx=1), dict(_inp))
test_eq(f(_inp, split_idx=0), _inp)

If you want a transform to operate on a list elementwise, you must implement this appropriately in the encodes and decodes methods:

In [None]:
def enc(x): return [x_+1 for x_ in x]
def dec(x): return [x_-1 for x_ in x]

f = Transform(enc,dec)
t = f([1,2])

test_eq(t, [2,3])
test_eq(f.decode(t), [1,2])
     

### Transforms on Tuples

Unlike lists, Transform operates on tuples element-wise.

In [None]:
def neg_int(x): return -x
f = Transform(neg_int)

test_eq(f((1,2,3)), (-1,-2,-3))

Transforms will also apply TypedDispatch element-wise on tuples when an input type annotation is specified. In the below example, the values 1.0 and 3.0 are ignored because they are of type float, not int:

In [None]:
def neg_int(x:int): return -x
f = Transform(neg_int)

test_eq(f((1.0, 2, 3.0)), (1.0, -2, 3.0))

In [None]:
#|hide
test_eq(f((1,)), (-1,))
test_eq(f((1.,)), (1.,))
test_eq(f.decode((1,2)), (1,2))
# test_eq(f.input_types, int) no idea where input_types is defined; this is from fastcore.Transform

Another example of how Transform can use TypedDispatch with tuples is shown below:

In [None]:
def enc1(x: int): return x+1
def enc2(x: str): return x+'hello'
def enc3(x): return str(x)+'!'
f = Transform(enc=[enc1, enc2, enc3])

If the input is not an int or str, the third encodes method will apply:



In [None]:
test_eq(f([1]), '[1]!')
test_eq(f([1.0]), '[1.0]!')

However, if the input is a tuple, then the appropriate method will apply according to the type of each element in the tuple:

In [None]:
test_eq(f(('1',)), ('1hello',))
test_eq(f((1,2)), (2,3))
test_eq(f(('a',1.0)), ('ahello','1.0!'))

In [None]:
#|hide
def dec(x: int): return x-1
f = Transform(dec=dec)
test_eq(f.decode((2,)), (1,))
test_eq(f.decode(('2',)), ('2',))
assert pickle.loads(pickle.dumps(f))

Dispatching over tuples works recursively, by the way:

In [None]:
def enc1(x:int): return x+1
def enc2(x:str): return x+'_hello'
def dec1(x:int): return x-1
def dec2(x:str): return x.replace('_hello', '')

f = Transform(enc=[enc1, enc2], dec=[dec1, dec2])
start = (1.,(2,'3'))
t = f(start)
test_eq_type(t, (1.,(3,'3_hello')))
test_eq(f.decode(t), start)

Dispatching also works with typing module type classes, like numbers.integral:



In [None]:
@Transform
def f(x:numbers.Integral): return x+1

t = f((1,'1',1))
test_eq(t, (2, '1', 2))

### Apply Transforms on subsets with `split_idx`

In [None]:
def enc(x): return x+1
def dec(x): return x-1
f = Transform(enc,dec)
f.split_idx = 1

The transformations are applied when a matching split_idx parameter is passed:

In [None]:
test_eq(f(1, split_idx=1),2)
test_eq(f.decode(2, split_idx=1),1)

     
On the other hand, transformations are ignored when the split_idx parameter does not match:

In [None]:
test_eq(f(1, split_idx=0), 1)
test_eq(f.decode(2, split_idx=0), 2)

## Extending Transform

### Limitation of calling Transform directly

However in this case it is not extendible, the previous implementation gets overwritten:

In [None]:
@Transform
def g(x:int): return x*3

test_eq(g(2), 6)
test_eq(g('a'), 'a')  # <- resorts to returning self
test_eq(len(g.encodes.methods), 1)

For extendible Transforms take a look at the "Extending the Transform class" section below


### Subclassing Transform

When you subclass Transform you can define multiple encodes as methods directly.

In [None]:
class A(Transform):
    def encodes(self, x:int): return x*2
    def encodes(self, x:str): return f'hello {x}!'
test_eq(len(A.encodes.methods), 2)

In [None]:
a = A()
test_eq(a(2), 4)
test_eq(a('Alex'), "hello Alex!")

Continued inheritance is supported

In [None]:
class B(A):
    def encodes(self, x:int): return x*4
    def encodes(self, x:float): return x/2
test_eq(len(B.encodes.methods), 3)

In [None]:
b = B()
test_eq(b(2), 8)
test_eq(b('Alex'), 'hello Alex!')
test_eq(b(5.), 2.5)

As is multiple inheritance:

In [None]:
class A(Transform):
    def encodes(self, x:int): return x*2
    def encodes(self, x:str): return f'hello {x}!'

class B(Transform):
    def encodes(self, x:int): return x*4
    def encodes(self, x:float): return x/2

class C(B,A):  # C is preferred of B is preferred over A
    def encodes(self, x:float): return x/4

test_eq(len(A.encodes.methods), 2)
test_eq(len(B.encodes.methods), 2)
test_eq(len(C.encodes.methods), 3)

In [None]:
c = C()
test_eq(c('Alex'), 'hello Alex!')  # A's str method
test_eq(c(5), 20)  # B's int method
test_eq(c(10.), 2.5)  # C's float method

### Extensions with decorators

Another way to define a Transform is to extend the `Transform` class:

In [None]:
class A(Transform): pass

And then use decorators:

In [None]:
@A
def encodes(self, x:int): return x*2

In [None]:
@A
def decodes(self,x:int): return x//2

In [None]:
test_eq(len(A.encodes.methods),1)
test_eq(len(A.decodes.methods),1)

In [None]:
a = A()

In [None]:
test_eq(a(5),10)
test_eq(a.decode(a(5)),5)

Note that adding a method to a class (A) after instantiating the object (a):

In [None]:
@A
def encodes(self, x:str): return f'hello {x}!'

Will result in the method being accessible in both:

In [None]:
test_eq(len(A.encodes.methods),2)
test_eq(len(a.encodes.methods),2)

In [None]:
#|hide
test_is(A.encodes,a.encodes._f)

## Predefined Transform extensions

Below are some Transforms that may be useful as reusable components

### InplaceTransform 

In [None]:
#|export
class InplaceTransform(Transform):
    "A `Transform` that modifies in-place and just returns whatever it's passed"
    def _call(self, fn, *args, split_idx=None, **kwargs):
        super()._call(fn,*args, split_idx=split_idx, **kwargs)
        return args[0]

In [None]:
#|hide
import pandas as pd

In [None]:
class A(InplaceTransform): pass

@A
def encodes(self, x:pd.Series): x.fillna(10, inplace=True)
    
f = A()

test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10],dtype=np.float64)) #fillna fills with floats.

### DisplayedTransform

In [None]:
#|export
class DisplayedTransform(Transform):
    "A transform with a `__repr__` that shows its attrs"

    @property
    def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}\n"

Transforms normally are represented by just their class name and a number of encodes and decodes implementations:



In [None]:
class A(Transform): encodes,decodes = noop,noop
f = A()
f

A DisplayedTransform will in addition show the contents of all attributes listed in the comma-delimited string self.store_attrs:

In [None]:
class A(DisplayedTransform):
    encodes = noop
    def __init__(self, a, b=2):
        super().__init__()
        store_attr()
    
A(a=1,b=2)

### ItemTransform

In [None]:
#|export
class ItemTransform(Transform):
    "A transform that always take tuples as items"
    _retain = True
    def __call__(self, x, **kwargs): return self._call1(x, '__call__', **kwargs)
    def decode(self, x, **kwargs):   return self._call1(x, 'decode', **kwargs)
    def _call1(self, x, name, **kwargs):
        if not is_tuple(x): return getattr(super(), name)(x, **kwargs)
        y = getattr(super(), name)(list(x), **kwargs)
        if not self._retain: return y
        if is_listy(y) and not isinstance(y, tuple): y = tuple(y)
        return retain_type(y, x, Any)
     

ItemTransform is the class to use to opt out of the default behavior of Transform.

In [None]:
class AIT(ItemTransform): 
    def encodes(self, xy): x,y=xy; return (x+y,y)
    def decodes(self, xy): x,y=xy; return (x-y,y)
    
f = AIT()
test_eq(f((1,2)), (3,2))
test_eq(f.decode((3,2)), (1,2))   

If you pass a special tuple subclass, the usual retain type behavior of Transform will keep it:

In [None]:
class _T(tuple): pass
x = _T((1,2))
test_eq_type(f(x), _T((3,2)))     

In [None]:
#|hide
f.split_idx = 0
test_eq_type(f((1,2)), (1,2))
test_eq_type(f((1,2), split_idx=0), (3,2))
test_eq_type(f.decode((1,2)), (1,2))
test_eq_type(f.decode((3,2), split_idx=0), (1,2))

In [None]:
#|hide
class Get(ItemTransform):
    _retain = False
    def encodes(self, x): return x[0]
    
g = Get()
test_eq(g([1,2,3]), 1)
test_eq(g(L(1,2,3)), 1)
test_eq(g(np.array([1,2,3])), 1)
test_eq_type(g((['a'], ['b', 'c'])), ['a'])    

In [None]:
#|hide
class A(ItemTransform): 
    def encodes(self, x): return _T((x,x))
    def decodes(self, x): return _T(x)
    
f = A()
test_eq(type(f.decode((1,1))), _T)

### Func

In [None]:
#|export
def get_func(t, name, *args, **kwargs):
    "Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
    f = nested_callable(t, name)
    return f if not (args or kwargs) else partial(f, *args, **kwargs)

This works for any kind of t supporting getattr, so a class or a module.



In [None]:
test_eq(get_func(operator, 'neg', 2)(), -2)
test_eq(get_func(operator.neg, '__call__')(2), -2)
test_eq(get_func(list, 'foobar')([2]), [2])
a = [2,1]
get_func(list, 'sort')(a)
test_eq(a, [1,2])

Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done with the Plum module and type-annotation in Transform, but you can also use the following class.

In [None]:
#|export
class Func():
    "Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
    def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
    def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
    def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
    def __call__(self,t): return mapped(self._get, t)

You can call the Func object on any module name or type, even a list of types. It will return the corresponding function (with a default to noop if nothing is found) or list of functions.



In [None]:
test_eq(Func('sqrt')(math), math.sqrt)

In [None]:
#|export
class _Sig():
    def __getattr__(self,k):
        def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
        return _inner

Sig = _Sig()
     

In [None]:
show_doc(Sig, name="Sig")


### Sig

Sig is just sugar-syntax to create a Func object more easily with the syntax Sig.name(*args, **kwargs).

In [None]:
f = Sig.sqrt()
test_eq(f(math), math.sqrt)

## Export -

In [None]:
#|hide
#|eval: false
from nbdev import nbdev_export
nbdev_export()