# PyTrees

PyTrees are a tree/container datastructure, where intermediate nodes are containers, and their contents are children. If the element is not a container, it becomes a leaf node, otherwise it is yet another node.

I think they created pytrees to make gradient calculations easier. Without pytrees we have to do this - 
```python
dW1, db1, dW2, db2 = grad(loss, argnums=(0, 1, 2, 3))(W1, b1, W2, b2, X, y)
```

With pytrees we can do this -
```python
param_grads = grad(loss)(params, X, y)
```
where `params` can be -
```
params = [
    {
        "W1": W1,
        "b1": b1
    },
    {
        "W2": W2,
        "b2": b2
    }
]
```

This works, because by default `grad` will take the gradient w.r.t the first argument. In this case the first argument is `params`, which is a pytree, so the output will be another pytree with the corresponding gradients.

![pytrees_1](./imgs/pytree_2.png)

I can even use pytrees to represent the training batch if I make it as any of the registered containers - e.g., 
```
batch = (X, y)
```

#### Which data structures are PyTree containers and which are PyTree leaves?

  * Native Python containers like `list`, `dict`, `tuple` are already registered with Jax as containers. Even empty containers are treated as a node with no children, but not as leaves!
  * `jax.Array` and `numpy.ndarray` are treated as leaves.
  * `None` is treated as an empty node.
  * Any object that is not registered as a container will be treated as a leaf. I can of course register my own custom class as a container.

What follows is a demonstration of the following utility functions:
  * `tree_leaves()`
  * `tree_structure()`
  * `tree_flatten()`
  * `tree_unflatten()`
  * `tree_map()`

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as tu
import numpy as np
from random import random as pyrandom

In [3]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

The following experiment demonstrates that both numpy and jax arrays are **not** containers, the entire array object is considered as a leaf. Python lists OTOH are containers where each element is a leaf.

In [23]:
params = {
    "W1": jax.random.uniform(key, shape=(3,)),
    "W2": rng.random(size=(3,)),
    "W3": [pyrandom(), pyrandom(), pyrandom()]
}
tu.tree_leaves(params)

[Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32),
 array([0.88037383, 0.04700075, 0.79224713]),
 0.21992445929098692,
 0.3532498327415212,
 0.4444931535522526]

In [24]:
container = ["a", "b", ("Anika", "Baboodi")]
tu.tree_leaves(container)

['a', 'b', 'Anika', 'Baboodi']

In [25]:
tu.tree_structure(container)

PyTreeDef([*, *, (*, *)])

In [26]:
tu.tree_structure(params)

PyTreeDef({'W1': *, 'W2': *, 'W3': [*, *, *]})

`tree_structure` gets both the leaves and the structure in one call.

In [27]:
ptree = {
    "x": 1,
    "y": (2., 3.),
    "z": [4., 5., 6.]
}
leaves, struct = tu.tree_flatten(ptree)
print("Leaves: ", leaves)
print("Structure: ", struct)

Leaves:  [1, 2.0, 3.0, 4.0, 5.0, 6.0]
Structure:  PyTreeDef({'x': *, 'y': (*, *), 'z': [*, *, *]})


In [29]:
newtree = tu.tree_map(lambda leaf: leaf ** 2, ptree)
print(newtree)
leaves, struct = tu.tree_flatten(newtree)
print("Leaves: ", leaves)
print("Structure: ", struct)

{'x': 1, 'y': (4.0, 9.0), 'z': [16.0, 25.0, 36.0]}
Leaves:  [1, 4.0, 9.0, 16.0, 25.0, 36.0]
Structure:  PyTreeDef({'x': *, 'y': (*, *), 'z': [*, *, *]})


In [30]:
newertree = tu.tree_map(lambda leaf1, leaf2: leaf1 + leaf2, ptree, newtree)
newertree

{'x': 2, 'y': (6.0, 12.0), 'z': [20.0, 30.0, 42.0]}

In [5]:
leaves, struct = tu.tree_flatten([1, 2, (), None])
print(f"Leaves: {leaves}")
print(f"Structure: {struct}")

Leaves: [1, 2]
Structure: PyTreeDef([*, *, (), None])


#### Custom Class as PyTree

In [6]:
from dataclasses import dataclass

@dataclass
class Triplet:
    name: str
    x: float
    y: float
    z: float

    def __repr__(self):
        return f"<Triplet(name={self.name} x={self.x} y={self.y} z={self.z})>"

In [9]:
obj = Triplet("Cookies", 10, 20, 30)
obj

<Triplet(name=Cookies x=10 y=20 z=30)>

In [10]:
leaves, struct = tu.tree_flatten([1, jnp.arange(3), ["hello", np.arange(3)], obj])
print("Leaves: ", leaves)
print("Struct: ", struct)

Leaves:  [1, Array([0, 1, 2], dtype=int32), 'hello', array([0, 1, 2]), <Triplet(name=Cookies x=10 y=20 z=30)>]
Struct:  PyTreeDef([*, *, [*, *], *])


Define two functions that tells Jax how to flatten and unflatten the object.

In [11]:
def triplet_flatten(triplet):
    leaves = (triplet.x, triplet.y, triplet.z)
    auxdata = triplet.name
    return (leaves, auxdata)

def triplet_unflatten(auxdata, leaves):
    x, y, z = leaves
    return Triplet(name=auxdata, x=x, y=y, z=z)

tu.register_pytree_node(Triplet, triplet_flatten, triplet_unflatten)

In [12]:
leaves, struct = tu.tree_flatten([1, jnp.arange(3), ["hello", np.arange(3)], obj])
print("Leaves: ", leaves)
print("Struct: ", struct)

Leaves:  [1, Array([0, 1, 2], dtype=int32), 'hello', array([0, 1, 2]), 10, 20, 30]
Struct:  PyTreeDef([*, *, [*, *], CustomNode(Triplet[Cookies], [*, *, *])])
