# Side effects

In this tutorial, we show how to control side effects when programming with PAX. You will learn:

1. Side effects can happen to PAX's Modules.
2. Side effects stop with JAX's transformations.
3. We can emulate side effects by returning the inputs.

## PAX modules can have side effects

A PAX's module is a stateful Pytree. It can store or modify its internal states like a normal Python object.

We have here a simple pax module that has an internal counter to count the number of times `__call__` is executed.

In [1]:
import pax


class M(pax.Module):
    def __init__(self):
        super().__init__()
        self.counter = 0

    def __call__(self, x):
        self.counter += 1
        return x


m = M()
m(0)
print(m.counter)  # 1
m(0)
print(m.counter)  # 2

1
2


Everything works as expected because `m` is just a normal Python object.

## The great purifier (a.k.a JAX's tracer)

Things get complicated when we use JAX's transformations. 

First, JAX's transformations are very useful. 
For example, `jax.grad` transforms a function to its gradient function. 
`jax.jit` compiles a python function to machine code that can be executed on CPU, GPU or TPU. 
These transformations are essential for machine learning applications.

However, these transformations also require input functions to have no side effects.
If an input function has side effects, these side effects will be automatically erased as a "side effect" of the transformation.

Let's see how it happens by trying to compile a simple function with `jax.jit`.

In [2]:
import jax


@jax.jit
def forward(m: M):
    m(0)
    print(f"ID = {id(m)} Counter (inside) = {m.counter}")


print(f"ID = {id(m)} Counter (before) = {m.counter}")
forward(m)
print(f"ID = {id(m)} Counter (after)  = {m.counter}")



ID = 140424167503328 Counter (before) = 2
ID = 140422720409760 Counter (inside) = 3
ID = 140424167503328 Counter (after)  = 2


What is happening here? `m.counter` increased to `3` inside `forward` function, then, decreased to `2` again?

Looking closely at the `ID` column, we realize that `m.counter` is increased for a different object, not our original `m`. 

To explain this behavior, we have to know how `jax.jit` compiles a function. Roughly, it traces the function by feeding the function with fake inputs what helps to record all the operations on the inputs and the returned outputs. The result of this tracing process is a JAX expression (jaxpr), basically, a JAX's representation of our python function. Finally, JAX compiles the expression to machine code using the XLA compiler.

Let's print out the jaxpr of our `forward` function

In [3]:
def print_xpr(xpr):
    print(f"JAX xpr:\n===\n{xpr}\n===")


xpr = jax.make_jaxpr(forward)(m)
print_xpr(xpr)

ID = 140422686347328 Counter (inside) = 3
JAX xpr:
===
{ lambda  ; .
  let  = xla_call[ backend=None
                   call_jaxpr={ lambda  ; .
                                let 
                                in () }
                   device=None
                   donated_invars=(  )
                   inline=False
                   name=forward ] 
  in () }
===


Firstly, it printed out (a different) ID and Counter during the tracing process. Secondly, there is no `counter += 1` in the returned JAX expression.

Indeed, `m.counter` is not considered as part of the JAX expression. Because it is not a `ndarray` leaf of the `m` Pytree.
We need to convert `counter` to a `ndarray` value and register it as a state of the Pytree using the `self.register_state` method.

In [4]:
import jax.numpy as jnp


class M1(pax.Module):
    counter: jnp.ndarray

    def __init__(self):
        super().__init__()
        self.register_state("counter", jnp.array(0))

    def __call__(self, x):
        self.counter += 1
        return x


m1 = M1()


@jax.jit
def forward_1(m: M1):
    m(0)
    print(f"ID = {id(m)} Counter (inside) = {m.counter}")


xpr = jax.make_jaxpr(forward_1)(m1)
print_xpr(xpr)

ID = 140422686437728 Counter (inside) = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/1)>
JAX xpr:
===
{ lambda  ; a.
  let  = xla_call[ backend=None
                   call_jaxpr={ lambda  ; a.
                                let _ = add a 1
                                in () }
                   device=None
                   donated_invars=(False,)
                   inline=False
                   name=forward_1 ] a
  in () }
===


Now, we can see that during the tracing process, `m.counter` is replaced by a fake input "Traced Array".
We also see `_ = add a 1` in the returned jax expression as `self.counter += 1` is traced.

Let's run our new `forward_1` function.

In [5]:
print(f"ID = {id(m1)} Counter (before) = {m1.counter}")
forward_1(m1)
print(f"ID = {id(m1)} Counter (after)  = {m1.counter}")

ID = 140422686349344 Counter (before) = 0
ID = 140422686347856 Counter (inside) = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
ID = 140422686349344 Counter (after)  = 0


Our `m1.counter` is still not increased after `forward_1` is executed. This is because JAX expressions have no side effects.
`self.counter` before and after `self.counter += 1` are two different variables. We can spot this in the `_ = add a 1` expression, the result is assigned to `_` not `a`.

This is a core concept of JAX. Its expressions are pure mathematical functions.

Note that: a pure function `f` is considered as a **major** advantage. We now can do `jax.grad(f)` because `f` is a mathematical function. We can run many instances of `f` in parallel with `jax.vmap` and on multiple cores with `jax.pmap` because `f` has no side effects. In short, pure functions are beautiful!

## Side effect strikes back

However, we still want to update `counter` though! 

The solution is simple. We just need to return the updated `m` as an output of our `forward` function.

In [6]:
@jax.jit
def forward_2(m: M1):
    m(0)
    print(f"ID = {id(m)} Counter (inside) = {m.counter}")
    return m


xpr = jax.make_jaxpr(forward_2)(m1)
print_xpr(xpr)

ID = 140422686439024 Counter (inside) = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/1)>
JAX xpr:
===
{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b = add a 1
                                 in (b,) }
                    device=None
                    donated_invars=(False,)
                    inline=False
                    name=forward_2 ] a
  in (b,) }
===


As we can see in the traced expression, the updated counter `b` (`b = a + 1`) is now included in the output.

In [7]:
print(f"ID = {id(m1)} Counter (before) = {m1.counter}")
m1 = forward_2(m1)
print(f"ID = {id(m1)} Counter (after)  = {m1.counter}")

ID = 140422686349344 Counter (before) = 0
ID = 140422686439312 Counter (inside) = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
ID = 140424176335504 Counter (after)  = 1


Note that the id of `m1` after `forward_2` is different from the id of `m1` we started with because it is a new python object.

## Summary

- We can say that PAX modules work as normal python objects both inside and outside of a JAX's transformation,
- the *PROBLEM* is that they are not the same object.
- When a pax module is passed to a transformed function, JAX will create a copy of this module and pass the copy to the function.
- As a result, any modification on the clone will not affect the original module.
- We need to return the modified clone to emulate side effects.