In [1]:
import jax
import jax.numpy as jnp

A pytree is a container of leaf elements and/or more pytrees. Standard containers include lists, tuples, dictionaries. Users registered containers are also allowed. A leaf elements is anything which is not a pytree.

In [2]:
example_trees = [
    [1, 2, ('a', 'b')],
    (11, (12, 13), ()),
    [{'evergreen': 'fir', 'deciduous': ('oak', 'maple')}, {'mammal': 'wolverine'}],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

In [3]:
for pytree in example_trees:
  leaves = jax.tree_util.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 2, ('a', 'b')]                            has 4 leaves: [1, 2, 'a', 'b']
(11, (12, 13), ())                            has 3 leaves: [11, 12, 13]
[{'evergreen': 'fir', 'deciduous': ('oak', 'maple')}, {'mammal': 'wolverine'}] has 4 leaves: ['oak', 'maple', 'fir', 'wolverine']
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


jax.tree_map applies a function to the leaves

In [4]:
init_tree = [
    [1, 2, [3, 4]],
    (10, 20),
    {"k1": 5, "k2": 3}
]

In [5]:
modified_tree = jax.tree_map(lambda x: x**2, init_tree)
modified_tree

[[1, 4, [9, 16]], (100, 400), {'k1': 25, 'k2': 9}]

And we can even tree_map multiple trees. Of course the structure of the inputs must match.

In [6]:
jax.tree_map(lambda x,y: 5*x+y, init_tree, modified_tree)

[[6, 14, [24, 36]], (150, 500), {'k1': 50, 'k2': 24}]

We can also view the pytree structure of arbitrary objects. These structures are coded in a treedef object.

In [7]:
print(jax.tree_util.tree_structure(init_tree))

PyTreeDef([[*, *, [*, *]], (*, *), {'k1': *, 'k2': *}])


In [8]:
arr = (jnp.zeros((2,3)), jnp.ones((3,4)))

In [9]:
print(jax.tree_util.tree_structure(arr))

PyTreeDef((*, *))


When JAX flattens a pytree it will produce a list of leaves and a treedef object that encodes the structure of the original value. The treedef can then be used to construct a matching structured value after transforming the leaves.

In [10]:
flat, tree = jax.tree_util.tree_flatten(init_tree)
print(f"Leaves {flat}, Tree {tree}")

Leaves [1, 2, 3, 4, 10, 20, 5, 3], Tree PyTreeDef([[*, *, [*, *]], (*, *), {'k1': *, 'k2': *}])


In [11]:
transformed_flat = list(map(lambda v: v * 7, flat))
print(f"{transformed_flat=}")

transformed_flat=[7, 14, 21, 28, 70, 140, 35, 21]


In [12]:
reconstructed = jax.tree_util.tree_unflatten(tree, transformed_flat)
print(reconstructed)

[[7, 14, [21, 28]], (70, 140), {'k1': 35, 'k2': 21}]


And we can define our own custom pytree containers to serve as nodes. As part of the definition we have to indicate how to flatten and how to unflatten the container.

In [13]:
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class MyNode:
  def __init__(self, name: str, a: int, b: int, c: str):
    self.name = name
    self.a = a
    self.b = b
    self.c = c
    
  def __repr__(self):
    return "MyNode(Name={}, a={}, b={}, c={})".format(self.name, self.a, self.b, self.c)
    
  def tree_flatten(self):
    flat_contents = [self.a, self.b, self.c]
    aux_data = self.name #The name is not a child
    return flat_contents, aux_data
  
  @classmethod
  def tree_unflatten(cls, aux_data, flat_contents):
    return MyNode(aux_data, *flat_contents)

In [14]:
print(MyNode('Alice', 1, 2, 'Flax'))

MyNode(Name=Alice, a=1, b=2, c=Flax)


In [15]:
jax.tree_util.tree_leaves([
    MyNode('Alice', 1, 2, 'Flax'),
    MyNode('Bob', 4, 5, 'Linen')
])

[1, 2, 'Flax', 4, 5, 'Linen']

In [16]:
Alice_and_Bob = jax.tree_map(lambda x,y: x+y, 
    MyNode('Alice', 1, 2, 'Flax'),
    MyNode('Alice', 10, 20, 'Linen')) #The custom data must match
jax.tree_util.tree_leaves(Alice_and_Bob)

[11, 22, 'FlaxLinen']

In [17]:
Alice_and_Bob.name

'Alice'

jax.tree_utils treats None as a node without children, not as a leaf:

In [18]:
jax.tree_util.tree_leaves([None, (), False])

[False]