In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

JX has built-in support for nested structures

Pytree is a container like structure built out of other pytrees and python objects. Each leaf of a pytree is a non-pytree object

In [2]:
import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

In [3]:
jax.tree.leaves(example_trees)

[1,
 'a',
 <object at 0x7faa2a785260>,
 1,
 2,
 3,
 1,
 2,
 3,
 4,
 5,
 2,
 2,
 3,
 Array([1, 2, 3], dtype=int32)]

In [4]:
# any object not in the pytree registry is treated as a leaf object.

In [5]:
# Pytree registry can be extended with custom classes by specifying how to flatten the object.

In [6]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)


[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [8]:
# Can also map a function with multiple tree inputs
jax.tree.map(lambda x, y: x+y, list_of_lists, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [9]:
# NamedTuple is registed with pytree. @dataclass is not.
# You can use jax.tree_util.register_pytree_node to register any class in pytree.
# You can use jax.tree_util.register_dataclass() to register a dataclass.

In [12]:
from typing import NamedTuple, Any
from dataclasses import dataclass
import functools

@functools.partial(jax.tree_util.register_dataclass,
                   data_fields=['a', 'b', 'c'],
                   meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
  name: str
  a: Any
  b: Any
  c: Any

# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
  MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
  MyDataclassContainer('banana', jnp.array([3, 4]), -1., 0.)
])

[5.3,
 1.2,
 Array([0., 0., 0., 0.], dtype=float32),
 Array([3, 4], dtype=int32),
 -1.0,
 0.0]

Name is declared as metadata and thus, doesn't appear as part of data. Now you can pass this dataclass instance to JIT'ed functions
and name will be a static.

All jax transformations can be applied to functions that have pytree input and output.

Some jax transformations take arguments that define how input and output values must be treated. These can also be pytrees as long as they are of the same structure.

In pytree, each node has a key path. A key path for a leaf is a list of keys, where the length of the list is equal to the depth of the leaf. Each key is a hashable object.

jax.tree_util functions treat None as absence of pytree instead of a leaf.

In [13]:
jax.tree.leaves([None, None, None])

[]

In [14]:
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)

[None, None, None]