# 🥶 Freezing tree leaves

In [1]:
from __future__ import annotations
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import pytreeclass as pytc


class Tree(pytc.TreeClass):
    a: int = 1
    b: float = 2.0
    c: jax.Array = jnp.array([3.0, 4.0, 5.0])


tree = Tree()
tree

Tree(a=1, b=2.0, c=f32[3](μ=4.00, σ=0.82, ∈[3.00,5.00]))

### Freeze leaves by specifying a mask

In [2]:
# lets freeze all int values
mask = jtu.tree_map(lambda x: isinstance(x, int), tree)
frozen_tree = tree.at[mask].apply(pytc.freeze)
print(frozen_tree)
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze by a mask
unfrozen_tree = frozen_tree.at[mask].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])


### Freeze leaves by specifying the leaf name

In [3]:
frozen_tree = tree.at["a"].apply(pytc.freeze)
print(frozen_tree)  # `a` has a prefix `#`
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze `a`
unfrozen_tree = frozen_tree.at["a"].apply(pytc.unfreeze, is_leaf=pytc.is_frozen)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])

Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])
