# Jax transformations

In this tutorial, we provide pieces of advice on mixing PAX and JAX transformations.


JAX transformations have a similar effect on a function as `pax.pure` does. We can only access a copy of the inputs. Any modification on the copy will not affect the original inputs.

In [1]:
import jax
import jax.numpy as jnp
import pax
from typing import Dict


def print_id_and_value(c: Dict[str, int], msg=""):
    print(f'({msg}) id {id(c)}  counter {c["count"]}')


@jax.jit
def increase_counter(c):
    c["count"] += 1  # increase counter
    print_id_and_value(c, "inside")


c = {"count": 1}
print_id_and_value(c, "before")
increase_counter(c)
print_id_and_value(c, "after ")



(before) id 140557305163584  counter 1
(inside) id 140555907601088  counter Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
(after ) id 140557305163584  counter 1


Note that inside the jitted function counter `c` is a different object (different id) compared to the counter `c` outside of the function. Therefore, modifications of `c` inside `increase_counter` will not affect `c` outside.

This behavior is very similar to `pax.pure`. In fact, `pax.pure` mimics this behavior from JAX transformations.


Things get complicated when we use JAX transformations inside a function decorated by `pax.pure`.

In [2]:
class Counter(pax.Module):
    counter: jnp.ndarray

    def __init__(self):
        super().__init__()
        self.register_state("counter", jnp.array(0))

    def __call__(self, x):
        self.counter = self.counter + 1
        return x


class RNN(pax.Module):
    def __init__(self):
        super().__init__()
        self.counter = Counter()

    def __call__(self, xs):
        def _f(c: Counter, x):
            y = c(x)
            return c, y

        _, y = jax.lax.scan(_f, init=self.counter, xs=xs)
        return y


rnn = RNN()
xs = jnp.arange(0, 10)

# ys = rnn(xs) # <-- this does not work because PAX modules are immutable by default
# rnn, ys = pax.module_and_value(rnn)(xs) # <-- this also does not work because `jax.lax.scan` uses a immutable copy of self.counter

`jax.lax.scan`, similar to `jax.jit`, prevents side effects to happen by copying input modules. In this case, we have to use `pax.module_and_value` inside of `_f`. Below is the correct implementation:

In [3]:
class RNNv2(RNN):
    def __call__(self, xs):
        def _f(c: Counter, x):
            c, y = pax.module_and_value(c)(x)
            return c, y

        self.counter, y = jax.lax.scan(_f, init=self.counter, xs=xs)
        return y


rnn = RNNv2()
xs = jnp.arange(0, 10)

rnn, ys = pax.module_and_value(rnn)(xs)  # <-- this works correctly!
print(f"Counter = {rnn.counter.counter}")

Counter = 10


JAX leaks can happen if we execute JAX transformations with mutability turned on.

In [4]:
class RNNv3(RNN):
    def __call__(self, xs):
        @jax.jit
        def _f(x):
            y = self.counter(x)  #  <-- bad
            return y

        y = _f(xs[0])
        return y


rnn = RNNv3()
xs = jnp.arange(0, 10)

# rnn, ys = pax.module_and_value(rnn)(xs)  <-- throw exception
# [...]
# Exception: Leaked sublevel 1. Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/1)>].

The statement `y = self.counter(x)` causes an JAX leaked exception. 
JAX leaked tracers detected that `self.counter.counter` is modified even though it is not an input of the jitted function `_f`. A correct implementation should pass `self.counter` as an argument to the function `_f`.

In [5]:
class RNNv4(RNN):
    def __call__(self, xs):
        @jax.jit
        def _f(c: Counter, x):
            c, y = pax.module_and_value(c)(x)  #
            return c, y  # return c to make updates available outside of _f

        self.counter, y = _f(self.counter, xs[0])
        return y


rnn = RNNv4()
xs = jnp.arange(0, 10)

print(f"Counter = {rnn.counter.counter}")
rnn, ys = pax.module_and_value(rnn)(xs)
print(f"Counter = {rnn.counter.counter}")

Counter = 0
Counter = 1
