# 🍳  Common recipes

This section introduces common recipes you might need while using `serket` to train/build models.

In [1]:
!pip install serket --quiet

## [1] Per-leaf optimization
The following recipe, `optax.masked` is used to apply certain optmizers to certain leaves using masking.

In [2]:
import optax
import serket as sk
import jax

@sk.autoinit
class Tree(sk.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0


tree = Tree()

false_mask = tree.at[...].set(False)

a_mask = false_mask.at["a"].set(True)
b_mask = false_mask.at["b"].set(True)
c_mask = false_mask.at["c"].set(True)

optim = optax.chain(
    # update `a` with sgd of learning rate 1
    optax.masked(optax.adam(learning_rate=1), a_mask),
    # update `b` with sgd of learning rate -1
    optax.masked(optax.adam(learning_rate=-1), b_mask),
    # update `c` with sgd of learning rate 0
    optax.masked(optax.adam(learning_rate=0), c_mask),
)
optim_state = optim.init(sk.tree_mask(tree))
# the optimizer contains 3 sub-optimizers for each field of the tree
print(sk.tree_diagram(optim_state))

tuple
├── [0]:MaskedState
│   └── .inner_state:tuple
│       └── [0]:ScaleByAdamState
│           ├── .count=i32[](μ=0.00, σ=0.00, ∈[0,0])
│           ├── .mu:Tree
│           │   └── .a=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])
│           └── .nu:Tree
│               └── .a=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])
├── [1]:MaskedState
│   └── .inner_state:tuple
│       └── [0]:ScaleByAdamState
│           ├── .count=i32[](μ=0.00, σ=0.00, ∈[0,0])
│           ├── .mu:Tree
│           │   └── .b=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])
│           └── .nu:Tree
│               └── .b=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])
└── [2]:MaskedState
    └── .inner_state:tuple
        └── [0]:ScaleByAdamState
            ├── .count=i32[](μ=0.00, σ=0.00, ∈[0,0])
            ├── .mu:Tree
            │   └── .c=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])
            └── .nu:Tree
                └── .c=f32[](μ=0.00, σ=0.00, ∈[0.00,0.00])


## [2] Buffers
In this example, certain array will be marked as non-trainable using `jax.lax.stop_gradient` and `field`

In [3]:
import serket as sp
import jax
import jax.numpy as jnp


@sk.autoinit
class Tree(sk.TreeClass):
    buffer: jax.Array = sk.field(on_getattr=[jax.lax.stop_gradient])

    def __call__(self, x):
        return self.buffer**x


tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0)  # Array([1., 4., 9.], dtype=float32)


@jax.jit
def f(tree, x):
    return jnp.sum(tree(x))


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))

6.0
Tree(buffer=[0. 0. 0.])


## [3] Frozen fields
In this example, field value freezing is done on class level using `on_geatattr`, and `on_setattr`. This effectively hide the field value across `jax` transformation

In [4]:
import serket as sp
import jax


@sk.autoinit
class Tree(sk.TreeClass):
    frozen_a: int = sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze])

    def __call__(self, x):
        return self.frozen_a + x


tree = Tree(frozen_a=1)  # 1 is non-jaxtype
# can be used in jax transformations


@jax.jit
def f(tree, x):
    return tree(x)


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))

2.0
Tree(frozen_a=#1)


## [4] Parameterization

In this example, field value is [parameterized](https://pytorch.org/tutorials/intermediate/parametrizations.html) using `on_getattr`,


In [5]:
import serket as sp
import jax.numpy as jnp


def symmetric(array: jax.Array) -> jax.Array:
    triangle = jnp.triu(array)  # upper triangle
    return triangle + triangle.transpose(-1, -2)


@sk.autoinit
class Tree(sk.TreeClass):
    symmetric_matrix: jax.Array = sk.field(on_getattr=[symmetric])


tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
print(tree.symmetric_matrix)

[[ 0  1  2]
 [ 1  8  5]
 [ 2  5 16]]


## [5] `numpy` and `TreeClass`.

In this reciep, `numpy` functions will operate directly on `TreeClass` instances.

In [6]:
import serket as sp
import jax.numpy as jnp


@sk.leafwise  # enable math operations on leaves
@sk.autoinit  # enable __init__ from type annotations
class Tree(sk.TreeClass):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

tree = Tree()

# make where work with arbitrary pytrees
tree_where = sk.bcmap(jnp.where)
# for values > 2, add 100, else set to 0
print(tree_where(tree > 2, tree + 100, 0))

Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])


## [6] Validate and conversion.

### Type and number range check

In [7]:
import jax
import serket as sp


# you can use any function
@sk.autoinit
class Range(sk.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


@sk.autoinit
class IsInstance(sk.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


@sk.autoinit
class Foo(sk.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = sk.field(on_setattr=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)

On applying Range(min=1, max=100) for field=`in_dim`:
0 not in range [1, 100]
On applying IsInstance(klass=<class 'int'>) for field=`in_dim`:
1.0 not an instance of <class 'int'>


### Array shape and dtype check, then dtype conversion

In [8]:
import serket as sp
from typing import Any
import jax
import jax.numpy as jnp


class ArrayValidator(sk.TreeClass):
    def __init__(self, shape, dtype):
        """Validate shape and dtype of input array.

        Args:
            shape: Expected shape of array. available values are int, None, ...
                use int for fixed size, None for any size, and ... for any number
                of dimensions. for example (..., 1) allows any number of dimensions
                with the last dimension being 1. (1, ..., 1) allows any number of
                dimensions with the first and last dimensions being 1.
            dtype: Expected dtype of array.

        Example:
            >>> x = jnp.ones((5, 5))
            >>> # any number of dimensions with last dim=5
            >>> shape = (..., 5)
            >>> dtype = jnp.float32
            >>> validator = ArrayValidator(shape, dtype)
            >>> validator(x)  # no error

            >>> # must be 2 dimensions with first dim unconstrained and last dim=5
            >>> shape = (None, 5)
            >>> validator = ArrayValidator(shape, dtype)
            >>> validator(x)  # no error
        """

        if shape.count(...) > 1:
            raise ValueError("Only one ellipsis allowed")

        for si in shape:
            if not isinstance(si, (int, type(...), type(None))):
                raise TypeError(f"Expected int or ..., got {si}")

        self.shape = shape
        self.dtype = dtype

    def __call__(self, x):
        if not (hasattr(x, "shape") and hasattr(x, "dtype")):
            raise TypeError(f"Expected array with shape {self.shape}, got {x}")

        shape = list(self.shape)
        array_shape = list(x.shape)
        array_dtype = x.dtype

        if self.shape and array_dtype != self.dtype:
            raise TypeError(f"Dtype mismatch, {array_dtype=} != {self.dtype=}")

        if ... in shape:
            index = shape.index(...)
            shape = (
                shape[:index]
                + [None] * (len(array_shape) - len(shape) + 1)
                + shape[index + 1 :]
            )

        if len(shape) != len(array_shape):
            raise ValueError(f"{len(shape)=} != {len(array_shape)=}")

        for i, (li, ri) in enumerate(zip(shape, array_shape)):
            if li is None:
                continue
            if li != ri:
                raise ValueError(f"Size mismatch, {li} != {ri} at dimension {i}")
        return x


# any number of dimensions with firt dim=3 and last dim=6
shape = (3, ..., 6)
# dtype must be float32
dtype = jnp.float32

validator = ArrayValidator(shape=shape, dtype=dtype)

# convert to half precision from float32
converter = lambda x: x.astype(jnp.float16)


@sk.autoinit
class Tree(sk.TreeClass):
    array: jax.Array = sk.field(on_setattr=[validator, converter])


x = jnp.ones([3, 1, 2, 6])
tree = Tree(array=x)


try:
    y = jnp.ones([1, 1, 2, 3])
    tree = Tree(array=y)
except ValueError as e:
    print(e, "\n")
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

try:
    z = x.astype(jnp.float16)
    tree = Tree(array=z)
except TypeError as e:
    print(e)
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Size mismatch, 3 != 1 at dimension 0

On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Size mismatch, 3 != 1 at dimension 0 

On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>


## [7] Lazy layers.
In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `.at` to create a new instance with the added parameter.

In [9]:
import serket as sp
from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jr


@sk.autoinit
class LazyLinear(sk.TreeClass):
    out_features: int

    def param(self, name: str, value: Any):
        # return the value if it exists, otherwise set it and return it
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
        weight = self.param("weight", jnp.ones((x.shape[-1], self.out_features)))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return x @ weight + bias


x = jnp.ones([10, 1])

lazy_linear = LazyLinear(out_features=1)

lazy_linear
print(f"Layer before param is set:\t{lazy_linear}")


# first call will set the parameters
_, linear = lazy_linear.at["__call__"](x, key=jr.PRNGKey(0))

print(f"Layer after param is set:\t{linear}")
# subsequent calls will use the same parameters and not set them again
linear(x)

Layer before param is set:	LazyLinear(out_features=1)
Layer after param is set:	LazyLinear(out_features=1, weight=[[1.]], bias=[0.])


Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)

## [8] Intermediates handling.

This example shows how to capture specific intermediate values within each function call in this example.

In [10]:
from typing import Any
import serket as sp
import jax
import optax
import jax.numpy as jnp


@sk.autoinit
class Tree(sk.TreeClass):
    a: float = 1.0

    def __call__(self, x: jax.Array, intermediate: tuple[Any, ...]):
        x = x + self.a
        # store intermediate variables
        return x, intermediate + (x,)


def loss_func(tree: Tree, x: jax.Array, y: jax.Array, intermediate: tuple[Any, ...]):
    ypred, intermediate = tree(x, intermediate)
    loss = jnp.mean((ypred - y) ** 2)
    return loss, intermediate


@jax.jit
def train_step(
    tree: Tree,
    optim_state: optax.OptState,
    x: jax.Array,
    y: jax.Array,
    intermediate: tuple[Any, ...],
):
    grads, intermediate = jax.grad(loss_func, has_aux=True)(tree, x, y, intermediate)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state, intermediate


tree = Tree()
optim = optax.adam(1e-1)
optim_state = optim.init(tree)

x = jnp.linspace(-1, 1, 5)[:, None]
y = x**2

intermediate = ()

for i in range(2):
    tree, optim_state, intermediate = train_step(tree, optim_state, x, y, intermediate)


print("Intermediate values:\t\n", intermediate)
print("\nFinal tree:\t\n", tree)

Intermediate values:	
 (Array([[0. ],
       [0.5],
       [1. ],
       [1.5],
       [2. ]], dtype=float32), Array([[-0.09999937],
       [ 0.40000063],
       [ 0.90000063],
       [ 1.4000006 ],
       [ 1.9000006 ]], dtype=float32))

Final tree:	
 Tree(a=0.801189)


## [9] Layers from configuations.
The next example shows how to use `serket.bcmap` to loop over a configuration dictionary that defines creation of simple linear layers.

In [11]:
import serket as sp
import jax


class Linear(sk.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, *, key: jax.random.KeyArray):
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jnp.zeros((out_dim,))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight + self.bias


config = {
    # each layer gets a different input dimension
    "in_dim": [1, 2, 3, 4],
    # out_dim is broadcasted to all layers
    "out_dim": 1,
    # each layer gets a different key
    "key": list(jax.random.split(jax.random.PRNGKey(0), 4)),
}


# `bcmap` transforms a function that takes a single input into a function that
# arbitrary pytree inputs. in case of a single input, the input is broadcasted
# to match the tree structure of the first argument
# (in our example is a list of 4 inputs)


@sk.bcmap
def build_layer(in_dim, out_dim, *, key: jax.random.KeyArray):
    return Linear(in_dim, out_dim, key=key)


build_layer(config["in_dim"], config["out_dim"], key=config["key"])

[Linear(
   weight=f32[1,1](μ=0.31, σ=0.00, ∈[0.31,0.31]), 
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[2,1](μ=-1.27, σ=0.33, ∈[-1.59,-0.94]), 
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[3,1](μ=0.24, σ=0.53, ∈[-0.48,0.77]), 
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 ),
 Linear(
   weight=f32[4,1](μ=-0.28, σ=0.21, ∈[-0.64,-0.08]), 
   bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
 )]

## [10] Ensembling
In this example, simple `Linear` layers are grouped by their weight on the first axis using `jax.vmap`. This is useful if the different instances of the model are desired to run in a vectorized fashion (model ensemble).

For more check [here](http://matpalm.com/blog/ensemble_nets/)

In [12]:
import jax
import jax.numpy as jnp
import jax.random as jr
import serket as sp
import functools as ft
from typing import Generic, TypeVar

T = TypeVar("T")


class Batched(Generic[T]):
    ...


class Linear(sk.TreeClass):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        *,
        key: jr.KeyArray,
        name: str,
    ):
        self.weight = jr.normal(key, (in_dim, out_dim))
        self.bias = jnp.zeros((out_dim,))
        self.name = name  # non-jax type for `tree_mask`/`tree_unmask` demonstration

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight + self.bias


class FNN(sk.TreeClass):
    def __init__(self, key: jr.KeyArray):
        k1, k2, k3 = jr.split(key, 3)
        self.l1 = Linear(1, 10, key=k1, name="l1")
        self.l2 = Linear(10, 10, key=k2, name="l2")
        self.l3 = Linear(10, 1, key=k3, name="l3")

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.l1(x)
        x = jax.nn.relu(x)
        x = self.l2(x)
        x = jax.nn.relu(x)
        x = self.l3(x)
        return x


def build_ensemble(keys: jr.KeyArray) -> Batched[FNN]:
    @jax.vmap
    def build_liner(key: jr.KeyArray):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return sk.tree_mask(FNN(key=key))

    return sk.tree_unmask(build_liner(keys))


def run_single_input_ensemble(fnns: Batched[FNN], x: jax.Array):
    def run_linear(fnn: FNN):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return sk.tree_mask(fnn(x))

    return jax.vmap(run_linear)(sk.tree_mask(fnns))


def run_multi_input_ensemble(fnns: Batched[FNN], x: Batched[jax.Array]):
    def run_linear(fnn: FNN, x: jax.Array):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return sk.tree_mask(fnn(x))

    return jax.vmap(run_linear)(sk.tree_mask(fnns), x)


num_layers = 4
keys = jr.split(jr.PRNGKey(0), num_layers)

# single input ensemble
# e.g. each model in the ensemble gets the same input
x = jnp.ones([10, 1])
fnns = build_ensemble(keys=keys)
y = run_single_input_ensemble(fnns, x)
print(f"Single input ensemble shape:\t{y.shape}")

# multi input ensemble
# e.g. each model in the ensemble gets a different input
xs = jnp.stack([x, x * 2, x * 3, x * 4])
fnns = build_ensemble(keys=keys)
ys = run_multi_input_ensemble(fnns, xs)
print(f"Multi input ensemble shape:\t{ys.shape}")

Single input ensemble shape:	(4, 10, 1)
Multi input ensemble shape:	(4, 10, 1)


## [11] Data pipelines

In this example, `AtIndexer` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines.

In [13]:
from serket import AtIndexer
import jax


class Transaction:
    def __init__(self, reason, amount):
        self.reason = reason
        self.amount = amount


# this example copied from  https://github.com/EntilZha/PyFunctional
transactions = [
    Transaction("github", 7),
    Transaction("food", 10),
    Transaction("coffee", 5),
    Transaction("digitalocean", 5),
    Transaction("food", 5),
    Transaction("riotgames", 25),
    Transaction("food", 10),
    Transaction("amazon", 200),
    Transaction("paycheck", -1000),
]

indexer = AtIndexer(transactions)
where = jax.tree_map(lambda x: x.reason == "food", transactions)
food_cost = indexer[where].reduce(lambda x, y: x + y.amount, initializer=0)
food_cost

25

## [12] Regaularization

The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:

### Boolean-based mask

The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (`True`/`False`) leave. The `True` leaf points to place where the operation can be done, while `False` leaf is indicating that this leaf should not be touched.

In [14]:
import serket as sk
import jax.numpy as jnp
import jax


class Net(sk.TreeClass):
    def __init__(self):
        self.weight = jnp.array([-1, -2, -3, 1, 2, 3])
        self.bias = jnp.array([-1, 1])


def negative_entries_l2_loss(net: Net):
    return (
        # select all positive array entries
        net.at[jax.tree_map(lambda x: x > 0, net)]
        # set them to zero to exclude their loss
        .set(0)
        # select all leaves
        .at[...]
        # finally reduce with l2 loss
        .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)
    )


net = Net()
print(negative_entries_l2_loss(net))

2.8333335


## Name-based mask

In this step, the mask is based on the path of the leaf.

In [15]:
# note that `weight` is a leaf node in this layer
# the `weight` leaf will be selected later in the next example.
print(repr(sk.nn.Linear(1,1,key=jax.random.PRNGKey(0))))

Linear(
  in_features=(1), 
  out_features=1, 
  weight_init=glorot_uniform, 
  bias_init=zeros, 
  weight=f32[1,1](μ=0.20, σ=0.00, ∈[0.20,0.20]), 
  bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)


In [16]:
import serket as sk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


class Net(sk.TreeClass):
    def __init__(self, key: jax.random.KeyArray) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.linear1 = sk.nn.Linear(in_features=1, out_features=20, key=k1)
        self.linear2 = sk.nn.Linear(in_features=20, out_features=20, key=k2)
        self.linear3 = sk.nn.Linear(in_features=20, out_features=20, key=k3)
        self.linear4 = sk.nn.Linear(in_features=20, out_features=1, key=k4)

    def __call__(self, x):
        x = jax.nn.tanh(self.linear1(x))
        x = jax.nn.tanh(self.linear2(x))
        x = jax.nn.tanh(self.linear3(x))
        x = self.linear4(x)
        return x


def linear_12_weight_l1_loss(net: Net):
    return (
        # select desired branches (linear1, linear2 in this example)
        # and the desired leaves (weight)
        net.at["linear1", "linear2"]["weight"]
        # alternatively, regex can be used to do the same functiontality
        # >>> import re
        # >>> net.at[re.compile("linear[12]")]["weight"]
        # finally apply l1 loss
        .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)
    )


net = Net(key=jax.random.PRNGKey(0))
print(linear_12_weight_l1_loss(net))

82.83809


This recipe can then be included inside the loss function, for example

``` python

def loss_fnc(net, x, y):
    l1_loss = linear_12_weight_l1_loss(net)
    loss += l1_loss
    ...
```