# Immutability

Pax allows users to modify Pax's modules. However, many Jax transformations such as `jax.jit`, `jax.grad`, `jax.vmap`, `jax.pmap` requires a pure functions.

Therefore, it is a good practice to keep Pax's modules immutable. This will prevent many undefined behaviors due to side effects.

By default, Pax's modules are immutable. However, Pax provides `pax.mutate` transformation which enables mutable mode when needed (for example, doing model surgery).

In [None]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git

In [1]:
import pax
import jax.numpy as jnp

class M1(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.counter = 0
    
    def __call__(self, x):
        self.counter = self.counter + 1
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(counter={self.counter})'

x = jnp.zeros((3, 2))
m1 = M1()
print(m1)
m1(x)
print(m1)

with pax.ctx.immutable():
    m1(x)
    print(m1)



M1(counter=0)


ValueError: Cannot assign an attribute of kind `PaxFieldKind.OTHERS` in immutable mode.

Unfortunately, Pax cannot catch all mutated objects. For example, it cannot detect if we modified a container.

In [3]:
class M2(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.a_list = [0]
    
    def __call__(self, x):
        self.a_list.append(0)
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(a_list={self.a_list})'

x = jnp.zeros((3, 2))
m2 = M2()
print(m2)
m2(x)
print(m2)

m2(x)
print(m2)

M2(a_list=[0])
M2(a_list=[0, 0])
M2(a_list=[0, 0, 0])


In these cases, Pax does the second-best thing: it guarantees that copies
of a module will not be affected when the original module is mutated.

In [4]:
print("m2 before:", m2)
m3 = m2.copy()
m2(x)
print("m3       :", m3)
print("m2 after :", m2)

m2 before: M2(a_list=[0, 0, 0])
m3       : M2(a_list=[0, 0, 0])
m2 after : M2(a_list=[0, 0, 0, 0])


Because immutability is important for Jax's transformations, Pax provides wrappers of these transformations with additional safeguards turned on. Pax includes: ``pax.jit``, ``pax.grad``, ``pax.vmap``, ``pax.pmap`` as alternatives to Jax's transformations.