# Pytree

By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict

In [1]:
from jax import tree_util

In [6]:
a = [1, 2, 4, [1, 2, 4]]
val, tree = tree_util.tree_flatten(a)
val, tree

([1, 2, 4, 1, 2, 4], PyTreeDef([*, *, *, [*, *, *]]))

In [8]:
transform_val = map(lambda x: x**2, val)
tree_util.tree_unflatten(tree, transform_val)

[1, 4, 16, [1, 4, 16]]

In [10]:
from collections import namedtuple
import jax.numpy as jnp

Points = namedtuple("Points", "x y")
examples_container = [
    [12, 34, [1, 2, 4], 3],
    {"a": 3, "b": {"c": 4, "d": 5}},
    (1, 24, 5, (2, 35)),
    Points(2, 4),
    jnp.arange(4),
    [[1, 2], (3, 4), {"a": 5, "b": 6}, Points(7, 8), jnp.arange(8, 10)],
]

In [13]:
def show_example(structured):
    flat, tree = tree_util.tree_flatten(structured)
    unflattened = tree_util.tree_unflatten(tree, flat)
    print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}\n")

In [14]:
for ex in examples_container:
    show_example(ex)

structured=[12, 34, [1, 2, 4], 3]
  flat=[12, 34, 1, 2, 4, 3]
  tree=PyTreeDef([*, *, [*, *, *], *])
  unflattened=[12, 34, [1, 2, 4], 3]

structured={'a': 3, 'b': {'c': 4, 'd': 5}}
  flat=[3, 4, 5]
  tree=PyTreeDef({'a': *, 'b': {'c': *, 'd': *}})
  unflattened={'a': 3, 'b': {'c': 4, 'd': 5}}

structured=(1, 24, 5, (2, 35))
  flat=[1, 24, 5, 2, 35]
  tree=PyTreeDef((*, *, *, (*, *)))
  unflattened=(1, 24, 5, (2, 35))

structured=Points(x=2, y=4)
  flat=[2, 4]
  tree=PyTreeDef(CustomNode(namedtuple[Points], [*, *]))
  unflattened=Points(x=2, y=4)

structured=Array([0, 1, 2, 3], dtype=int32)
  flat=[Array([0, 1, 2, 3], dtype=int32)]
  tree=PyTreeDef(*)
  unflattened=Array([0, 1, 2, 3], dtype=int32)

structured=[[1, 2], (3, 4), {'a': 5, 'b': 6}, Points(x=7, y=8), Array([8, 9], dtype=int32)]
  flat=[1, 2, 3, 4, 5, 6, 7, 8, Array([8, 9], dtype=int32)]
  tree=PyTreeDef([[*, *], (*, *), {'a': *, 'b': *}, CustomNode(namedtuple[Points], [*, *]), *])
  unflattened=[[1, 2], (3, 4), {'a': 5, 'b':

# Extending tree

In [17]:
from dataclasses import dataclass

@dataclass
class NotPytree:
    x:int
    y:int

    def __repr__(self) -> str:
        return f"NotPytree(x={self.x},y={self.y})"


In [18]:
NotPytree(3,4)

NotPytree(x=3,y=4)

In [21]:
tree_util.tree_flatten(NotPytree(3,4))
#? if the structure is not recognized as internal node , then it is consider as leaf
#* we need to register pytree for recognize the internal structure

([NotPytree(x=3,y=4)], PyTreeDef(*))

In [25]:
@tree_util.register_pytree_node_class
@dataclass
class RegisterPytree:
    x:int
    y:int

    def __repr__(self) -> str:
        return f"RegisterPytree(x={self.x},y={self.y})"
    
    def tree_flatten(self):
        children = (self.x,self.y)
        aux = None
        return (children,aux)
    
    @classmethod
    def tree_unflatten(cls,aux_data,children):
        return cls(*children)

In [26]:
tree_util.tree_flatten(RegisterPytree(3,4))

([3, 4], PyTreeDef(CustomNode(RegisterPytree[None], [*, *])))

In [27]:
tree_util.tree_map(lambda x:2*x,RegisterPytree(3,4))

RegisterPytree(x=6,y=8)