# Immutability

Many Jax transformations such as `jax.jit`, `jax.grad`, `jax.vmap`, `jax.pmap` requires pure functions.

Therefore, it is a good practice to keep Pax's modules immutable. This will prevent many undefined behaviors due to side effects.

By default, Pax's modules are immutable. However, Pax provides `pax.mutate` transformation which enables mutable mode when needed (for example, doing model surgery).

In [None]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git

In [1]:
import pax
import jax.numpy as jnp

class M1(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.counter = 0
    
    def __call__(self, x):
        self.counter = self.counter + 1
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(counter={self.counter})'

x = jnp.zeros((3, 2))
m1 = M1()
print(m1)
m1(x)
print(m1)

with pax.ctx.immutable():
    m1(x)
    print(m1)



M1(counter=0)


ValueError: Cannot assign an attribute of kind `PaxFieldKind.OTHERS` of a frozen module.

In this case, Pax detects that `self.counter` is being modified. It raises a ValueError to prevent that to happen.

Unfortunately, Pax cannot catch all mutated objects directly. For example, it cannot detect if we are modifying a list.

In these cases, Pax keeps the original tree definition of the module after initialization, and, it will check for any modification of 
the module every time a method is called.

In [2]:
class M2(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.a_list = [0]
    
    def __call__(self, x):
        self.a_list.append(0)
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(a_list={self.a_list})'

x = jnp.zeros((3, 2))
m2 = M2()
print(m2)
m2(x)
print(m2)

m2(x)
print(m2)

M2(a_list=[0])


ValueError: The module `M2(a_list=[0, 0])` has its treedef modified.
--- PyTreeDef(CustomNode(<class '__main__.M2'>[{'_pax': PaxModuleInfo[name=None, training=True, frozen=False, nodes=[fc:MODULE]], 'a_list': [0]}], [CustomNode(<class 'pax._src.nn.linear.Linear'>[{'_pax': PaxModuleInfo[name=None, training=True, frozen=False, nodes=[weight:PARAMETER, bias:PARAMETER]], 'in_dim': 2, 'out_dim': 2, 'with_bias': True}], [*, *])]))
+++ PyTreeDef(CustomNode(<class '__main__.M2'>[{'_pax': PaxModuleInfo[name=None, training=True, frozen=False, nodes=[fc:MODULE]], 'a_list': [0, 0]}], [CustomNode(<class 'pax._src.nn.linear.Linear'>[{'_pax': PaxModuleInfo[name=None, training=True, frozen=False, nodes=[weight:PARAMETER, bias:PARAMETER]], 'in_dim': 2, 'out_dim': 2, 'with_bias': True}], [*, *])]))
================
Differences:
{'_pa[83 chars]st': [0], 'fc': Linear[in_dim=2, out_dim=2, with_bias=True]} != {'_pa[83 chars]st': [0, 0], 'fc': Linear[in_dim=2, out_dim=2, with_bias=True]}
  {'_pax': PaxModuleInfo[name=None, training=True, frozen=False, nodes=[fc:MODULE]],
-  'a_list': [0],
+  'a_list': [0, 0],
?             +++

   'fc': Linear[in_dim=2, out_dim=2, with_bias=True]}

In this case, a modification is detected when the `__call__` method of `m2` is executed.

In case you want to modify a module. You need to use the `pax.mutate` context manager which unfreezes the module.

The following cell shows how to append a new element to `a_list` correctly.

In [3]:
m2 = M2()
print(m2)

with pax.mutate(m2):
    m2.a_list.append(0)

print(m2)

M2(a_list=[0])
M2(a_list=[0, 0])
