# 🗂️  Misc recipes

In [1]:
# !pip install git+https://github.com/ASEM000/serket --quiet

This section introduces some miscellaneous recipes that are not covered in the previous sections.

## [1] Lazy layers.
In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `.at` to create a new instance with the added parameter.

In [2]:
import serket as sk
from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jr


@sk.autoinit
class LazyLinear(sk.TreeClass):
    out_features: int

    def param(self, name: str, value: Any):
        # return the value if it exists, otherwise set it and return it
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, x: jax.Array, *, key: jax.Array = jr.PRNGKey(0)):
        weight = self.param("weight", jnp.ones((x.shape[-1], self.out_features)))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return x @ weight + bias


x = jnp.ones([10, 1])

lazy_linear = LazyLinear(out_features=1)

lazy_linear
print(f"Layer before param is set:\t{lazy_linear}")


# first call will set the parameters
_, linear = lazy_linear.at["__call__"](x, key=jr.PRNGKey(0))

print(f"Layer after param is set:\t{linear}")
# subsequent calls will use the same parameters and not set them again
linear(x)

Layer before param is set:	LazyLinear(out_features=1)
Layer after param is set:	LazyLinear(out_features=1, weight=[[1.]], bias=[0.])


Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)

## [2] Intermediates handling.

This example shows how to capture specific intermediate values within each function call in this example.

In [3]:
from typing import Any
import serket as sk
import jax
import optax
import jax.numpy as jnp


@sk.autoinit
class Tree(sk.TreeClass):
    a: float = 1.0

    def __call__(self, x: jax.Array, intermediate: tuple[Any, ...]):
        x = x + self.a
        # store intermediate variables
        return x, intermediate + (x,)


def loss_func(tree: Tree, x: jax.Array, y: jax.Array, intermediate: tuple[Any, ...]):
    ypred, intermediate = tree(x, intermediate)
    loss = jnp.mean((ypred - y) ** 2)
    return loss, intermediate


@jax.jit
def train_step(
    tree: Tree,
    optim_state: optax.OptState,
    x: jax.Array,
    y: jax.Array,
    intermediate: tuple[Any, ...],
):
    grads, intermediate = jax.grad(loss_func, has_aux=True)(tree, x, y, intermediate)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state, intermediate


tree = Tree()
optim = optax.adam(1e-1)
optim_state = optim.init(tree)

x = jnp.linspace(-1, 1, 5)[:, None]
y = x**2

intermediate = ()

for i in range(2):
    tree, optim_state, intermediate = train_step(tree, optim_state, x, y, intermediate)


print("Intermediate values:\t\n", intermediate)
print("\nFinal tree:\t\n", tree)

Intermediate values:	
 (Array([[0. ],
       [0.5],
       [1. ],
       [1.5],
       [2. ]], dtype=float32), Array([[-0.09999937],
       [ 0.40000063],
       [ 0.90000063],
       [ 1.4000006 ],
       [ 1.9000006 ]], dtype=float32))

Final tree:	
 Tree(a=0.801189)


## [3] Data pipelines

In this example, `AtIndexer` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines.

In [4]:
from serket import AtIndexer
import jax


class Transaction:
    def __init__(self, reason, amount):
        self.reason = reason
        self.amount = amount


# this example copied from  https://github.com/EntilZha/PyFunctional
transactions = [
    Transaction("github", 7),
    Transaction("food", 10),
    Transaction("coffee", 5),
    Transaction("digitalocean", 5),
    Transaction("food", 5),
    Transaction("riotgames", 25),
    Transaction("food", 10),
    Transaction("amazon", 200),
    Transaction("paycheck", -1000),
]

indexer = AtIndexer(transactions)
where = jax.tree_map(lambda x: x.reason == "food", transactions)
food_cost = indexer[where].reduce(lambda x, y: x + y.amount, initializer=0)
food_cost

25

## [4] Regularization

The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:

### Boolean-based mask

The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (`True`/`False`) leave. The `True` leaf points to place where the operation can be done, while `False` leaf is indicating that this leaf should not be touched.

In [5]:
import serket as sk
import jax.numpy as jnp
import jax


class Net(sk.TreeClass):
    def __init__(self):
        self.weight = jnp.array([-1, -2, -3, 1, 2, 3])
        self.bias = jnp.array([-1, 1])


def negative_entries_l2_loss(net: Net):
    return (
        # select all positive array entries
        net.at[jax.tree_map(lambda x: x > 0, net)]
        # set them to zero to exclude their loss
        .set(0)
        # select all leaves
        .at[...]
        # finally reduce with l2 loss
        .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)
    )


net = Net()
print(negative_entries_l2_loss(net))

2.8333335


### Name-based mask

In this step, the mask is based on the path of the leaf.

In [6]:
# note that `weight` is a leaf node in this layer
# the `weight` leaf will be selected later in the next example.
print(repr(sk.nn.Linear(1, 1, key=jax.random.PRNGKey(0))))

Linear(
  in_features=(1), 
  out_features=1, 
  in_axis=(-1), 
  out_axis=-1, 
  weight_init=glorot_uniform, 
  bias_init=zeros, 
  weight=f32[1,1](μ=0.20, σ=0.00, ∈[0.20,0.20]), 
  bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)


In [7]:
import serket as sk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


class Net(sk.TreeClass):
    def __init__(self, key: jax.Array) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.linear1 = sk.nn.Linear(in_features=1, out_features=20, key=k1)
        self.linear2 = sk.nn.Linear(in_features=20, out_features=20, key=k2)
        self.linear3 = sk.nn.Linear(in_features=20, out_features=20, key=k3)
        self.linear4 = sk.nn.Linear(in_features=20, out_features=1, key=k4)

    def __call__(self, x):
        x = jax.nn.tanh(self.linear1(x))
        x = jax.nn.tanh(self.linear2(x))
        x = jax.nn.tanh(self.linear3(x))
        x = self.linear4(x)
        return x


def linear_12_weight_l1_loss(net: Net):
    return (
        # select desired branches (linear1, linear2 in this example)
        # and the desired leaves (weight)
        net.at["linear1", "linear2"]["weight"]
        # alternatively, regex can be used to do the same functiontality
        # >>> import re
        # >>> net.at[re.compile("linear[12]")]["weight"]
        # finally apply l1 loss
        .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)
    )


net = Net(key=jax.random.PRNGKey(0))
print(linear_12_weight_l1_loss(net))

82.83809


This recipe can then be included inside the loss function, for example

``` python

def loss_fnc(net, x, y):
    l1_loss = linear_12_weight_l1_loss(net)
    loss += l1_loss
    ...
```

## [5] Sharing/Tie Weights

In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated.

In [8]:
import serket as sk
import jax
import jax.numpy as jnp
import jax.random as jr


class TiedAutoEncoder(sk.TreeClass):
    def __init__(self, *, key: jax.Array):
        k1, k2 = jr.split(key)
        self.encoder = sk.nn.Linear(1, 10, key=k1)
        # set the unused weight of decoder to `None` to avoid memory usage
        self.decoder = sk.nn.Linear(10, 1, key=k2).at["weight"].set(None)

    def _call(self, x):
        # share/tie weights of encoder and decoder
        # however this operation mutates the state
        # so this method will only work with .at
        # otherwise will throw `AttributeError`
        self.decoder.weight = self.encoder.weight.T
        out = self.decoder(jax.nn.relu(self.encoder(x)))
        return out

    def __call__(self, x):
        # make the mutating method `_call` work with .at
        # since .at returns a tuple of the method value and a new instance
        # of the class that has the mutated state (i.e. does not mutate in place)
        # then we can define __call__ to return only the result of the method
        # and ignore the new instance of the class
        out, _ = self.at["_call"](x)
        return out


tree = sk.tree_mask(TiedAutoEncoder(key=jr.PRNGKey(0)))


@jax.jit
@jax.grad
def loss_func(net, x, y):
    net = sk.tree_unmask(net)
    return jnp.mean((jax.vmap(net)(x) - y) ** 2)


tree = sk.tree_mask(tree)
x = jnp.ones([10, 1]) + 0.0
y = jnp.ones([10, 1]) * 2.0
grads: TiedAutoEncoder = loss_func(tree, x, y)

grads

TiedAutoEncoder(
  encoder=Linear(
    in_features=(#1), 
    out_features=#10, 
    in_axis=(#-1), 
    out_axis=#-1, 
    weight_init=#glorot_uniform, 
    bias_init=#zeros, 
    weight=f32[10,1](μ=-0.78, σ=1.11, ∈[-2.58,0.00]), 
    bias=f32[10](μ=-0.39, σ=0.55, ∈[-1.29,0.00])
  ), 
  decoder=Linear(
    in_features=(#10), 
    out_features=#1, 
    in_axis=(#-1), 
    out_axis=#-1, 
    weight_init=#glorot_uniform, 
    bias_init=#zeros, 
    weight=None, 
    bias=f32[1](μ=-2.40, σ=0.00, ∈[-2.40,-2.40])
  )
)