# Working with pytrees
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — in JAX these are called pytrees. This section will explain how to use them, provide useful code examples, and point out common “gotchas” and patterns.

# What is a pytree?
A pytree is a container-like structure built out of container-like Python objects — “leaf” pytrees and/or more pytrees. A pytree can include lists, tuples, and dicts. A leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree.

In the context of machine learning (ML), a pytree can contain:

* Model parameters

* Dataset entries

* Reinforcement learning agent observations

When working with datasets, you can often come across pytrees (such as lists of lists of dicts).

Below is an example of a simple pytree. In JAX, you can use `jax.tree.leaves()`, to extract the flattened leaves from the trees, as demonstrated here:

In [1]:
import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x000001A0C6B17F40>] has 3 leaves: [1, 'a', <object object at 0x000001A0C6B17F40>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


Any tree-like structure built out of container-like Python objects can be treated as a pytree in JAX. Classes are considered container-like if they are in the pytree registry, which by default includes lists, tuples, and dicts. Any object whose type is not in the pytree container registry will be treated as a leaf node in the tree.

The pytree registry can be extended to include user-defined container classes by registering the class with functions that specify how to flatten the tree; see Custom pytree nodes below.

## Common pytree functions
JAX provides a number of utilities to operate over pytrees. These can be found in the `jax.tree_util` subpackage; for convenience many of these have aliases in the `jax.tree` module.

### Common function: jax.tree.map
The most commonly used pytree function is `jax.tree.map()`. It works analogously to Python’s native `map`, but transparently operates over entire pytrees.

In [2]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [3]:
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [4]:
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

jax.tree.map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

In [5]:
# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

## Custom pytree nodes
This section explains how in JAX you can extend the set of Python types that will be considered *internal nodes* in pytrees (pytree nodes) by using `jax.tree_util.register_pytree_node()` with `jax.tree.map()`.

Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you register it with JAX. This is also the case even if your container class has trees inside it. 

In [6]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])

[<__main__.Special at 0x1a0c6dcaf30>, <__main__.Special at 0x1a0c9337e30>]

In [7]:
jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ])

TypeError: unsupported operand type(s) for +: 'Special' and 'int'

In [8]:
from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

In [9]:
jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])

[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

In [10]:
from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])

['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

## Explicit key paths
In a pytree each leaf has a key path. A key path for a leaf is a `list` of keys, where the length of the list is equal to the depth of the leaf in the pytree . Each key is a hashable object that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dicts` is different from the type of keys for `tuples`.

For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.

JAX has the following `jax.tree_util.*` methods for working with key paths:

* `jax.tree_util.tree_flatten_with_path()`: Works similarly to jax.tree.flatten(), but returns key paths.

* `jax.tree_util.tree_map_with_path()`: Works similarly to `jax.tree.map()`, but the function also takes key paths as arguments.

* `jax.tree_util.keystr()`: Given a general key path, returns a reader-friendly string expression.

In [11]:
import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')

Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo


To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:

* `SequenceKey(idx: int)`: For lists and tuples.

* `DictKey(key: Hashable)`: For dictionaries.

* `GetAttrKey(name: str)`: For `namedtuples` and preferably custom pytree nodes (more in the next section)

You are free to define your own key types for your custom nodes. They will work with `jax.tree_util.keystr()` as long as their `__str__()` method is also overridden with a reader-friendly expression.

In [12]:
for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')

Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))


## Common pytree gotchas
This section covers some of the most common problems (“gotchas”) encountered when using JAX pytrees.

### Mistaking pytree nodes for leaves
A common gotcha to look out for is accidentally introducing tree nodes instead of leaves

In [13]:
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)

[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it’s called on `2` and `3`.

The solution will depend on the specifics, but there are two broadly applicable options:

* Rewrite the code to avoid the intermediate `jax.tree.map()`.

* Convert the tuple into a NumPy array (`np.array`) or a JAX NumPy array (`jnp.array`), which makes the entire sequence a leaf.

## Handling of `None` by `jax.tree_util`
`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf

In [14]:
jax.tree.leaves([None, None, None])

[]

In [15]:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)

[None, None, None]

### Custom pytrees and initialization with unexpected values
Another common gotcha with user-defined pytree objects is that JAX transformations occasionally initialize them with unexpected values, so that any input validation done at initialization may fail. For example:

In [16]:
class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.

TypeError: Value '<object object at 0x000001A0CA464D00>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [17]:
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.

  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)


TypeError: Value '<object object at 0x000001A0D0018BD0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [18]:
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.

  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)


TypeError: Value '<object object at 0x000001A0D0018D60>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

**Potential solution 1:**

The `__init__` and `__new__` methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:

In [19]:
class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

**Potential solution 2:**

Structure your custom ``tree_unflatten` function so that it avoids calling `__init__`. If you choose this route, make sure that your `tree_unflatten` function stays in sync with `__init__` if and when the code is updated. Example:

In [20]:
def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

# Common pytree patterns
This section covers some of the most common patterns with JAX pytrees.

## Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and `jax.tree.transpose()` (more flexible, complex and verbose).

Option 1: Use `jax.tree.map()`. Here’s an example:

In [21]:
def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)

{'obs': [3, 4], 't': [1, 2]}

In [22]:
jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)

{'obs': [3, 4], 't': [1, 2]}

**Option 2**: For more complex transposes, use `jax.tree.transpose()`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:

In [23]:
jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)

{'obs': [3, 4], 't': [1, 2]}