In [1]:
import jax
import jax.numpy as jnp

### By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:

In [2]:
custom_structure = [1, (2, 3), [4, 5]]
jax.tree_util.tree_flatten(custom_structure)

([1, 2, 3, 4, 5], PyTreeDef([*, (*, *), [*, *]]))

In [3]:
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2), # only object that is considered a leaf
    Point(1., 2.)
]
def show_example(structured):
    flat, tree = jax.tree_util.tree_flatten(structured)
    unflattened = jax.tree_util.tree_unflatten(tree, flat)
    print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

print(f'example_containers=')
show_example(example_containers)
for structured in example_containers:
    print('\n')
    show_example(structured)

example_containers=
structured=[(1.0, [2.0, 3.0]), (1.0, {'b': 2.0, 'a': 3.0}), 1.0, None, Array([0., 0.], dtype=float32), Point(x=1.0, y=2.0)]
  flat=[1.0, 2.0, 3.0, 1.0, 3.0, 2.0, 1.0, Array([0., 0.], dtype=float32), 1.0, 2.0]
  tree=PyTreeDef([(*, [*, *]), (*, {'a': *, 'b': *}), *, None, *, CustomNode(namedtuple[Point], [*, *])])
  unflattened=[(1.0, [2.0, 3.0]), (1.0, {'a': 3.0, 'b': 2.0}), 1.0, None, Array([0., 0.], dtype=float32), Point(x=1.0, y=2.0)]


structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])


structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})


structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0


structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None


structured=Array([0., 0.], dtype=float32)
  flat=[Array([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=Array([0., 0.], dtyp

Python objects: https://docs.python.org/3/reference/datamodel.html#basic-customization, https://stackoverflow.com/questions/73409385/object-class-documentation-in-python, https://docs.python.org/3/library/functions.html#object

### To extend the set of objects that are recognized as nodes, we can register new objects

In [4]:
# create a dummy model as a ABC class with dataclass
from dataclasses import dataclass
from abc import ABC, abstractmethod

# AbstractLayer is an abstract class which cannot be instantiated
# ABC classes with abstract methods insure that all subclasses implement the abstract methods
# (It doesn't make sense to use the @dataclass decorator with an ABC class)
class AbstractLayer(ABC):
    @abstractmethod
    def __call__(self, x):
        pass

In [5]:
@jax.tree_util.register_pytree_node_class
@dataclass
class PyTreeLayer(AbstractLayer):
    weight: jax.Array
    bias: jax.Array
    name: str

    def __call__(self, x):
        return self.weight @ x + self.bias
    
    def tree_flatten(self):
        children = (self.weight, self.bias) # the children of the current node
        aux_data = (self.name,) # auxiliary data that are not part of the tree structure
        return children, aux_data
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, *aux_data)

In [9]:
w = jax.random.normal(jax.random.PRNGKey(0), (3, 2))
b = jax.random.normal(jax.random.PRNGKey(1), (3,))

layer = PyTreeLayer(w, b, 'layer1')
print(layer)

PyTreeLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')


In [7]:
leaves, aux_data = jax.tree_util.tree_flatten(layer)
print(leaves)
print(aux_data, '\n')

layer_reconstructed = jax.tree_util.tree_unflatten(aux_data, leaves) # will throw an error
print(layer_reconstructed)

[Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32)]
PyTreeDef(CustomNode(PyTreeLayer[('layer1',)], [*, *])) 

PyTreeLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')


In [8]:
leaves, aux_data = layer.tree_flatten()
print(leaves)
print(aux_data)

layer_reconstructed = PyTreeLayer.tree_unflatten(aux_data, leaves) # will work
print(layer_reconstructed)

(Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32))
('layer1',)
PyTreeLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')


### Then jax has multiple functions to manipulate pytree containers, including: `jax.tree_map`, `jax.tree_multimap`, `jax.tree_flatten`, `jax.tree_unflatten`, `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, `jax.tree_reduce`, `jax.tree_all`, `jax.tree_any`, `jax.tree_pmap`, `jax.tree_util`, `jax.tree_util.tree_flatten`, `jax.tree_util.tree_unflatten`, `jax.tree_util.tree_structure`, `jax.tree_util.tree_transpose`, `jax.tree_util.tree_reduce`, `jax.tree_util.tree_all`, `jax.tree_util.tree_any`, `jax.tree_util.tree_pmap`

### For example, `jax.tree_map` applies a function to each node in a pytree, preserving the tree structure: Here we square each leaf in a pytree

In [12]:
print(layer)
layer_squared = jax.tree_util.tree_map(lambda x: x**2, layer)
print(layer_squared)

PyTreeLayer(weight=Array([[ 0.18784384, -1.2833426 ],
       [ 0.6494181 ,  1.2490594 ],
       [ 0.24447003, -0.11744965]], dtype=float32), bias=Array([ 0.17269018, -0.64765567,  1.2229712 ], dtype=float32), name='layer1')
PyTreeLayer(weight=Array([[0.03528531, 1.6469682 ],
       [0.4217439 , 1.5601494 ],
       [0.0597656 , 0.01379442]], dtype=float32), bias=Array([0.0298219 , 0.41945785, 1.4956585 ], dtype=float32), name='layer1')
