# Basics of pytree

In [58]:
import jax
import jax.numpy as jnp

import jax.tree_util as tree_util

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = tree_util.tree_leaves(pytree) 
  print(pytree.__class__)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

<class 'list'>
[1, 'a', <object object at 0x7f0238394720>]   has 3 leaves: [1, 'a', <object object at 0x7f0238394720>]
<class 'tuple'>
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
<class 'list'>
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
<class 'dict'>
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
<class 'jaxlib.xla_extension.ArrayImpl'>
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


The most commonly used pytree function is `tree_util.tree_map`. It works analogously to Python’s native map, but transparently operates over entire pytrees.

In [59]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

tree_util.tree_map(lambda x: x*2, list_of_lists) #

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

# Customized pynode 

The following are examples showing how to register a new pytree

In [60]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

tree_util.tree_leaves([
    Special(0, 1),
    Special(2, 4),
])


[<__main__.Special at 0x7f023833b010>, <__main__.Special at 0x7f023833bbe0>]

In [61]:
# tree_util.tree_map(lambda x: x + 1,
#   [
#     Special(0, 1),
#     Special(2, 4)
#   ]) # this will report an ever 

In [62]:
#from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
tree_util.register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

In [63]:
tree_util.tree_map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])

[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.

For instance, a Python `NamedTuple` subclass doesn’t need to be registered to be considered a pytree node type:

In [64]:
from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
tree_util.tree_leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])

['Alice', 1, 2, 3, 'Bob', 4, 5, 6]