# 🏟️ Fields

In [1]:
# !pip install git+https://github.com/ASEM000/serket --quiet

This section introduces common recipes for fields. A `field` is a type hinted class variable that adds certain functionality to the class. For example, a `field` can be used to validate the input data, or to provide a default value. The notebook provides examples for common use cases.

## [1] Buffers
In this example, certain array will be marked as non-trainable using `jax.lax.stop_gradient` and `field`.

The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:
```python
class Tree(sk.TreeClass):
    def __init__(self, buffer: jax.Array):
        self.buffer = buffer

    def __call__(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
```
However, if you access this buffer from other methods, then it another `jax.lax.stop_gradient` should be used and written inside all the methods:

```python
class Tree(sk.TreeClass):
    def method_1(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
        .
        .
        .
    def method_n(self, x: jax.Array) -> jax.Array:
        return x + jax.lax.stop_gradient(self.buffer)
```

Similarly, if you access `buffer` defined for `Tree` instances, from another context, you need to use `jax.lax.stop_gradient` again:

```python
tree = Tree(buffer=...)
def func(tree: Tree):
    buffer = jax.lax.stop_gradient(tree.buffer)
    ...    
```

This becomes cumbersome if this process is repeated multiple times. for this, applying `jax.lax.stop_gradient` on `__getattr__` using `on_getattr` is simpler to use, because you need to define it only once. 

The following example demonstrate this point.

In [None]:
import serket as sk
import jax
import jax.numpy as jnp


def buffer_field(**kwargs):
    return sk.field(on_getattr=[jax.lax.stop_gradient], **kwargs)


@sk.autoinit  # autoinit construct `__init__` from fields
class Tree(sk.TreeClass):
    buffer: jax.Array = buffer_field()

    def __call__(self, x):
        return self.buffer**x


tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0)  # Array([1., 4., 9.], dtype=float32)


@jax.jit
def f(tree: Tree, x: jax.Array):
    return jnp.sum(tree(x))


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))

6.0
Tree(buffer=[0. 0. 0.])


## [2] Frozen field

In this example, field value freezing is done on class level using `on_geatattr`, and `on_setattr`. This effectively hide the field value across `jax` transformation.

Hiding a field means that this field value does not get traced/updated by `jax` internals.

In [4]:
import serket as sk
import jax


def frozen_field(**kwargs):
    return sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze], **kwargs)


@sk.autoinit
class Tree(sk.TreeClass):
    frozen_a: int = frozen_field()

    def __call__(self, x):
        return self.frozen_a + x


tree = Tree(frozen_a=1)  # 1 is non-jaxtype


@jax.jit
def f(tree, x):
    return tree(x)


print(f(tree, 1.0))

print(jax.grad(f)(tree, 1.0))

# not visible to `jax.tree_util...`
print(jax.tree_util.tree_leaves(tree))

2.0
Tree(frozen_a=#1)
[]


To unfreeze the frozen values, use `tree_unmask`:

In [5]:
print(jax.tree_util.tree_leaves(sk.tree_unmask(tree)))

[1]


## [3] Range+Type `validator`

In [9]:
import jax
import serket as sk


# you can use any function
@sk.autoinit
class Range(sk.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


@sk.autoinit
class IsInstance(sk.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


@sk.autoinit
class Foo(sk.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = sk.field(on_setattr=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)

On applying Range(min=1, max=100) for field=`in_dim`:
0 not in range [1, 100]
On applying IsInstance(klass=<class 'int'>) for field=`in_dim`:
1.0 not an instance of <class 'int'>


## [4] `Array` validator

In [10]:
import serket as sk
from typing import Any
import jax
import jax.numpy as jnp


class ArrayValidator(sk.TreeClass):
    """Validate shape and dtype of input array.

    Args:
        shape: Expected shape of array. available values are int, None, ...
            use int for fixed size, None for any size, and ... for any number
            of dimensions. for example (..., 1) allows any number of dimensions
            with the last dimension being 1. (1, ..., 1) allows any number of
            dimensions with the first and last dimensions being 1.
        dtype: Expected dtype of array.

    Example:
        >>> x = jnp.ones((5, 5))
        >>> # any number of dimensions with last dim=5
        >>> shape = (..., 5)
        >>> dtype = jnp.float32
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error

        >>> # must be 2 dimensions with first dim unconstrained and last dim=5
        >>> shape = (None, 5)
        >>> validator = ArrayValidator(shape, dtype)
        >>> validator(x)  # no error
    """
    def __init__(self, shape, dtype):
        if shape.count(...) > 1:
            raise ValueError("Only one ellipsis allowed")

        for si in shape:
            if not isinstance(si, (int, type(...), type(None))):
                raise TypeError(f"Expected int or ..., got {si}")

        self.shape = shape
        self.dtype = dtype

    def __call__(self, x):
        if not (hasattr(x, "shape") and hasattr(x, "dtype")):
            raise TypeError(f"Expected array with shape {self.shape}, got {x}")

        shape = list(self.shape)
        array_shape = list(x.shape)
        array_dtype = x.dtype

        if self.shape and array_dtype != self.dtype:
            raise TypeError(f"Dtype mismatch, {array_dtype=} != {self.dtype=}")

        if ... in shape:
            index = shape.index(...)
            shape = (
                shape[:index]
                + [None] * (len(array_shape) - len(shape) + 1)
                + shape[index + 1 :]
            )

        if len(shape) != len(array_shape):
            raise ValueError(f"{len(shape)=} != {len(array_shape)=}")

        for i, (li, ri) in enumerate(zip(shape, array_shape)):
            if li is None:
                continue
            if li != ri:
                raise ValueError(f"Size mismatch, {li} != {ri} at dimension {i}")
        return x


# any number of dimensions with firt dim=3 and last dim=6
shape = (3, ..., 6)
# dtype must be float32
dtype = jnp.float32

validator = ArrayValidator(shape=shape, dtype=dtype)

# convert to half precision from float32
converter = lambda x: x.astype(jnp.float16)


@sk.autoinit
class Tree(sk.TreeClass):
    array: jax.Array = sk.field(on_setattr=[validator, converter])


x = jnp.ones([3, 1, 2, 6])
tree = Tree(array=x)


try:
    y = jnp.ones([1, 1, 2, 3])
    tree = Tree(array=y)
except ValueError as e:
    print(e, "\n")
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

try:
    z = x.astype(jnp.float16)
    tree = Tree(array=z)
except TypeError as e:
    print(e)
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Size mismatch, 3 != 1 at dimension 0

On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Size mismatch, 3 != 1 at dimension 0 

On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>


## [5] Parameterization field

In this example, field value is [parameterized](https://pytorch.org/tutorials/intermediate/parametrizations.html) using `on_getattr`,


In [6]:
import serket as sk
import jax.numpy as jnp


def symmetric(array: jax.Array) -> jax.Array:
    triangle = jnp.triu(array)  # upper triangle
    return triangle + triangle.transpose(-1, -2)


@sk.autoinit
class Tree(sk.TreeClass):
    symmetric_matrix: jax.Array = sk.field(on_getattr=[symmetric])


tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
print(tree.symmetric_matrix)

[[ 0  1  2]
 [ 1  8  5]
 [ 2  5 16]]
