# Understanding PAX

This tutorial shows the inner mechanisms of PAX. You will learn that:

1. A pytree is a tree structure used by JAX to store `ndarray`'s.
2. The two most important methods of a pytree object are `tree_flatten` and `tree_unflatten`.
3. PAX modules use the dictionary `name_to_kind` to manages the structure of the pytree.

## Pytree

Pytree is a tree-like structure that is constructed from Python object containers. Here are a few examples:

In [1]:
import jax

a = 123

b = [1, 2, 3]
d = (1, 2, 3)
c = {"1": 1, "2": 2, "3": 3}

e = [(1, 2), "123", {"1": 1, "2": [4, 5]}]

leaves, treedef = jax.tree_flatten(e)
print(leaves)
print(treedef)

[1, 2, '123', 1, 4, 5]
PyTreeDef([(*, *), *, {'1': *, '2': [*, *]}])


The `jax.tree_flatten` function transforms an object into its tree representation that includes:

1.  `leaves`: a list of tree leaves.
2.  `treedef`: information about the structure of the tree. 


**Note**: Even though a pytree can have any object at its leaves, many jax functions such as ``jax.jit``, ``jax.lax.scan``, ``jax.grad``, etc. only support pytrees with `ndarray` leaves.

We can reverse ``jax.tree_flatten`` transformation with ``jax.tree_unflatten``:

In [2]:
print(jax.tree_unflatten(treedef=treedef, leaves=leaves))

[(1, 2), '123', {'1': 1, '2': [4, 5]}]


## A simple PAX module

Now let us try to build a simple PAX module from scratch!

The core idea here is that: **a module is also a pytree.**

To let jax knows how to flatten and unflatten an object:

1. It needs to implement two methods ``tree_flatten`` and ``tree_unflatten``.
2. It is registered as a pytree node. The function ``jax.tree_util.register_pytree_node_class`` is what we need.

In [3]:
class Module:
    def __init__(self, mylist):
        self.mylist = mylist
        self.is_training = True

    def tree_flatten(self):
        chilren = [self.mylist]
        aux_info = {"is_training": self.is_training}
        return chilren, aux_info

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        new_object = cls.__new__(cls)
        new_object.mylist = children[0]
        new_object.is_training = aux_info["is_training"]
        return new_object

    def __str__(self):
        return (
            self.__class__.__name__
            + f" :: mylist={self.mylist}  is_training={self.is_training}"
        )


jax.tree_util.register_pytree_node_class(Module)

__main__.Module

Let's try to flatten and unflatten a module.

In [4]:
mod = Module([1, 2, 3])
leaves, tree_def = jax.tree_flatten(mod)
print(leaves, tree_def)
new_mod = jax.tree_unflatten(tree_def, leaves)
print(new_mod)

[1, 2, 3] PyTreeDef(CustomNode(<class '__main__.Module'>[{'is_training': True}], [[*, *, *]]))
Module :: mylist=[1, 2, 3]  is_training=True


**Note:** ``is_training`` is kept as part of the PyTreeDef.

## Introducing ``register_subtree`` method

OK, but our pytree ``Module`` only supports ``mylist`` and ``is_training`` attributes. A _real_ module for neural network training can have an arbitrary number of attributes.

Moreover, how can our module know that ``mylist`` is part of the subtree while ``is_training`` belongs to the tree definition?

One solution is:

1. to keep a set (namely, ``tree_part_names``) that tells if an attribute is part of the tree or not.
2. users need to `register` if an attribute is part of the tree.
3. by default, any attribute that is not registered belongs to the tree definition.

In [5]:
class Module:
    def __init__(self):
        self.tree_part_names = frozenset()
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if name in self.tree_part_names:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def register_subtree(self, name, value):
        self.__dict__[name] = value
        self.tree_part_names = self.tree_part_names.union([name])

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

Our new module has ``register_subtree`` method that will add attribute name to the ``tree_part_names`` set. 

The ``tree_flatten`` method will list all attributes of the object and check if its name is in ``tree_part_names`` or not. If it is, its value will be added to the ``children`` list, otherwise, ``(name, value)`` will be added to the ``others`` list. 

The ``tree_unflatten`` method combines information from ``others``, ``children_names``, and ``children`` to reconstruct the object.

**Note:** 

1. We purposely use `frozenset` to guarantee that any modification of `tree_part_names` in one object does not affect other objects. 
However, this is not guaranteed for other attributes of the module.
2. `__init_subclass__` ensures any subclass of `Module` is also registered as pytree.

Let's try our module:

In [6]:
class Counter(Module):
    def __init__(self):
        super().__init__()

        self.register_subtree("count", 0)

    def step(self):
        self.count = self.count + 1

    def __str__(self):
        return self.__class__.__name__ + f" :: count={self.count}"


counter = Counter()
counter.step()
print(counter)

leaves, treedef = jax.tree_flatten(counter)
print((leaves, treedef))

new_counter = jax.tree_unflatten(treedef, leaves)
print(new_counter)

Counter :: count=1
([1], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count'], [('tree_part_names', frozenset({'count'})), ('is_training', True)])], [*])))
Counter :: count=1


## Parameter and State

However, our module still does not have a way to distinguish trainable and non-trainable parameters. PAX's solution is to use a dictionary instead of a set. 
This dictionary, namely `name_to_kind`, maps attribute's name to parameter kinds which can be `trainable parameter`,  `non-trainable state` or `others`.

In [7]:
from types import MappingProxyType
from typing import OrderedDict
from enum import Enum


class PaxKind(Enum):
    STATE = 1
    PARAMETER = 2
    OTHERS = -1


class Module:
    name_to_kind: OrderedDict[str, PaxKind]
    is_training: bool

    def __init__(self):
        self.name_to_kind = MappingProxyType(OrderedDict())
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if self.name_to_kind.get(name, PaxKind.OTHERS) != PaxKind.OTHERS:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def _update_name_to_kind_dict(self, name: str, value: PaxKind):
        new_dict = OrderedDict(self.name_to_kind)
        new_dict[name] = value
        self.name_to_kind = MappingProxyType(new_dict)

    def register_state(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.STATE)

    def register_parameter(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.PARAMETER)

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)


jax.tree_util.register_pytree_node_class(Module)

__main__.Module

Our module now has two methods to register tree parts:

1. `register_parameter` registers a trainable parameter.
2. `register_state` registers a non-trainable state.

**Note**: We are using `MappingProxyType` to create a read-only view to `name_to_kind`. The method `_update_name_to_kind_dict` creates a new dictionary every time `name_to_kind` is updated.

Now, let's try it out.

In [8]:
class Counter(Module):
    def __init__(self):
        super().__init__()

        self.register_state("count", 0)
        self.register_parameter("bias", 0.0)

    def step(self):
        self.count = self.count + 1

    def __str__(self):
        return self.__class__.__name__ + f" :: count={self.count}  bias={self.bias}"


counter = Counter()
counter.step()
print(counter)

leaves, treedef = jax.tree_flatten(counter)
print((leaves, treedef))

new_counter = jax.tree_unflatten(treedef, leaves)
print(new_counter)

Counter :: count=1  bias=0.0
([1, 0.0], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count', 'bias'], [('name_to_kind', mappingproxy(OrderedDict([('count', <PaxKind.STATE: 1>), ('bias', <PaxKind.PARAMETER: 2>)]))), ('is_training', True)])], [*, *])))
Counter :: count=1  bias=0.0


## Sub-module

Last but not the least, we need a way to know and detect if an attribute is a sub-module of the current module.
PAX provides two ways to achieve that:

1. To register: the `register_module` method will register an attribute with a new kind `PaxKind.MODULE`.
2. To detect: If a user assigns a value of instance `Module` to an attribute whose name is not in `name_to_kind`, it will automatically be registered as `PaxKind.MODULE`.

In [9]:
from typing import Any


class PaxKind(Enum):
    STATE = 1
    PARAMETER = 2
    MODULE = 3
    OTHERS = -1


class ModuleV2(Module):
    def register_module(self, name, value):
        self.__dict__[name] = value
        self.name_to_kind[name] = PaxKind.MODULE

    def __setattr__(self, name: str, value: Any) -> None:
        if isinstance(value, ModuleV2) and name not in self.name_to_kind:
            self._updatename_to_kind_dict(name, PaxKind.MODULE)
        return super().__setattr__(name, value)