# 🌲 Getting started

## 🛠️ Installation

```
pip install pytreeclass
```

**Install development version**



```
pip install git+https://github.com/ASEM000/PyTreeClass
```

## 📖 Description

`PyTreeClass` is a JAX-compatible `dataclass`-like decorator to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:

1) 🔒 To maintain safe and correct behaviour by using _immutable_ modules with _functional_ API.
2) To achieve the **most intuitive** user experience in the `JAX` ecosystem by :
   - 🏗️ Defining layers similar to `PyTorch` or `TensorFlow` sublcassing style.
   - ☝️ Filtering\Indexing layer values by using boolean masking similar to `jax.numpy.at[].{get,set,apply,...}`
   - 🎨 Visualize defined layers in plethora of ways for better debugging and sharing of information

## ⏩ Quick Example

### 🏗️ Simple Tree example

<div align="center">
<table>
<tr><td align="center">Code</td> <td align="center">PyTree representation</td></tr>
<tr>
<td>

```python
import jax
import jax.numpy as jnp
import pytreeclass as pytc

@pytc.treeclass
class Tree:
    a:int = 1
    b:tuple[float] = (2.,3.)
    c:jax.Array = jnp.array([4.,5.,6.])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x

tree = Tree()
```

</td>

<td>

```python
# leaves are parameters

Tree
    ├── a=1
    ├── b:tuple
    │   ├── [0]=2.0
    │   └── [1]=3.0
    └── c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])
```

</td>

</tr>
</table>
</div>

In [1]:
import jax
import jax.numpy as jnp
import pytreeclass as pytc


@pytc.treeclass
class Tree:
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()

### 🎨 Visualize<a id="Viz">

#### `tree_summary`

In [2]:
print(pytc.tree_summary(tree, depth=1))

┌────┬──────┬─────┐
│Name│Type  │Count│
├────┼──────┼─────┤
│a   │int   │1    │
├────┼──────┼─────┤
│b   │tuple │1    │
├────┼──────┼─────┤
│c   │f32[3]│3    │
├────┼──────┼─────┤
│Σ   │Tree  │5    │
└────┴──────┴─────┘


In [3]:
print(pytc.tree_summary(tree, depth=2))

┌────┬──────┬─────┐
│Name│Type  │Count│
├────┼──────┼─────┤
│a   │int   │1    │
├────┼──────┼─────┤
│b[0]│float │1    │
├────┼──────┼─────┤
│b[1]│float │1    │
├────┼──────┼─────┤
│c   │f32[3]│3    │
├────┼──────┼─────┤
│Σ   │Tree  │6    │
└────┴──────┴─────┘


#### `tree_diagram`

In [4]:
print(pytc.tree_diagram(tree, depth=1))

Tree
├── a=1
├── b=(...)
└── c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])


In [5]:
print(pytc.tree_diagram(tree, depth=2))

Tree
├── a=1
├── b:tuple
│   ├── [0]=2.0
│   └── [1]=3.0
└── c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00])


#### `tree_repr`

In [6]:
print(pytc.tree_repr(tree, depth=1))

Tree(a=1, b=(...), c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00]))


In [7]:
print(pytc.tree_repr(tree, depth=2))

Tree(a=1, b=(2.0, 3.0), c=f32[3](μ=5.00, σ=0.82, ∈[4.00,6.00]))


#### `tree_str`

In [8]:
print(pytc.tree_str(tree, depth=1))

Tree(a=1, b=(...), c=[4. 5. 6.])


In [9]:
print(pytc.tree_str(tree, depth=2))

Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])


### 🏃 Working with `jax` transformation

Parameters are defined in `Tree` at the top of class definition similar to defining
`dataclasses.dataclass` field.
Lets optimize our parameters

In [10]:
@jax.grad
def loss_func(tree: Tree, x: jax.Array):
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error


@jax.jit
def train_step(tree: Tree, x: jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3 * g, tree, grads)


# lets freeze the non-differentiable parts of the tree
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
jaxable_tree = jax.tree_util.tree_map(
    lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree
)

for epoch in range(1_000):
    jaxable_tree = train_step(jaxable_tree, jnp.ones([10, 1]))

print(jaxable_tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unfreeze the tree
tree = jax.tree_util.tree_map(pytc.unfreeze, jaxable_tree, is_leaf=pytc.is_frozen)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])

Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])
Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


#### ☝️ Advanced Indexing with `.at[]` <a id="Indexing">
_Out-of-place updates using mask, attribute name or index_

`PyTreeClass` offers 3 means of indexing through `.at[]`

1. Indexing by boolean mask.
2. Indexing by attribute name.
3. Indexing by Leaf index.

**Since `treeclass` wrapped class are immutable, `.at[]` operations returns new instance of the tree**


#### Index update by boolean mask

In [11]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](μ=5.00, σ=0.82, ∈[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x > 4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

Tree(a=False, b=(False, False), c=[False  True  True])
Tree(a=None, b=(None, None), c=[5. 6.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])


#### Index update by attribute name

In [12]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](μ=5.00, σ=0.82, ∈[4,6]))

print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)

print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

Tree(a=1, b=(None, None), c=None)
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])


#### Index update by integer index

In [13]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](μ=5.00, σ=0.82, ∈[4,6]))

print(tree.at[1].at[0].get())
# Tree(a=None, b=(2.0, None), c=None)

print(tree.at[1].at[0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

print(tree.at[1].at[0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

Tree(a=None, b=(2.0, None), c=None)
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
