# 🏟️ Fields

In [1]:
!pip install sepes

This section introduces common recipes for fields. A `sepes.field` is class variable that adds certain functionality to the class with `jax` and `numpy`, but this can work with any other framework.

Add field is written like this:

```python
class MyClass:
    my_field: Any = sepes.field()
```
For example, a `field` can be used to validate the input data, or to provide a default value. The notebook provides examples for common use cases.

`sepes.field` is implemented as a [python descriptor](https://docs.python.org/3/howto/descriptor.html), which means that it can be used in any class not necessarily a `sepes` class. Refer to the [python documentation](https://docs.python.org/3/howto/descriptor.html) for more information on descriptors and how they work.

## [1] Buffers
In this example, certain array will be marked as non-trainable using `jax.lax.stop_gradient` and `field`.

The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:
```python
class Tree(sp.TreeClass):
    def __init__(self, buffer: jax.Array):
        self.buffer = buffer

    def __call__(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
```
However, if you access this buffer from other methods, then another `jax.lax.stop_gradient` should be used and written inside all the methods:

```python
class Tree(sp.TreeClass):
    def method_1(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
        .
        .
        .
    def method_n(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
```

Similarly, if you access `buffer` defined for `Tree` instances, from another context, you need to use `jax.lax.stop_gradient` again:

```python
tree = Tree(buffer=...)
def func(tree: Tree):
    buffer = jax.lax.stop_gradient(tree.buffer)
    ...    
```

This becomes **cumbersome** if this process is repeated multiple times.Alternatively, `jax.lax.stop_gradient` can be applied to the `buffer` using `sepes.field` whenever the buffer is accessed. The next example demonstrates this.

In [2]:
import sepes as sp
import jax
import jax.numpy as jnp


def buffer_field(**kwargs):
    return sp.field(on_getattr=[jax.lax.stop_gradient], **kwargs)


@sp.autoinit  # autoinit construct `__init__` from fields
class Tree(sp.TreeClass):
    buffer: jax.Array = buffer_field()

    def __call__(self, x):
        return self.buffer**x


tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0)  # Array([1., 4., 9.], dtype=float32)


@jax.jit
def f(tree: Tree, x: jax.Array):
    return jnp.sum(tree(x))


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))

6.0
Tree(buffer=[0. 0. 0.])


## [2] Masked field

`sepes` provide a simple wrapper to *mask* data. Masking here means that the data yields no leaves when flattened. This is useful in some frameworks like `jax` to hide a certain values from being seen by the transformation.

**Flattening a masked value**

In [3]:
import sepes as sp
import jax

tree = [1, sp.tree_mask(2, cond=lambda _: True)]
print(tree)
print(jax.tree_util.tree_leaves(tree))  # note that 2 is removed from the leaves

[1, #2]
[1]


**Using masking with `jax` transformations**

The next example demonstrates how to use masking to work with data types that are not supported by `jax`.

In [4]:
import sepes as sp
import jax


def mask_field(**kwargs):
    return sp.field(
        # un mask when the value is accessed
        on_getattr=[lambda x: sp.tree_unmask(x, cond=lambda node: True)],
        # mask when the value is set
        on_setattr=[lambda x: sp.tree_mask(x, cond=lambda node: True)],
        **kwargs,
    )

Now we can use this custom `field` to mark some class attributes as masked. Masking a value will effectively hide it from `jax` transformations.

**Without masking the `str` type**

In [5]:
@sp.autoinit
class Tree(sp.TreeClass):
    training_mode: str  # <- will throw error with jax transformations.
    alpha: float

    def __call__(self, x):
        if self.training_mode == "training":
            return x**self.alpha
        return x


@jax.grad
def loss_func(tree, input):
    return tree(input)


tree = Tree(training_mode="training", alpha=2.0)
print(loss_func(tree, 2.0))  # <- will throw error with jax transformations.

TypeError: Argument 'training' of type <class 'str'> is not a valid JAX type.

The error resulted because `jax` recognize numerical values only. The next example demonstrates how to modify the class to mask the `str` type.

In [None]:
@sp.autoinit
class Tree(sp.TreeClass):
    training_mode: str = mask_field()  # hide the field from jax transformations
    alpha: float

    def __call__(self, x):
        if self.training_mode == "training":
            return x**self.alpha
        return x


tree = Tree(training_mode="training", alpha=2.0)
print(loss_func(tree, 2.0))

## [3] Validator fields

The following provides an example of how to use `sepes.field` to validate the input data. The `validator` function is used to check if the input data is valid. If the data is invalid, an exception is raised. This example is inspired by the [python offical docs example](https://docs.python.org/3/howto/descriptor.html#validator-class)

### Range+Type validator

In [None]:
import jax
import sepes as sp


# you can use any function
@sp.autoinit
class Range(sp.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


@sp.autoinit
class IsInstance(sp.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


@sp.autoinit
class Foo(sp.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = sp.field(on_setattr=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)

### Array validator

In [None]:
import sepes as sp
from typing import Any
import jax
import jax.numpy as jnp


class ArrayValidator(sp.TreeClass):
    """Validate shape and dtype of input array.

    Args:
        shape: Expected shape of array. available values are int, None, ...
            use int for fixed size, None for any size, and ... for any number
            of dimensions. for example (..., 1) allows any number of dimensions
            with the last dimension being 1. (1, ..., 1) allows any number of
            dimensions with the first and last dimensions being 1.
        dtype: Expected dtype of array.

    Example:
        >>> x = jnp.ones((5, 5))
        >>> # any number of dimensions with last dim=5
        >>> shape = (..., 5)
        >>> dtype = jnp.float32
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error

        >>> # must be 2 dimensions with first dim unconstrained and last dim=5
        >>> shape = (None, 5)
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error
    """

    def __init__(self, shape, dtype):
        if shape.count(...) > 1:
            raise ValueError("Only one ellipsis allowed")

        for si in shape:
            if not isinstance(si, (int, type(...), type(None))):
                raise TypeError(f"Expected int or ..., got {si}")

        self.shape = shape
        self.dtype = dtype

    def __call__(self, x):
        if not (hasattr(x, "shape") and hasattr(x, "dtype")):
            raise TypeError(f"Expected array with shape {self.shape}, got {x}")

        shape = list(self.shape)
        array_shape = list(x.shape)
        array_dtype = x.dtype

        if self.shape and array_dtype != self.dtype:
            raise TypeError(f"Dtype mismatch, {array_dtype=} != {self.dtype=}")

        if ... in shape:
            index = shape.index(...)
            shape = (
                shape[:index]
                + [None] * (len(array_shape) - len(shape) + 1)
                + shape[index + 1 :]
            )

        if len(shape) != len(array_shape):
            raise ValueError(f"{len(shape)=} != {len(array_shape)=}")

        for i, (li, ri) in enumerate(zip(shape, array_shape)):
            if li is None:
                continue
            if li != ri:
                raise ValueError(f"Size mismatch, {li} != {ri} at dimension {i}")
        return x


# any number of dimensions with firt dim=3 and last dim=6
shape = (3, ..., 6)
# dtype must be float32
dtype = jnp.float32

validator = ArrayValidator(shape=shape, dtype=dtype)

# convert to half precision from float32
converter = lambda x: x.astype(jnp.float16)


@sp.autoinit
class Tree(sp.TreeClass):
    array: jax.Array = sp.field(on_setattr=[validator, converter])


x = jnp.ones([3, 1, 2, 6])
tree = Tree(array=x)


try:
    y = jnp.ones([1, 1, 2, 3])
    tree = Tree(array=y)
except ValueError as e:
    print(e, "\n")
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

try:
    z = x.astype(jnp.float16)
    tree = Tree(array=z)
except TypeError as e:
    print(e)
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Size mismatch, 3 != 1 at dimension 0

## [4] Parameterization field

In this example, field value is [parameterized](https://pytorch.org/tutorials/intermediate/parametrizations.html) using `on_getattr`,


In [None]:
import sepes as sp
import jax.numpy as jnp


def symmetric(array: jax.Array) -> jax.Array:
    triangle = jnp.triu(array)  # upper triangle
    return triangle + triangle.transpose(-1, -2)


@sp.autoinit
class Tree(sp.TreeClass):
    symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])


tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
print(tree.symmetric_matrix)