# 🥶 Freezing tree leaves

## `jax` and inexact data types
`jax` transformations like `jax.grad` can handle  `PyTree` of inexact data types ( `float`, `complex`, `array` of `float`/`complex`). any other input type will lead to type error, the following example shows this.

In [1]:
import jax


@jax.grad
def identity_grad(x):
    # x can be any pyt
    return sum(x)


# valid input
identity_grad([1.0, 1.0])

# invalid input (not in-exact)
try:
    identity_grad([1])
except TypeError as e:
    print(e)

grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.


## Using `freeze`

However, in cases when you function needs to pass non-inexact data type, we can use wrap the non-inexact types with `pytreeclass.freeze`

In [2]:
import pytreeclass as pytc


@jax.grad
def identity_grad(x):
    return 1.0


try:
    # this will fail
    identity_grad([1, 1.0])
except TypeError as e:
    print(e)

# this will work
identity_grad([pytc.freeze(1), 1.0])

grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.


[#1, Array(0., dtype=float32, weak_type=True)]

Notice that using `pytc.freeze` we were able to pass non-inexact type to `jax` transformation without `jax` complaining.
however, inside the function we need to unfreeze this value if we want to use it, if donot need to use the value , we dont need to worry about unfreezing it. the following example shows that.

In [3]:
import pytreeclass as pytc


@jax.grad
def identity_grad(x):
    # this function does not use the frozen value
    return x[1] ** 2


print(identity_grad([pytc.freeze(1), 1.0]))

[#1, Array(2., dtype=float32, weak_type=True)]


However, if we need to pass non-inexact value to the function to use inside the function we need to freeze it before passing it to the function, and unfreeze it inside the function. The next example explain this concept

In [4]:
import pytreeclass as pytc


@jax.grad
def func(x):
    # this function uses the non-inexact and inexact values
    # the non-inexact value is frozen so we need to unfreeze it
    return pytc.unfreeze(x[0]) ** 2 + x[1] ** 2


print(func([pytc.freeze(1), 1.0]))

[#1, Array(2., dtype=float32, weak_type=True)]


The result of previous cell reveals something interesting, we know that $\frac{d}{dx} x^2 = 2x$, however this derivative is only evaluated for the inexact value of type `float` and returned the result as `Array(2.)`, but for the value of type `int` which was frozen on input, it has not changed. this is working as intended, in fact we can use this mechanism not only to pass invalid types to `jax` transformation without raising an error, but we can use this scheme to prevent values from being updated/take derivative with respect to. the following example shows this:

In [5]:
import pytreeclass as pytc
import jax


@jax.grad
def func(x):
    return pytc.unfreeze(x) ** 2


print(func(pytc.freeze(1.0)))

#1.0


We can use this scheme to prevent some values from being updated (e.g. `trained`) when interfacing with `jax` transformation.
However, we note that from the previous example, we had to know exactly which value to unfreeze to apply `unfreeze` to. what if we dont have this information?, then we can simply use `jax.tree_map` to do it for us in the following example.

In [6]:
import pytreeclass as pytc
import jax


@jax.grad
def sum_grad(x):
    # unfreeze the input in case any of the values are frozen
    # this is not necessary if you know the input is not frozen
    x = jax.tree_map(pytc.unfreeze, x, is_leaf=pytc.is_frozen)
    return sum(x)


print(sum_grad([pytc.freeze(1.0), 1.0]))
print(sum_grad([1.0, pytc.freeze(1.0)]))
print(sum_grad([pytc.freeze(1.0), pytc.freeze(1.0)]))
print(sum_grad([1.0, 1.0]))

[#1.0, Array(1., dtype=float32, weak_type=True)]
[Array(1., dtype=float32, weak_type=True), #1.0]
[#1.0, #1.0]
[Array(1., dtype=float32, weak_type=True), Array(1., dtype=float32, weak_type=True)]


Using the previous scheme , you can achieve a low-overhead training when using `jax` and `PyTreeClass`

## Using `freeze` with a mask recipes
The following examples shows how to effictively using `freeze` and `TreeClass` instances to freeze certain values.

In [7]:
from __future__ import annotations
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import pytreeclass as pytc


class Tree(pytc.TreeClass):
    a: int = 1
    b: float = 2.0
    c: jax.Array = jnp.array([3.0, 4.0, 5.0])


tree = Tree()
tree

Tree(a=1, b=2.0, c=f32[3](μ=4.00, σ=0.82, ∈[3.00,5.00]))

### Freeze leaves by specifying a mask

In [8]:
# lets freeze all int values
mask = jtu.tree_map(lambda x: isinstance(x, int), tree)
frozen_tree = tree.at[mask].apply(pytc.freeze)
print(frozen_tree)
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze by a mask
unfrozen_tree = frozen_tree.at[mask].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])


### Freeze leaves by specifying the leaf name

In [9]:
frozen_tree = tree.at["a"].apply(pytc.freeze)
print(frozen_tree)  # `a` has a prefix `#`
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze `a`
unfrozen_tree = frozen_tree.at["a"].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])


# Use `freeze`/`unfreeze` with `tree_mask`/`tree_unmask`

`PyTreeClass` provides `tree_mask`/`tree_unmask` as a convenience function to achieve the above
tasks with less code. in essence, `tree_mask`, `tree_unmask` freeze/unfreeze node that satisfies a mask

In [10]:
# freeze leaves that satisfy a predicate/mask
# by default `tree_mask` uses `is_nondiff` predicate mask to freeze non-differentiable values

tree = (1.0, 2)
masked_tree = pytc.tree_mask(tree)

print(masked_tree)
# (1.0, #2)  # `2` is frozen because it is an int (e.g. non-inexact type)

# we can undo the above by `tree_unmask`
print(pytc.tree_unmask(masked_tree))
# (1.0, 2)

(1.0, #2)
(1.0, 2)
