## The performance cost of using Pax

Even though `jax.jit` can eliminate almost all performance penalties related to Pax, there is a small cost of calling `tree_flatten` and `tree_unflatten` for the inputs and outputs of a jitted function.

In this tutorial, we will measure the performance cost of using Pax. We also introduce practices that help to reduce the cost.

**Note**: This cost is usually less than 1% of the training time. Most of the time, we can ignore it.

Let us start with a simple code for training a ResNet50 classifier.

In [None]:
import pax, jax, opax
import jax.numpy as jnp
from pax.nets import ResNet50

def loss_fn(params: ResNet50, model: ResNet50, inputs):
    images, labels = inputs
    model = model.update(params)
    log_pr = jax.nn.log_softmax(model(images), axis=-1)
    loss = jnp.mean(jnp.sum(jax.nn.one_hot(labels, num_classes=10) * log_pr, axis=-1))
    return loss, (loss, model)

@pax.jit
def update(model: ResNet50, optimizer: opax.GradientTransformation, inputs):
    grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(model.parameters(), model, inputs)
    model = model.update(
        optimizer.step(grads, model.parameters()),
    )
    return loss, model, optimizer

net = ResNet50(3, 10)
optimizer = opax.adam(1e-4)(net.parameters())

rng_key = jax.random.PRNGKey(42)
img = jax.random.normal(rng_key,  (1, 3, 64, 64))
label = jax.random.randint(rng_key, (1,), 0, 10)
# loss, net, optimizer = update(net, optimizer, (img, label))



In [None]:
import time 
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net, optimizer))
    (net, optimizer) = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 39.52964387499999


It takes 39 seconds to execute 10,000 iterations of `tree_flatten` and `tree_unflatten` for  the tuple `(net, optimizer)`.

This is the cost of using Pax with a ResNet50 network and an `opax.adam` optimizer.

### Flatten optimizer

One easy way to reduce the cost is to use the `flatten` mode supported by `opax` optimizers.

In [None]:
optimizer = opax.adam(1e-4)(net.parameters(), flatten=True)

In this mode, the optimizer will automatically flatten the parameters and gradients to a list of leaves instead of dealing with the full tree structure. This reduces the `flatten` and `unflatten` cost of the optimizer to almost zero.

However, we are no longer able to access the optimizer's pytree objects. 
Fortunately, we rarely need to access the optimizer's pytree objects, and one can easily convert the flatten list back to the pytree object using `jax.tree_unflatten` function.


In [None]:
import time 
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net, optimizer))
    (net, optimizer) = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 9.955761487000004


In [None]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten(optimizer)
    optimizer = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 0.40090021199999626


With `flatten=True` we reduce the time to only 9.9 seconds. And the time to flatten/unflatten the `optimizer` alone is almost zeros (0.40 seconds)

### Multi-step update function

Another solution to reduce Pax's cost is to execute multiple update steps inside a jitted function.

In [None]:
num_steps = 10

@pax.jit
def multistep_update(model, optimizer, inputs):
    def scan_loop(prev, batch):
        model, optimizer = prev
        loss, model, optimizer = update(model, optimizer, batch)
        return (model, optimizer), loss
    
    (model, optimizer), loss = pax.utils.scan(scan_loop, (model, optimizer), inputs)
    loss = jnp.array(0.)
    return loss, model, optimizer

multistep_img = jax.random.normal(rng_key,  (num_steps, 1, 3, 64, 64))
multistep_label = jax.random.randint(rng_key, (num_steps, 1,), 0, 10)
# loss, net, optimizer = multistep_update(net, optimizer, (multistep_img, multistep_label))

The `multistep_update` function will execute multiple update steps in a single call.
If `num_steps=10`, we can reduce the Pax's cost by a factor of `10`.

**Note**:  This approach of executing multiple update steps inside a jitted function is also very useful for TPU. It reduces the communication cost between CPU host and TPU cores, significantly.

### Flatten model

We have reduced Pax's cost significantly with a little effort. This final solution will reduce to cost to almost zero. However, it has downsides too.

The idea is simple: we want to move `flatten` and `unflatten` to the inside of the update function.

In [None]:
from functools import partial

@partial(pax.jit, static_argnums=0)
def flatten_update(model_def, model_leaves, optimizer: opax.GradientTransformation, inputs):
    model = jax.tree_unflatten(model_def, model_leaves)
    grads, (loss, model) = pax.grad(loss_fn, has_aux=True)(model.parameters(), model, inputs)
    model = model.update(
        optimizer.step(grads, model.parameters()),
    )
    return loss, jax.tree_leaves(model), optimizer

In [None]:
net_leaves, net_def = jax.tree_flatten(net)

# for i in range(10_000):
#     loss, net_leaves, optimizer = flatten_update(net_def, net_leaves, optimizer, (img, label))

In [None]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_flatten((net_leaves, optimizer))
    (net_leaves, optimizer) = jax.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end-start)

Duration: 0.7107760669999976


We have reduced the cost to almost zero. However, we have to recreate the model manually from its leaves and tree_def when needed.

In [None]:
net = jax.tree_unflatten(net_def, net_leaves)