# Improve performance

Even though `jax.jit` can eliminate almost all performance penalties related to Pax, there is a small cost of calling `tree_flatten` and `tree_unflatten` for the inputs and outputs of a jitted function.

In this tutorial, we will measure Pax's performance. We also introduce practices that help to improve the performance.

.. note::
    Pax's performance penalties are usually less than 1% of the training time. Most of the time, we can ignore it.

Let us start with a simple code for training a ResNet50 classifier.

In [None]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git

In [1]:
import pax, jax, opax
import jax.numpy as jnp
from pax.nets import ResNet50

def loss_fn(model: ResNet50, inputs):
    images, labels = inputs
    log_pr = jax.nn.log_softmax(model(images), axis=-1)
    loss = jnp.mean(jnp.sum(jax.nn.one_hot(labels, num_classes=10) * log_pr, axis=-1))
    return loss, (loss, model)

@pax.jit
def update(model, optimizer, inputs):
    grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(model, inputs)
    model, optimizer = pax.apply_gradients(model, optimizer, grads=grads)
    return model, optimizer, loss

net = ResNet50(3, 10)
optimizer = opax.adam(1e-4)(net.parameters())

rng_key = jax.random.PRNGKey(42)
img = jax.random.normal(rng_key,  (1, 3, 64, 64))
label = jax.random.randint(rng_key, (1,), 0, 10)
# loss, net, optimizer = update(net, optimizer, (img, label))



In [2]:
import time 
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net, optimizer))
    net, optimizer = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 31.03061650000018


It takes 31.0 seconds to execute 10,000 iterations of `tree_flatten` and `tree_unflatten` for  the tuple `net, optimizer`.

This is approximately the extra time (we have to wait) when training a ResNet50 network with an `opax.adam` optimizer for 10,000 iterations.

## Flatten optimizer

One easy way to reduce the time is to use the `flatten` mode supported by `opax` optimizers.

In [3]:
optimizer = opax.adam(1e-4)(net.parameters(), flatten=True)

In this mode, the optimizer will automatically flatten the parameters and gradients to a list of leaves instead of dealing with the full tree structure. This reduces the `flatten` and `unflatten` time of the optimizer to almost zero.

However, we are no longer able to access the optimizer's pytree objects. 
Fortunately, we rarely need to access the optimizer's pytree objects, and one can easily convert the flatten list back to the pytree object using `jax.tree_unflatten` function.


In [4]:
import time 
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net, optimizer))
    net, optimizer = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 7.258846299999277


In [5]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten(optimizer)
    optimizer = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 0.38130159999855096


With `flatten=True` we reduce the time to only 7.2 seconds. And the time to `flatten`/`unflatten` the `optimizer` alone is close to zero (0.38 seconds).

## Multi-step update function

Another solution to reduce the time for `flatten`/`unflatten` is to execute multiple update steps inside a jitted function.

In [6]:
num_steps = 10

@pax.jit
def multistep_update(model, optimizer, inputs):
    def _step(m_o, i):
        m, o, aux = update(*m_o, i)
        return (m, o), aux
    (model, optimizer), losses = pax.utils.scan(_step, (model, optimizer), inputs)
    return model, optimizer, jnp.mean(losses)

multistep_img = jax.random.normal(rng_key,  (num_steps, 1, 3, 64, 64))
multistep_label = jax.random.randint(rng_key, (num_steps, 1,), 0, 10)
# loss, net, optimizer = multistep_update(net, optimizer, (multistep_img, multistep_label))

The `multistep_update` function will execute multiple update steps in a single call.
If `num_steps=10`, we can reduce the time by a factor of `10`.

.. note::
    The practice of executing multiple update steps inside a jitted function is also very useful for TPU. It reduces the communication cost between CPU host and TPU cores, significantly.

## Flatten model

We have reduced the time to `flatten`/`unflatten` the optimizer to almost zero. We can do the same thing for the model too.

The idea is simple: we want to move `flatten` and `unflatten` to the inside of the update function.

In [7]:
from functools import partial

@partial(pax.jit, static_argnums=0)
def flatten_update(model_def, model_leaves_and_optimizer, inputs):
    model_leaves, optimizer = model_leaves_and_optimizer
    model = jax.tree_unflatten(model_def, model_leaves)
    params = model.parameters()
    grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(params, model, inputs)
    model, optimizer = pax.apply_gradients(model, optimizer, grads=grads)
    return (jax.tree_leaves(model), optimizer), loss

In [8]:
net_leaves, net_def = jax.tree_flatten(net)
# loss, net_leaves, optimizer = flatten_update(net_def, net_leaves, optimizer, (img, label))

In [9]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net_leaves, optimizer))
    (net_leaves, optimizer) = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 0.6903635999988182


We now only wait an extra time of `0.69` seconds when training a ResNet50 for 10,000 steps.

However, we have to manually recreate the model from its leaves and tree_def.

In [10]:
net = jax.tree_unflatten(net_def, net_leaves)

Pax provides a similar functionality with `pax.nn.FlattenModule`. It creates a new module with all parameters and states are flatten.

In [None]:
flatten_net = pax.flatten_module(net)

# to recreate the original module
net = flatten_net.unflatten() 

The functionality of `flatten_module` is limited as it is designed for performance purpose only.