# Understanding PAX's module

This tutorial shows how to build a PAX-like module for neural network training from scratch.

## Pytree

In [1]:
import jax
from types import MappingProxyType
from typing import OrderedDict, Any
from enum import Enum

First, let's talk about _pytree_.

Pytrees are tree-like structures that are constructed from Python object containers. Here are a few examples of pytree:

In [2]:
a = 123
b = [1, 2, 3]
d = (1, 2, 3)
c = {"1": 1, "2": 2, "3": 3}
e = [(1, 2), "123", {"1": 1, "2": [4, 5]}]

JAX provides the `jax.tree_flatten` function that transforms an object into its tree representation that includes:

- `leaves`: a list of tree leaves.
- `treedef`: information about the structure of the tree.

In [3]:
leaves, treedef = jax.tree_flatten(e)
print("Leaves:", leaves)
print("TreeDef:", treedef)

Leaves: [1, 2, '123', 1, 4, 5]
TreeDef: PyTreeDef([(*, *), *, {'1': *, '2': [*, *]}])


.. note:: Even though a pytree can have any object at its leaves, many jax functions such as ``jax.jit``, ``jax.lax.scan``, ``jax.grad``, etc. only support pytrees with `ndarray` leaves.

We can reverse ``jax.tree_flatten`` transformation with ``jax.tree_unflatten``:

In [4]:
jax.tree_unflatten(treedef=treedef, leaves=leaves)

[(1, 2), '123', {'1': 1, '2': [4, 5]}]

## A simple PAX module

Now let's try to build a simple PAX module. The core idea here is that:

> **A module is also a pytree.**

To let JAX knows how to flatten and unflatten a _pytree_ module:

1. It needs to implement two methods: ``tree_flatten`` and ``tree_unflatten``.
2. It is registered as a pytree node.

In [5]:
@jax.tree_util.register_pytree_node_class
class Module:
    def __init__(self, mylist):
        self.mylist = mylist
        self.is_training = True

    def tree_flatten(self):
        chilren = [self.mylist]
        aux_info = {"is_training": self.is_training}
        return chilren, aux_info

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        new_object = cls.__new__(cls)
        new_object.mylist = children[0]
        new_object.is_training = aux_info["is_training"]
        return new_object

    def __str__(self):
        name = self.__class__.__name__
        info = f"mylist={self.mylist}, is_training={self.is_training}"
        return f"{name}({info})"

The function ``jax.tree_util.register_pytree_node_class`` registers `Module` as a class of pytree nodes.

Let's try to flatten and unflatten a module.

In [6]:
mod = Module([1, 2, 3])
print(mod)
leaves, tree_def = jax.tree_flatten(mod)
print(leaves, tree_def)
new_mod = jax.tree_unflatten(tree_def, leaves)
print(new_mod)

Module(mylist=[1, 2, 3], is_training=True)
[1, 2, 3] PyTreeDef(CustomNode(<class '__main__.Module'>[{'is_training': True}], [[*, *, *]]))
Module(mylist=[1, 2, 3], is_training=True)


**Note:** ``is_training`` is considered as part of the PyTreeDef.

## Introducing ``register_subtree`` method

OK, but our pytree module only supports ``mylist`` and ``is_training`` attributes. A _real_ module for neural network training can have an arbitrary number of attributes.

Moreover, how can our module know that ``mylist`` is part of the subtree while ``is_training`` belongs to the tree definition?

One solution is:

1. to keep a set (namely, ``tree_part_names``) that tells if an attribute is part of the tree or not.
2. users need to _register_ if an attribute is part of the tree.
3. any attribute that is not registered belongs to the tree definition.

In [7]:
@jax.tree_util.register_pytree_node_class
class Module:
    def __init__(self):
        self.tree_part_names = frozenset()
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if name in self.tree_part_names:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def register_subtree(self, name, value):
        self.__dict__[name] = value
        self.tree_part_names = self.tree_part_names.union([name])

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

Our new module has ``register_subtree`` method that adds attribute's name to the ``tree_part_names`` set. 

The ``tree_flatten`` method lists all attributes of the object and checks if its name is in ``tree_part_names`` or not. If it is, its value will be added to the ``children`` list, otherwise, ``(name, value)`` will be added to the ``others`` list. 

The ``tree_unflatten`` method combines information from ``others``, ``children_names``, and ``children`` to reconstruct the module.

**Note:** 

1. We purposely use `frozenset` to guarantee that any modification of `tree_part_names` in one module does not affect other modules. 
(However, this is not guaranteed for other attributes of the module.)
2. `__init_subclass__` ensures any subclass of `Module` is registered as pytree node.

Let's try our module with a simple counter:

In [8]:
class Counter(Module):
    def __init__(self):
        super().__init__()

        self.register_subtree("count", 0)

    def step(self):
        self.count = self.count + 1

    def __str__(self):
        return f"{self.__class__.__name__}(count={self.count})"

In [9]:
counter = Counter()
print(counter)
counter.step()
print(counter)
leaves, treedef = jax.tree_flatten(counter)
print((leaves, treedef))
new_counter = jax.tree_unflatten(treedef, leaves)
print(new_counter)

Counter(count=0)
Counter(count=1)
([1], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count'], [('tree_part_names', frozenset({'count'})), ('is_training', True)])], [*])))
Counter(count=1)


## Parameter and State

However, our module still does not have a way to distinguish trainable and non-trainable parameters.
PAX's solution is to use a dictionary, namely `name_to_kind`, that maps attribute's name to parameter kinds, which includes: _trainable parameter_,  _non-trainable state_ and _unknown_.

In [10]:
class PaxKind(Enum):
    PARAMETER = 1
    STATE = 2
    UNKNOWN = -1

In [11]:
@jax.tree_util.register_pytree_node_class
class Module:
    name_to_kind: OrderedDict[str, PaxKind]
    is_training: bool

    def __init__(self):
        self.name_to_kind = MappingProxyType(OrderedDict())
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if self.name_to_kind.get(name, PaxKind.UNKNOWN) != PaxKind.UNKNOWN:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def _update_name_to_kind_dict(self, name: str, value: PaxKind):
        new_dict = OrderedDict(self.name_to_kind)
        new_dict[name] = value
        self.name_to_kind = MappingProxyType(new_dict)

    def register_state(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.STATE)

    def register_parameter(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.PARAMETER)

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

Our module now has two methods to register tree parts:

1. `register_parameter` registers a trainable parameter.
2. `register_state` registers a non-trainable state.

**Note**: We are using `MappingProxyType` to create a read-only view to `name_to_kind`. The method `_update_name_to_kind_dict` creates a new dictionary every time `name_to_kind` is updated.

Now, let's try it out.

In [12]:
class TrainableCounter(Module):
    def __init__(self, delta=1.0):
        super().__init__()

        self.register_state("count", 0)
        self.register_parameter("delta", delta)

    def step(self):
        self.count = self.count + self.delta

    def __str__(self):
        return f"{self.__class__.__name__}(count={self.count}, delta={self.delta})"

In [13]:
counter = TrainableCounter(delta=1.0)
print(counter)
counter.step()
print(counter)
leaves, treedef = jax.tree_flatten(counter)
print((leaves, treedef))
new_counter = jax.tree_unflatten(treedef, leaves)
print(new_counter)

TrainableCounter(count=0, delta=1.0)
TrainableCounter(count=1.0, delta=1.0)
([1.0, 1.0], PyTreeDef(CustomNode(<class '__main__.TrainableCounter'>[(['count', 'delta'], [('name_to_kind', mappingproxy(OrderedDict([('count', <PaxKind.STATE: 2>), ('delta', <PaxKind.PARAMETER: 1>)]))), ('is_training', True)])], [*, *])))
TrainableCounter(count=1.0, delta=1.0)


## Submodules

Last but not the least, we need a way to register submodules.

PAX provides the `register_module` method that registers an attribute with the new kind `PaxKind.MODULE`.

In [14]:
class PaxKind(Enum):
    PARAMETER = 1
    STATE = 2
    MODULE = 3
    UNKNOWN = -1

In [15]:
@jax.tree_util.register_pytree_node_class
class Module:
    name_to_kind: OrderedDict[str, PaxKind]
    is_training: bool

    def __init__(self):
        self.name_to_kind = MappingProxyType(OrderedDict())
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if self.name_to_kind.get(name, PaxKind.UNKNOWN) != PaxKind.UNKNOWN:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def _update_name_to_kind_dict(self, name: str, value: PaxKind):
        new_dict = OrderedDict(self.name_to_kind)
        new_dict[name] = value
        self.name_to_kind = MappingProxyType(new_dict)

    def register_state(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.STATE)

    def register_parameter(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.PARAMETER)

    def register_module(self, name, value):
        self.__dict__[name] = value
        self._update_name_to_kind_dict(name, PaxKind.MODULE)

    def __setattr__(self, name: str, value: Any) -> None:
        if isinstance(value, Module) and name not in self.name_to_kind:
            self.register_module(name, value)
        else:
            return super().__setattr__(name, value)

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

For convenience, if users assign a value of instance `Module` to an attribute whose name is not in `name_to_kind`, it will be registered as `PaxKind.MODULE`, automatically.

Let's give it a try:

In [16]:
class TrainableCounter(Module):
    def __init__(self, delta=1.0):
        super().__init__()

        self.register_state("count", 0)
        self.register_parameter("delta", delta)

    def step(self):
        self.count = self.count + self.delta

    def __str__(self):
        return f"{self.__class__.__name__}(count={self.count}, delta={self.delta})"

In [17]:
class TwoCounters(Module):
    def __init__(self):
        super().__init__()

        self.counter_1 = TrainableCounter(delta=1.0)
        self.counter_2 = TrainableCounter(delta=2.0)

    def step(self):
        self.counter_1.step()
        self.counter_2.step()

In [18]:
two_counters = TwoCounters()
print(two_counters.name_to_kind)
print(two_counters.counter_1, two_counters.counter_2)

OrderedDict([('counter_1', <PaxKind.MODULE: 3>), ('counter_2', <PaxKind.MODULE: 3>)])
TrainableCounter(count=0, delta=1.0) TrainableCounter(count=0, delta=2.0)


In [19]:
leaves, treedef = jax.tree_flatten(two_counters)
print(leaves, treedef)
new_mod = jax.tree_unflatten(treedef, leaves)
print(new_mod.name_to_kind)
print(new_mod.counter_1, new_mod.counter_2)

[0, 1.0, 0, 2.0] PyTreeDef(CustomNode(<class '__main__.TwoCounters'>[(['counter_1', 'counter_2'], [('name_to_kind', mappingproxy(OrderedDict([('counter_1', <PaxKind.MODULE: 3>), ('counter_2', <PaxKind.MODULE: 3>)]))), ('is_training', True)])], [CustomNode(<class '__main__.TrainableCounter'>[(['count', 'delta'], [('name_to_kind', mappingproxy(OrderedDict([('count', <PaxKind.STATE: 2>), ('delta', <PaxKind.PARAMETER: 1>)]))), ('is_training', True)])], [*, *]), CustomNode(<class '__main__.TrainableCounter'>[(['count', 'delta'], [('name_to_kind', mappingproxy(OrderedDict([('count', <PaxKind.STATE: 2>), ('delta', <PaxKind.PARAMETER: 1>)]))), ('is_training', True)])], [*, *])]))
OrderedDict([('counter_1', <PaxKind.MODULE: 3>), ('counter_2', <PaxKind.MODULE: 3>)])
TrainableCounter(count=0, delta=1.0) TrainableCounter(count=0, delta=2.0)
