# 📏 Weight regularization

The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:

**Boolean mask based selection for branches and leaves**

In [1]:
!pip install git+https://github.com/ASEM000/serket --quiet

In [1]:
import serket as sk
import jax.numpy as jnp
import jax


class Net(sk.TreeClass):
    def __init__(self):
        self.weight = jnp.array([-1, -2, -3, 1, 2, 3])
        self.bias = jnp.array([-1, 1])


def negative_entries_l2_loss(net: Net):
    return (
        # select all positive array entries
        net.at[jax.tree_map(lambda x: x > 0, net)]
        # set them to zero to exclude their loss
        .set(0)
        # select all leaves
        .at[...]
        # finally reduce with l2 loss
        .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)
    )


net = Net()
print(negative_entries_l2_loss(net))

2.8333335


**Name-based selection for branches and leaves**

In [1]:
import serket as sk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import re


class Net(sk.TreeClass):
    def __init__(self, key: jax.random.KeyArray) -> None:
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.linear1 = sk.nn.Linear(in_features=1, out_features=20, key=k1)
        self.linear2 = sk.nn.Linear(in_features=20, out_features=20, key=k2)
        self.linear3 = sk.nn.Linear(in_features=20, out_features=20, key=k3)
        self.linear4 = sk.nn.Linear(in_features=20, out_features=1, key=k4)

    def __call__(self, x):
        x = jax.nn.tanh(self.linear1(x))
        x = jax.nn.tanh(self.linear2(x))
        x = jax.nn.tanh(self.linear3(x))
        x = self.linear4(x)
        return x


def linear_12_weight_l1_loss(net: Net):
    return (
        # select desired branches (linear1, linear2 in this example)
        # we can also select branches by a valid regex expression
        # then select `weight` leaf
        net.at[re.compile("linear[1,2]")]["weight"]
        # apply l1 loss
        .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)
    )


net = Net(key=jax.random.PRNGKey(0))
print(linear_12_weight_l1_loss(net))

82.83809
