**S03P01_tutorial_working_with_pytrees.ipynb**

Arz

2024 APR 16 (WED)

reference:
https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# what is a pytree?

pytree:

a container of leaf elements &| more pytrees.
- **container**:
    - ex) list, tuple, dict
- **leaf element**: anything that's not a pytree
    - ex) array
 
a possibly-nested standard or user-registered Python container.

In [4]:
example_trees = [[1, 'a', object()],
                (1, (2, 3), [], ()),
                [0, {'a': 1, 'b': (2, 'a')}, 3],
                {'p': (7, ''), 'q': 8},
                jnp.array([1, 2, 3])]

for tree in example_trees:
    leaves = jax.tree_util.tree_leaves(tree)
    print(f"{repr(tree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x7d2dfe2d7830>]   has 3 leaves: [1, 'a', <object object at 0x7d2dfe2d7830>]
(1, (2, 3), [], ())                           has 3 leaves: [1, 2, 3]
[0, {'a': 1, 'b': (2, 'a')}, 3]               has 5 leaves: [0, 1, 2, 'a', 3]
{'p': (7, ''), 'q': 8}                        has 3 leaves: [7, '', 8]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


W: so, key in dict is not a leaf (?)

# why pytree?

- model parameters
- dataset entries
- RL agent observations

(+) working in bulk with datasets

# common pytree functions

## jax.tree_map

In [5]:
x1 = [[1, 2, 3],
     (1, 2),
     {'a': 1}]

jax.tree_map(lambda x: 2*x, x1)

[[2, 4, 6], (2, 4), {'a': 2}]

multiple arguments.

In [6]:
# x2 = map(lambda x: 3*x, x1)  # unsupported
x2 = jax.tree_map(lambda x: 3*x, x1)  # works!
print(x2)

jax.tree_map(lambda x, y: x + y, x1, x2)

[[3, 6, 9], (3, 6), {'a': 3}]


[[4, 8, 12], (4, 8), {'a': 4}]

In [7]:
x3 = x1
x3[1] = [8, 9]
x3[2] = [1, 2]
x3 = map(lambda x: x, x3)
print(list(x3))

# jax.tree_map(lambda x, y: x + y, x1, x3)  # forbidden: 
# pytree structures do not match.

[[1, 2, 3], [8, 9], [1, 2]]


# ex) ML model parameters

MLP (Multilayer Perceptron)

In [8]:
def init_MLP_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(weight=np.random.normal(size=(n_in, n_out))*np.sqrt(2/n_in),
                bias=np.ones(shape=(n_out,))))
    return params

# params is a pytree: list of dicts.

In [9]:
params = init_MLP_params([1, 128, 64, 16, 1])

use jax.tree_map() to check parameters' shapes.

In [10]:
jax.tree_map(lambda x: x.shape, params)

[{'bias': (128,), 'weight': (1, 128)},
 {'bias': (64,), 'weight': (128, 64)},
 {'bias': (16,), 'weight': (64, 16)},
 {'bias': (1,), 'weight': (16, 1)}]

## train

- params: neural net parameters [pytree: list of dicts]
- x: input data [batched array]
- y: output data [batched array]

jax.grad() supports pytree argument.

so if we apply grad to the loss function, it finds the gradients of the loss function in terms of params (the first argument), the resulting grads is a pytree having the same structure as that of params. 

In [11]:
def model(params, x):
    *hidden_layer_params, output_layer_param = params
    for layer_param in hidden_layer_params:
        x = jax.nn.relu(x@layer_param["weight"] + layer_param["bias"])
    return x@output_layer_param["weight"] + output_layer_param["bias"]

def loss_function(params, x, y):
    return jnp.mean((model(params, x) - y)**2)

learning_rate = 0.0001

@jit
def update(params, x, y):
    grads = grad(loss_function)(params, x, y)
    return jax.tree_map(lambda p, g: p - learning_rate*g, params, grads)

## test

In [12]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

In [13]:
import plotly.graph_objects as go

In [14]:
x = np.random.normal(size=(100, 1))
y = x**2

In [15]:
num_iterations = 1000
for _ in range(num_iterations):
    params = update(params, x, y)

In [16]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=x.squeeze(), y=y.squeeze(),
                        mode="markers"))
fig.add_trace(go.Scatter(x=x.squeeze(), y=model(params, x).squeeze(),
                        mode="markers"))

fig.show()

# key_paths

each leaf has a **key path**: list of keys

- lenght of key path == depth of the leaf in the pytree
- **key**: a hashable object that represents an index into the corresponding pytree node.
    - key type: depends on the pytree node type.

<br />

- **default key types : built-in python node types**
    - SequenceKey(idx: int) : list, tuple
    - DictKey(key, hashable) : dict
    - GetAttrKey(name: str) : namedtuple, custom pytree node 

for built-in pytree node types, the set of keys for any pytree node instance is unique. for a pytree comprising nodes with this property, the key path for each leaf is unique.

## ex) track key paths for all values in a pytree

In [17]:
import collections

In [18]:
a = collections.namedtuple("a", ["school", "name", "age"])

tree = [7, {"key1": 3, "key2": 8}, [2], ((0)), a("MSS", "Arz", "27")]

In [19]:
key_path_and_keys, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in key_path_and_keys:
    print(f'value at key path: tree{jax.tree_util.keystr(key_path)} is {value}')

value at key path: tree[0] is 7
value at key path: tree[1]['key1'] is 3
value at key path: tree[1]['key2'] is 8
value at key path: tree[2][0] is 2
value at key path: tree[3] is 0
value at key path: tree[4].school is MSS
value at key path: tree[4].name is Arz
value at key path: tree[4].age is 27


In [20]:
for key_path, value in key_path_and_keys:
    print(f'key path: tree{jax.tree_util.keystr(key_path)} has key types\n {repr(key_path)}')

key path: tree[0] has key types
 (SequenceKey(idx=0),)
key path: tree[1]['key1'] has key types
 (SequenceKey(idx=1), DictKey(key='key1'))
key path: tree[1]['key2'] has key types
 (SequenceKey(idx=1), DictKey(key='key2'))
key path: tree[2][0] has key types
 (SequenceKey(idx=2), SequenceKey(idx=0))
key path: tree[3] has key types
 (SequenceKey(idx=3),)
key path: tree[4].school has key types
 (SequenceKey(idx=4), GetAttrKey(name='school'))
key path: tree[4].name has key types
 (SequenceKey(idx=4), GetAttrKey(name='name'))
key path: tree[4].age has key types
 (SequenceKey(idx=4), GetAttrKey(name='age'))


# custom pytree nodes

user-defined container class is considered as a leaf, since it is not an official pytree container (list, tuple, dict).

In [21]:
class My_Container:
    """user-defined container: is a leaf"""

    def __init__(self, name: str, level: int, rankings: list[int]):
        self.name = name
        self.level = level
        self.rankings = rankings

In [22]:
players_including_mutant = [My_Container("Arz", 64, [92925, 91501, 875]),
                            My_Container("Ssr", 64, [101973, 97323, 1212]),
                            ("mutant", 100, [0, 0, 0]),
                            My_Container("Nov", 57, [124953, 134056, 7249])]

jax.tree_util.tree_leaves(players_including_mutant)

[<__main__.My_Container at 0x7d2cccd4a8d0>,
 <__main__.My_Container at 0x7d2cccd0f380>,
 'mutant',
 100,
 0,
 0,
 0,
 <__main__.My_Container at 0x7d2d38287f50>]

you see. My_Container is classified as leaf, while list and tuple are not.

My_Container is not a pytree, so it is not supported by pytree functions such as jax.tree_map().

In [23]:
try:
    jax.tree_map(lambda x: x + 1, players_including_mutant)
except TypeError as e:
    print(f"TypeError: {e}")

TypeError: unsupported operand type(s) for +: 'My_Container' and 'int'


## register custom node

### method #1: define flatten and unflatten

In [24]:
from typing import Iterable

In [25]:
def flatten_My_Container(my_container) -> tuple[tuple[int, Iterable[int]], str]:
    flat_contents = (my_container.level, my_container.rankings)
    aux_data = my_container.name
    return flat_contents, aux_data

def unflatten_My_Containter(aux_data: str, flat_contents: tuple[int, Iterable[int]]) -> My_Container:
    return My_Container(aux_data, *flat_contents)

In [26]:
jax.tree_util.register_pytree_node(My_Container, flatten_My_Container, unflatten_My_Containter)

In [27]:
players = [My_Container("Arz", 64, [92925, 91501, 875]),
           My_Container("Ssr", 64, [101973, 97323, 1212]),
           My_Container("Nov", 57, [124953, 134056, 7249])]

players_including_mutant = [My_Container("Arz", 64, [92925, 91501, 875]),
                            My_Container("Ssr", 64, [101973, 97323, 1212]),
                            ("mutant", 100, [0, 0, 0]),
                            My_Container("Nov", 57, [124953, 134056, 7249])]

In [28]:
jax.tree_util.tree_leaves(players_including_mutant)

[64,
 92925,
 91501,
 875,
 64,
 101973,
 97323,
 1212,
 'mutant',
 100,
 0,
 0,
 0,
 57,
 124953,
 134056,
 7249]

In [29]:
players_updated = jax.tree_map(lambda x: x + 1, players)

jax.tree_util.tree_leaves(players_updated)

[65, 92926, 91502, 876, 65, 101974, 97324, 1213, 58, 124954, 134057, 7250]

in this case, the key for the custom node has key type: **FlattenedIndexKey** and the index is **flat index**

In [30]:
key_path_and_keys, _ = jax.tree_util.tree_flatten_with_path(players_including_mutant)

for key_path, value in key_path_and_keys:
    print(f'key path: tree{jax.tree_util.keystr(key_path)} has key types\n {repr(key_path)}')

key path: tree[0][<flat index 0>] has key types
 (SequenceKey(idx=0), FlattenedIndexKey(key=0))
key path: tree[0][<flat index 1>][0] has key types
 (SequenceKey(idx=0), FlattenedIndexKey(key=1), SequenceKey(idx=0))
key path: tree[0][<flat index 1>][1] has key types
 (SequenceKey(idx=0), FlattenedIndexKey(key=1), SequenceKey(idx=1))
key path: tree[0][<flat index 1>][2] has key types
 (SequenceKey(idx=0), FlattenedIndexKey(key=1), SequenceKey(idx=2))
key path: tree[1][<flat index 0>] has key types
 (SequenceKey(idx=1), FlattenedIndexKey(key=0))
key path: tree[1][<flat index 1>][0] has key types
 (SequenceKey(idx=1), FlattenedIndexKey(key=1), SequenceKey(idx=0))
key path: tree[1][<flat index 1>][1] has key types
 (SequenceKey(idx=1), FlattenedIndexKey(key=1), SequenceKey(idx=1))
key path: tree[1][<flat index 1>][2] has key types
 (SequenceKey(idx=1), FlattenedIndexKey(key=1), SequenceKey(idx=2))
key path: tree[2][0] has key types
 (SequenceKey(idx=2), SequenceKey(idx=0))
key path: tree[2]

### method #2: extended method #1 with keys

In [31]:
class My_Container_With_Key(My_Container):
    pass

def flatten_My_Container_With_Key(my_container_with_key) -> tuple[tuple[int, Iterable[int]], str]:
    # GetAttrKey is a common choice.
    flat_contents = ((jax.tree_util.GetAttrKey("level"), my_container_with_key.level), 
                     (jax.tree_util.GetAttrKey("rankings"), my_container_with_key.rankings))
    aux_data = my_container_with_key.name
    return flat_contents, aux_data

def unflatten_My_Containter_With_Key(aux_data: str, flat_contents: tuple[int, Iterable[int]]) -> My_Container_With_Key:
    return My_Container_With_Key(aux_data, *flat_contents)

In [32]:
jax.tree_util.register_pytree_with_keys(My_Container_With_Key, flatten_My_Container_With_Key, unflatten_My_Containter_With_Key)

In [33]:
players = [My_Container_With_Key("Arz", 64, [92925, 91501, 875]),
           My_Container_With_Key("Ssr", 64, [101973, 97323, 1212]),
           My_Container_With_Key("Nov", 57, [124953, 134056, 7249])]

players_including_mutant = [My_Container_With_Key("Arz", 64, [92925, 91501, 875]),
                            My_Container_With_Key("Ssr", 64, [101973, 97323, 1212]),
                            ("mutant", 100, [0, 0, 0]),
                            My_Container_With_Key("Nov", 57, [124953, 134056, 7249])]

In [34]:
jax.tree_util.tree_leaves(players_including_mutant)

[64,
 92925,
 91501,
 875,
 64,
 101973,
 97323,
 1212,
 'mutant',
 100,
 0,
 0,
 0,
 57,
 124953,
 134056,
 7249]

In [35]:
players_updated = jax.tree_map(lambda x: x + 1, players)

jax.tree_util.tree_leaves(players_updated)

[65, 92926, 91502, 876, 65, 101974, 97324, 1213, 58, 124954, 134057, 7250]

in this case, the key for the custom node has key type: **GetAttrKey**.

In [36]:
key_path_and_keys, _ = jax.tree_util.tree_flatten_with_path(players_including_mutant)

for key_path, value in key_path_and_keys:
    print(f'key path: tree{jax.tree_util.keystr(key_path)} has key types\n {repr(key_path)}')

key path: tree[0].level has key types
 (SequenceKey(idx=0), GetAttrKey(name='level'))
key path: tree[0].rankings[0] has key types
 (SequenceKey(idx=0), GetAttrKey(name='rankings'), SequenceKey(idx=0))
key path: tree[0].rankings[1] has key types
 (SequenceKey(idx=0), GetAttrKey(name='rankings'), SequenceKey(idx=1))
key path: tree[0].rankings[2] has key types
 (SequenceKey(idx=0), GetAttrKey(name='rankings'), SequenceKey(idx=2))
key path: tree[1].level has key types
 (SequenceKey(idx=1), GetAttrKey(name='level'))
key path: tree[1].rankings[0] has key types
 (SequenceKey(idx=1), GetAttrKey(name='rankings'), SequenceKey(idx=0))
key path: tree[1].rankings[1] has key types
 (SequenceKey(idx=1), GetAttrKey(name='rankings'), SequenceKey(idx=1))
key path: tree[1].rankings[2] has key types
 (SequenceKey(idx=1), GetAttrKey(name='rankings'), SequenceKey(idx=2))
key path: tree[2][0] has key types
 (SequenceKey(idx=2), SequenceKey(idx=0))
key path: tree[2][1] has key types
 (SequenceKey(idx=2), Sequ

## NamedTuple

NamedTuple is already a supported pytree type.

In [37]:
from typing import NamedTuple, Any

In [38]:
class My_NamedTuple_Container(NamedTuple):
    name: str
    id_number: int
    pets: Any    

In [39]:
butlers = [My_NamedTuple_Container("butler1", 1, ["cat1", "cat2", "dog1"]),
          My_NamedTuple_Container("butler2", 2, ["cat1", "owl1"])]

In [40]:
jax.tree_util.tree_leaves(butlers)

['butler1', 1, 'cat1', 'cat2', 'dog1', 'butler2', 2, 'cat1', 'owl1']

but now *name* appears as leaf. 

### shortcut: register_static

use register_static to keep *name* as aux_data rather than a child.

In [41]:
@jax.tree_util.register_static
class Static_String(str):
    pass

class My_NamedTuple_Container_With_Static_Name(NamedTuple):
    name: Static_String
    id_number: int
    pets: Any

In [42]:
butlers = [My_NamedTuple_Container_With_Static_Name(Static_String("butler1"), 1, ["cat1", "cat2", "dog1"]),
          My_NamedTuple_Container_With_Static_Name(Static_String("butler2"), 2, ["cat1", "owl1"])]

In [43]:
jax.tree_util.tree_leaves(butlers)

[1, 'cat1', 'cat2', 'dog1', 2, 'cat1', 'owl1']

now *name* disappears. 

# common pytree gotchas and patterns

## gotchas

### mistaking nodes for leaves

In [44]:
a = [jnp.zeros((3, 1)), jnp.zeros((2, 6))]

shapes_of_a = jax.tree_map(lambda x: x.shape, a)

b = jax.tree_map(jnp.ones, shapes_of_a)

In [45]:
print(a)
print(shapes_of_a)
print(b)

[Array([[0.],
       [0.],
       [0.]], dtype=float32), Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)]
[(3, 1), (2, 6)]
[(Array([1., 1., 1.], dtype=float32), Array([1.], dtype=float32)), (Array([1., 1.], dtype=float32), Array([1., 1., 1., 1., 1., 1.], dtype=float32))]


.shape returns tuple, so its elements are the leaves to which tree_map is applying the ones operation.

**two example approachs to solve this issue**

- method #1: rewrite the code to avoid the intermediate tree_map.
- method #2: convert the tuple into np.array or jnp.array, which makes the entire sequence a leaf.

ex) method #2

In [46]:
shapes_of_a = jax.tree_map(lambda x: jnp.array(x.shape), a)

b = jax.tree_map(jnp.ones, shapes_of_a)

In [47]:
print(shapes_of_a)
print(b)

[Array([3, 1], dtype=int32), Array([2, 6], dtype=int32)]
[Array([[1.],
       [1.],
       [1.]], dtype=float32), Array([[1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.]], dtype=float32)]


### handling of None

jax.tree_utils treats None as a node without children, not as a leaf.

In [48]:
jax.tree_util.tree_leaves([None, None, None])

[]

## patterns

### transposing trees

ex) (?) list of trees ---> tree of lists

In [49]:
def transpose_list_of_trees(list_of_trees):
    """convert a list of trees of identical structure into a single tree of lists."""
    return jax.tree_map(lambda *x: list(x), *list_of_trees)

In [50]:
episode_steps = [dict(t=0, obs=(0, 0)),
                dict(t=1, obs=(1, 2)),
                dict(t=2, obs=(3, -1))]
print(episode_steps)

transpose_list_of_trees(episode_steps)

[{'t': 0, 'obs': (0, 0)}, {'t': 1, 'obs': (1, 2)}, {'t': 2, 'obs': (3, -1)}]


{'obs': ([0, 1, 3], [0, 2, -1]), 't': [0, 1, 2]}

**for more complicated transposes**

use jax.tree_transpose, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility.

In [53]:
jax.tree.transpose(
    outer_treedef=jax.tree.structure([0 for e in episode_steps]),
    inner_treedef=jax.tree.structure(episode_steps[0]),
    pytree_to_transpose=episode_steps)

{'obs': ([0, 1, 3], [0, 2, -1]), 't': [0, 1, 2]}