<a href="https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytreeclass --quiet

## Construct a Tree

In [2]:
import jax
import jax.numpy as jnp
import pytreeclass as tc


@tc.autoinit
class Tree(tc.TreeClass):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()



## Vizualize pytree

In [3]:
print(tc.tree_summary(tree, depth=1))
print(tc.tree_summary(tree, depth=2))
print(tc.tree_diagram(tree, depth=1))
print(tc.tree_diagram(tree, depth=2))

┌────┬──────┬─────┬──────┐
│Name│Type  │Count│Size  │
├────┼──────┼─────┼──────┤
│.a  │int   │1    │      │
├────┼──────┼─────┼──────┤
│.b  │tuple │2    │      │
├────┼──────┼─────┼──────┤
│.c  │f32[3]│3    │12.00B│
├────┼──────┼─────┼──────┤
│Σ   │Tree  │6    │12.00B│
└────┴──────┴─────┴──────┘
┌─────┬──────┬─────┬──────┐
│Name │Type  │Count│Size  │
├─────┼──────┼─────┼──────┤
│.a   │int   │1    │      │
├─────┼──────┼─────┼──────┤
│.b[0]│float │1    │      │
├─────┼──────┼─────┼──────┤
│.b[1]│float │1    │      │
├─────┼──────┼─────┼──────┤
│.c   │f32[3]│3    │12.00B│
├─────┼──────┼─────┼──────┤
│Σ    │Tree  │6    │12.00B│
└─────┴──────┴─────┴──────┘
Tree
├── .a=1
├── .b=(...)
└── .c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])
Tree
├── .a=1
├── .b:tuple
│   ├── [0]=2.0
│   └── [1]=3.0
└── .c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])


## Working with `jax` transformations

In [4]:
@jax.grad
def loss_func(tree: Tree, x: jax.Array):
    # unfreeze tree before calling
    tree = tree.at[...].apply(tc.unfreeze, is_leaf=tc.is_frozen)
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error


@jax.jit
def train_step(tree: Tree, x: jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3 * g, tree, grads)


# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(
    lambda x: tc.freeze(x) if tc.is_nondiff(x) else x, tree
)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10, 1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jax.tree_util.tree_map(tc.unfreeze, jaxable_tree, is_leaf=tc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])

Tree(a=#1, b=(-4.282653, 3.0), c=[2.3924797 2.905778  3.4190807])
Tree(a=1, b=(-4.282653, 3.0), c=[2.3924797 2.905778  3.4190807])


## `at` indexing

In [5]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](μ=5.00, σ=0.82, ∈[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x > 4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

Tree(a=False, b=(False, False), c=[False  True  True])
Tree(a=None, b=(None, None), c=[5. 6.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
