# JAX transformations

In this tutorial, we provide pieces of advice on mixing PAX and JAX transformations.


In [1]:
import jax
import jax.numpy as jnp
import pax
from typing import Dict
from absl import logging

In [2]:
%xmode Minimal
logging.set_verbosity(logging.FATAL)

Exception reporting mode: Minimal


JAX transformations have a similar effect on a function as `pax.pure` does.
We can only access a copy of the inputs. Any modification on the copy will not affect the original inputs.

Let's try with a simple example:

In [3]:
def print_id_and_value(c: Dict[str, int], msg=""):
    print(f'({msg}) id {id(c)}  counter {c["count"]}')

In [4]:
@jax.jit
def increase_counter(c):
    c["count"] += 1  # increase counter
    print_id_and_value(c, "inside")

In [5]:
c = {"count": 1}
print_id_and_value(c, "before")
increase_counter(c)
print_id_and_value(c, "after ")

(before) id 140438671691264  counter 1
(inside) id 140437229406400  counter Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
(after ) id 140438671691264  counter 1


Note that, inside the jitted function `increase_counter`, the counter `c` is a different object (different _id_) compared to the counter `c` outside of the function. Therefore, modifications of `c` inside `increase_counter` will not affect the `c` outside.

This behavior is very similar to `pax.pure`. In fact, `pax.pure` mimics this behavior from JAX transformations.

Now, things get complicated when we use JAX transformations inside a function decorated by `pax.pure`.

In the following toy example, we have a `RNN` module that uses `jax.lax.scan` to _scan_ the inputs with the function `scan_fn`.
Besides, the `scan_fn` function also updates the internal state of a `Counter` module.

In [6]:
class Counter(pax.StateModule):
    def __init__(self):
        self.count = jnp.array(0)

    def __call__(self, x):
        self.count = self.count + 1
        return x

In [7]:
class RNN(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, xs):
        def scan_fn(c: Counter, x):
            y = c(x)
            return c, y

        _, y = jax.lax.scan(scan_fn, init=self.counter, xs=xs)
        return y

In [8]:
rnn = RNN()
xs = jnp.arange(0, 10)
rnn, ys = pax.module_and_value(rnn)(xs)

ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

Oops! PAX prevents us to update the counter even though we did run `rnn` with `pax.module_and_value`.

This is because `jax.lax.scan`, similar to `jax.jit`, executes the function `scan_fn` on a copy of its input modules. Moreover, this copy is immutable in our case.

.. note::
    Only input modules of functions decorated by `pax.pure` are mutable. A _copy_ of an input module is still immutable.

In this case, we have to use `pax.module_and_value` inside `scan_fn`. Below is a working implementation:

In [9]:
class RNN(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, xs):
        def scan_fn(c: Counter, x):
            c, y = pax.module_and_value(c)(x)
            return c, y

        self.counter, y = jax.lax.scan(scan_fn, init=self.counter, xs=xs)
        return y

In [10]:
rnn = RNN()
xs = jnp.arange(0, 10)
rnn, ys = pax.module_and_value(rnn)(xs)
print(f"Count = {rnn.counter.count}")

Count = 10


Now, let's try another example. In the following, we have a jitted function `fn` trying to call `self.counter`.

In [11]:
class BadModule(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, x):
        @jax.jit
        def fn(x):
            y = self.counter(x)
            return y

        y = fn(x)
        return y

In [12]:
mod = BadModule()
x = jnp.array(0.0)
mod, y = pax.module_and_value(mod)(x)

ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

In this example, PAX also prevents `fn` to modify `self.counter`.
This is PAX's mechanism to prevent leaks
when a traced function at a higher level of abstraction 
trying to modify a module that is created at a lower level of abstraction.

.. note:: 
    All modules created at lower levels of abstraction than the current level are immutable.

A correct implementation should pass `self.counter` as an argument to the function `fn`.

In [13]:
class GoodModule(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, x):
        @jax.jit
        def fn(c: Counter, x):
            c, y = pax.module_and_value(c)(x)
            return c, y

        self.counter, y = fn(self.counter, x)
        return y

In [14]:
mod = GoodModule()
x = jnp.array(0.0)
print(f"Count = {mod.counter.count}")
mod, y = pax.module_and_value(mod)(x)
print(f"Count = {mod.counter.count}")

Count = 0
Count = 1
