# Understanding PAX's module

This tutorial shows how to build a PAX-like module for neural network training from scratch.

## Pytree

In [1]:
from copy import copy
import jax
import numpy as np
import jax.numpy as jnp
from absl import logging

In [2]:
logging.set_verbosity(logging.FATAL)

First, let's talk about _pytree_.

Pytrees are tree-like structures that are constructed from Python object containers. Here are a few examples of pytree:

In [3]:
a = 123
b = [1, 2, 3]
d = (1, 2, 3)
c = {"1": 1, "2": 2, "3": 3}
e = [(1, 2), "123", {"1": 1, "2": [4, 5]}]

JAX provides the `jax.tree_util.tree_flatten` function that transforms an object into its tree representation that includes:

- `leaves`: a list of tree leaves.
- `treedef`: information about the structure of the tree.

In [4]:
leaves, treedef = jax.tree_util.tree_flatten(e)
print("Leaves:", leaves)
print("TreeDef:", treedef)

Leaves: [1, 2, '123', 1, 4, 5]
TreeDef: PyTreeDef([(*, *), *, {'1': *, '2': [*, *]}])


.. note:: Even though a pytree can have any object at its leaves, many jax functions such as ``jax.jit``, ``jax.lax.scan``, ``jax.grad``, etc. only support pytrees with `ndarray` leaves.

We can reverse ``jax.tree_util.tree_flatten`` transformation with ``jax.tree_util.tree_unflatten``:

In [5]:
jax.tree_util.tree_unflatten(treedef=treedef, leaves=leaves)

[(1, 2), '123', {'1': 1, '2': [4, 5]}]

## A simple PAX module

Now let's try to build a simple PAX module. The core idea here is that:

> **A module is also a pytree.**

To let JAX knows how to flatten and unflatten a _pytree_ module:

1. It needs to implement two methods: ``tree_flatten`` and ``tree_unflatten``.
2. It is registered as a pytree node.

In [6]:
@jax.tree_util.register_pytree_node_class
class ModuleV0:
    def __init__(self, mylist):
        self.mylist = mylist
        self.is_training = True

    def tree_flatten(self):
        chilren = [self.mylist]
        aux_info = {"is_training": self.is_training}
        return chilren, aux_info

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        new_object = cls.__new__(cls)
        new_object.mylist = children[0]
        new_object.is_training = aux_info["is_training"]
        return new_object

    def __repr__(self):
        name = self.__class__.__name__
        info = f"mylist={self.mylist}, is_training={self.is_training}"
        return f"{name}({info})"

The function ``jax.tree_util.register_pytree_node_class`` registers `Module` as a class of pytree nodes.

Let's try to flatten and unflatten a module.

In [7]:
mod = ModuleV0([1, 2, 3])
print(mod)
leaves, tree_def = jax.tree_util.tree_flatten(mod)
print(leaves, tree_def)
new_mod = jax.tree_util.tree_unflatten(tree_def, leaves)
new_mod

ModuleV0(mylist=[1, 2, 3], is_training=True)
[1, 2, 3] PyTreeDef(CustomNode(<class '__main__.ModuleV0'>[{'is_training': True}], [[*, *, *]]))


ModuleV0(mylist=[1, 2, 3], is_training=True)

**Note:** ``is_training`` is considered as part of the PyTreeDef.

## Introducing ``register_subtree`` method

OK, but our pytree module only supports ``mylist`` and ``is_training`` attributes. A _real_ module for neural network training can have an arbitrary number of attributes.

Moreover, how can our module know that ``mylist`` is part of the subtree while ``is_training`` belongs to the tree definition?

One solution is:

1. to keep a set (namely, ``tree_part_names``) that tells if an attribute is part of the tree or not.
2. users need to _register_ if an attribute is part of the tree.
3. any attribute that is not registered belongs to the tree definition.

In [8]:
@jax.tree_util.register_pytree_node_class
class ModuleV1(ModuleV0):
    def __init__(self):
        self.tree_part_names = frozenset()
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if name in self.tree_part_names:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def register_subtree(self, name, value):
        self.__dict__[name] = value
        self.tree_part_names = self.tree_part_names.union([name])

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

Our new module has ``register_subtree`` method that adds attribute's name to the ``tree_part_names`` set. 

The ``tree_flatten`` method lists all attributes of the object and checks if its name is in ``tree_part_names`` or not. If it is, its value will be added to the ``children`` list, otherwise, ``(name, value)`` will be added to the ``others`` list. 

The ``tree_unflatten`` method combines information from ``others``, ``children_names``, and ``children`` to reconstruct the module.

**Note:** 

1. We purposely use `frozenset` to guarantee that any modification of `tree_part_names` in one module does not affect other modules. 
(However, this is not guaranteed for other attributes of the module.)
2. `__init_subclass__` ensures any subclass of `Module` is registered as pytree node.

Let's try our module with a simple counter:

In [9]:
class Counter(ModuleV1):
    def __init__(self):
        super().__init__()

        self.register_subtree("count", 0)

    def step(self):
        self.count = self.count + 1

    def __repr__(self):
        return f"{self.__class__.__name__}(count={self.count})"

In [10]:
counter = Counter()
print(counter)
counter.step()
print(counter)
leaves, treedef = jax.tree_util.tree_flatten(counter)
print((leaves, treedef))
new_counter = jax.tree_util.tree_unflatten(treedef, leaves)
print(new_counter)

Counter(count=0)
Counter(count=1)
([1], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count'], [('tree_part_names', frozenset({'count'})), ('is_training', True)])], [*])))
Counter(count=1)


## A custom `parameters` method

Our module does not have a way to select trainable parameters. We need this feature for gradient computation.

PAX's solution is to let the user implement a `parameters()` method themself. For example:

In [11]:
class ModuleV2(ModuleV1):
    def replace(self, **kwargs):
        mod = copy(self)
        for name, value in kwargs.items():
            setattr(mod, name, value)
        return mod

    def parameters(self):
        weights = {}
        for name in self.tree_part_names:
            value = getattr(self, name)
            value = value.parameters() if isinstance(value, ModuleV2) else None
            weights[name] = value
        return self.replace(**weights)

In [12]:
class Linear(ModuleV2):
    def __init__(self):
        super().__init__()
        self.register_subtree("weight", jnp.array(1.0))
        self.register_subtree("bias", jnp.array(0.0))
        self.register_subtree("count", jnp.array(0))

    def parameters(self):
        return super().parameters().replace(weight=self.weight, bias=self.bias)

    def __call__(self, x):
        self.count += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"

In [13]:
fc = Linear()
x = 2.0
y = fc(x)
print(fc)
print(fc.parameters())

Linear(weight=1.0, bias=0.0, count=1)
Linear(weight=1.0, bias=0.0, count=None)


However, it is a bit inconvenient that we have to implement a `parameters` method ourselves.
Below is a utility function that does the job for us.

In [14]:
def parameters_method(*trainable_weights):
    def _parameters(self):
        values = {name: getattr(self, name) for name in trainable_weights}
        return super(self.__class__, self).parameters().replace(**values)

    return _parameters

In [15]:
class Linear(ModuleV2):
    def __init__(self):
        super().__init__()
        self.register_subtree("weight", jnp.array(1.0))
        self.register_subtree("bias", jnp.array(0.0))
        self.register_subtree("count", jnp.array(0))

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"

## Find and register subtrees

It is inconvenient that we have to register subtrees manually, 
we can have a method that detects subtree attributes for us.

In [16]:
class ModuleV3(ModuleV2):
    def find_and_register_subtree(self):
        for name, value in self.__dict__.items():
            is_pytree = lambda x: isinstance(x, (np.ndarray, jnp.ndarray, ModuleV3))
            leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_pytree)
            if any(map(is_pytree, leaves)):
                self.register_subtree(name, value)

In [17]:
class Linear(ModuleV3):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)
        self.count = jnp.array(0)
        self.find_and_register_subtree()

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"

In [18]:
fc = Linear()
fc.tree_part_names

frozenset({'bias', 'count', 'weight'})

## Metaclass

We can get rid of calling `self.find_and_register_subtree()` explicitly by using metaclass.

In [19]:
class ModuleMetaclass(type):
    def __call__(cls, *args, **kwargs):
        module = cls.__new__(cls, *args, **kwargs)
        cls.__init__(module, *args, **kwargs)
        module.find_and_register_subtree()
        return module

In [20]:
class ModuleV4(ModuleV3, metaclass=ModuleMetaclass):
    pass

In [21]:
class Linear(ModuleV4):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)
        self.count = jnp.array(0)

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"

In [22]:
fc = Linear()
fc.tree_part_names

frozenset({'bias', 'count', 'weight'})