# 🍳  Common recipes

This section introduces common recipes you might need while using `PyTreeClass` to train/build models.

## [1] Add a leaf to the instance after instantiation.

The following recipe, adds a method `add_leaf` that sets a leaf value and name. however, since this method mutate the internal state of the instance `.at['add_leaf']` is used to apply the method functionally and return method call value and a **new** instance .

In [1]:
import pytreeclass as pytc


class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0

    def add_leaf(self, name: str, value):
        setattr(self, name, value)


tree = Tree()
# Tree(a=1.0, b=2.0, c=3.0)

_, tree_with_d = tree.at["add_leaf"]("d", 4.0)

tree_with_d

Tree(a=1.0, b=2.0, c=3.0, d=4.0)

## [2] Customize optimizers-leaf updates using `PyTreeClass` mask + `Optax`.
The following recipe, `optax.masked` is used to apply certain optmizers to certain leaves using masking.

In [2]:
import optax
import pytreeclass as pytc
import jax


class Tree(pytc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0


tree = Tree()

false_mask = tree.at[...].set(False)

a_mask = false_mask.at["a"].set(True)
b_mask = false_mask.at["b"].set(True)
c_mask = false_mask.at["c"].set(True)

optim = optax.chain(
    # update `a` with sgd of learning rate 1
    optax.masked(optax.sgd(learning_rate=1), a_mask),
    # update `b` with sgd of learning rate -1
    optax.masked(optax.sgd(learning_rate=-1), b_mask),
    # update `c` with sgd of learning rate 0
    optax.masked(optax.sgd(learning_rate=0), c_mask),
)


# freeze non-differentiable parameters
# in this case all parameters are differentiable
# but we do it incase we add a non-differentiable parameter later
tree = tree.at[jax.tree_map(pytc.is_nondiff, tree)].apply(pytc.freeze)

optim_state = optim.init(tree)

## [3] Use `numpy` functions on `TreeClass` instance.
`jax.numpy` functions can be applied to `TreeClass` instance using a function transformation `bcmap` around the `numpy` function and enabling the feature through `leafwise=True`. `leafwise=True` additionally enable math operation per-leaf, for example `tree`+1 will add 1 to all leaves. 

In [3]:
import pytreeclass as pytc
import jax.numpy as jnp


class Tree(pytc.TreeClass, leafwise=True):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])


tree = Tree()

# make where work with arbitrary pytrees
tree_where = pytc.bcmap(jnp.where)

print(tree_where(tree > 2, tree + 100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

print(tree.at[tree > 1].apply(lambda x: x + 100))
# Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])

mask = tree_where(tree > 1, True, False)
print(tree.at[mask].apply(lambda x: x + 100))
# Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])

Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])
Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])
Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])


## [4] Use visualization tools with arbitrary pytrees

In [4]:
import jax
import pytreeclass as pytc

tree = [1, [2, 3], 4]

print(pytc.tree_diagram(tree, depth=1))
print(pytc.tree_diagram(tree, depth=2))
print(pytc.tree_summary(tree, depth=1))
print(pytc.tree_summary(tree, depth=2))

list
├── [0]=1
├── [1]=[...]
└── [2]=4
list
├── [0]=1
├── [1]:list
│   ├── [0]=2
│   └── [1]=3
└── [2]=4
┌────┬────┬─────┐
│Name│Type│Count│
├────┼────┼─────┤
│[0] │int │1    │
├────┼────┼─────┤
│[1] │list│1    │
├────┼────┼─────┤
│[2] │int │1    │
├────┼────┼─────┤
│Σ   │list│3    │
└────┴────┴─────┘
┌──────┬────┬─────┐
│Name  │Type│Count│
├──────┼────┼─────┤
│[0]   │int │1    │
├──────┼────┼─────┤
│[1][0]│int │1    │
├──────┼────┼─────┤
│[1][1]│int │1    │
├──────┼────┼─────┤
│[2]   │int │1    │
├──────┼────┼─────┤
│Σ     │list│4    │
└──────┴────┴─────┘


## [5] Using `callbacks` to validate/convert inputs

In [5]:
import jax
import pytreeclass as pytc


# you can use any function
class Range(pytc.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


class IsInstance(pytc.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


class Foo(pytc.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = pytc.field(callbacks=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)

On applying Range(min=1, max=100) for field=`in_dim`:
0 not in range [1, 100]
On applying IsInstance(klass=<class 'int'>) for field=`in_dim`:
1.0 not an instance of <class 'int'>


## [6] Freeze custom parameters using `.at` manually/with mask

In the following example,  some classes like `Dropout`, can contain some leaves that are differentiable,
but we do not wish to update them. in `Dropout` Example, the `drop_rate` is a float that
should not be updated by optimization. the following recipe shows how to deal with such values.


In [6]:
import pytreeclass as pytc
import jax


class Dropout(pytc.TreeClass):
    drop_rate: float = 0.0  # dropout rate, 0 mean no dropout

    def __call__(self, x, *, key):
        keep_rate = 1.0 - self.drop_rate
        mask = jax.random.bernoulli(key, keep_rate, x.shape)
        return jnp.where(mask, x / keep_rate, 0.0)


x = jnp.arange(10)
dropout = Dropout(drop_rate=0.5)
dropout(x, key=jax.random.PRNGKey(0))


@jax.grad
def f(layer: Dropout, x: jax.Array):
    return layer(x, key=jax.random.PRNGKey(0)).sum()


print(f(dropout, x))
# Dropout(drop_rate=108.0)  # <--- this is the gradient which is undesired


# lets fix this by freezing the dropout rate
class Dropout(pytc.TreeClass):
    drop_rate: float = pytc.field(callbacks=[pytc.freeze], default=0.0)

    def __call__(self, x, *, key):
        keep_rate = 1.0 - self.drop_rate
        mask = jax.random.bernoulli(key, keep_rate, x.shape)
        return jnp.where(mask, x / keep_rate, 0.0)


x = jnp.arange(10)
dropout = Dropout(drop_rate=0.5)

dropout
# Dropout(drop_rate=#0.5)  # -> dropout rate is frozen, to call dropout layer we need to unfreeze it first


@jax.grad
def f(layer: Dropout, x: jax.Array):
    layer = jax.tree_map(pytc.unfreeze, layer, is_leaf=pytc.is_frozen)
    return layer(x, key=jax.random.PRNGKey(0)).sum()


f(dropout, x)
# Dropout(drop_rate=#0.5)  # <- dropout rate is not updated, can be used safely with optax


# lets say, for evaluation we want to set the dropout rate to 0.0
# then we can do the following

disable_dropout = dropout.at["drop_rate"].set(0.0, is_leaf=pytc.is_frozen)
print(disable_dropout)
# Dropout(drop_rate=0.0)  # now the dropout rate is 0. and unfrozen.
# this layer is now safe to use for evaluation without special handling (like eval in pytorch)

Dropout(drop_rate=108.0)
Dropout(drop_rate=0.0)


## [7] Use `PyTreeClass` with `Flax`/`Equinox`
The following recipe adds `at` support for `Flax` and `Equinox`. note for equinox use `eqx.Module` instead of `struct.PyTreeNode`

In [7]:
import jax
import pytreeclass as pytc
from flax import struct

import jax
import pytreeclass as pytc
from flax import struct

# note that flax is registered with `jax.tree_util.register_pytree_with_keys`
# otherwise for arbitrary objects you need to do the key registration


class FlaxTree(struct.PyTreeNode):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jax.numpy.array([4.0, 5.0, 6.0])

    def __repr__(self) -> str:
        return pytc.tree_repr(self)

    def __str__(self) -> str:
        return pytc.tree_str(self)

    @property
    def at(self):
        return pytc.AtIndexer(self, where=())


flax_tree = FlaxTree()

print(f"{flax_tree!r}")
print(f"{flax_tree!s}")
print(pytc.tree_diagram(flax_tree))
print(pytc.tree_summary(flax_tree))

flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00]))

FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00]))
FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
FlaxTree
├── .a=1
├── .b:tuple
│   ├── [0]=2.0
│   └── [1]=3.0
└── .c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])
┌─────┬────────┬─────┐
│Name │Type    │Count│
├─────┼────────┼─────┤
│.a   │int     │1    │
├─────┼────────┼─────┤
│.b[0]│float   │1    │
├─────┼────────┼─────┤
│.b[1]│float   │1    │
├─────┼────────┼─────┤
│.c   │f32[3]  │3    │
├─────┼────────┼─────┤
│Σ    │FlaxTree│6    │
└─────┴────────┴─────┘


FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00]))

## [8] `named_parameters()` like in `PyTreeClass`

In [8]:
import pytreeclass as pytc
import jax


class Tree(pytc.TreeClass):
    a: int = 1
    b: tuple[float, float] = (2.0, 3.0)


tree = Tree()

for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
    print(path, leaf)

(NamedSequenceKey(idx=0, key='a'),) 1
(NamedSequenceKey(idx=1, key='b'), SequenceKey(idx=0)) 2.0
(NamedSequenceKey(idx=1, key='b'), SequenceKey(idx=1)) 3.0


## [9] Initializae parameters based on input
In this example, a `Linear` layer with weight paraemter based on the shape of the input will be created.
Since this requires parameter creation (i.e. `weight`) after instance initialization we will use `.at` to create a new instance with the added parameter.

In [9]:
import pytreeclass as pytc
from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jr


class LazyLinear(pytc.TreeClass):
    out_features: int

    def param(self, name: str, value: Any):
        # return the value if it exists, otherwise set it and return it
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
        weight = self.param("weight", jnp.ones((x.shape[-1], self.out_features)))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return x @ weight + bias


x = jnp.ones([10, 1])

lazy_linear = LazyLinear(out_features=1)

lazy_linear
print(f"Layer before param is set:\t{lazy_linear}")


# first call will set the parameters
_, linear = lazy_linear.at["__call__"](x, key=jr.PRNGKey(0))

print(f"Layer after param is set:\t{linear}")
# subsequent calls will use the same parameters and not set them again
linear(x)

Layer before param is set:	LazyLinear(out_features=1)
Layer after param is set:	LazyLinear(out_features=1, weight=[[1.]], bias=[0.])


Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)