# Immutability

Many Jax transformations such as `jax.jit`, `jax.grad`, `jax.vmap`, `jax.pmap` requires pure functions.

Therefore, it is a good practice to keep Pax's modules immutable. This will prevent many undefined behaviors due to side effects.

By default, Pax's modules are immutable. However, Pax provides `pax.mutate` transformation which enables mutable mode when needed (for example, doing model surgery).

In [None]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git

In [1]:
import pax
import jax.numpy as jnp

class M1(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.counter = 0
    
    def __call__(self, x):
        self.counter = self.counter + 1
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(counter={self.counter})'

x = jnp.zeros((3, 2))
m1 = M1()
print(m1)
m1(x)
print(m1)

with pax.ctx.immutable():
    m1(x)
    print(m1)



M1(counter=0)


ValueError: Cannot assign an attribute of kind `PaxFieldKind.OTHERS` in immutable mode.

Unfortunately, Pax cannot catch all mutated objects. For example, it cannot detect if we modified a container.

In [2]:
class M2(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(2, 2)
        self.a_list = [0]
    
    def __call__(self, x):
        self.a_list.append(0)
        y = self.fc(x)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(a_list={self.a_list})'

x = jnp.zeros((3, 2))
m2 = M2()
print(m2)
m2(x)
print(m2)

m2(x)
print(m2)

M2(a_list=[0])
M2(a_list=[0, 0])
M2(a_list=[0, 0, 0])


In these cases, Pax keeps the tree definition of the original module after initialization. This allows checking for modification when needed.

Pax provides the `check_treedef` method to check if a module is modified.

In [3]:
m2.check_treedef()

ValueError: The module `M2(a_list=[0, 0, 0])` has its treedef modified.
--- PyTreeDef(CustomNode(<class '__main__.M2'>[(['fc'], {'_name': None, '_name_to_kind': mappingproxy(OrderedDict([('fc', <PaxFieldKind.MODULE: 3>)])), '_training': True, 'a_list': [0]})], [CustomNode(<class 'pax._src.nn.linear.Linear'>[(['weight', 'bias'], {'_name': None, '_name_to_kind': mappingproxy(OrderedDict([('weight', <PaxFieldKind.PARAMETER: 2>), ('bias', <PaxFieldKind.PARAMETER: 2>)])), '_training': True, 'in_dim': 2, 'out_dim': 2, 'with_bias': True})], [*, *])]))
+++ PyTreeDef(CustomNode(<class '__main__.M2'>[(['fc'], {'_name': None, '_name_to_kind': mappingproxy(OrderedDict([('fc', <PaxFieldKind.MODULE: 3>)])), '_training': True, 'a_list': [0, 0, 0]})], [CustomNode(<class 'pax._src.nn.linear.Linear'>[(['weight', 'bias'], {'_name': None, '_name_to_kind': mappingproxy(OrderedDict([('weight', <PaxFieldKind.PARAMETER: 2>), ('bias', <PaxFieldKind.PARAMETER: 2>)])), '_training': True, 'in_dim': 2, 'out_dim': 2, 'with_bias': True})], [*, *])]))
================
Differences:
{'_na[117 chars]': [0], 'fc': Linear[in_dim=2, out_dim=2, with_bias=True]} != {'_na[117 chars]': [0, 0, 0], 'fc': Linear[in_dim=2, out_dim=2[13 chars]rue]}
  {'_name': None,
   '_name_to_kind': mappingproxy(OrderedDict([('fc', <PaxFieldKind.MODULE: 3>)])),
   '_training': True,
-  'a_list': [0],
+  'a_list': [0, 0, 0],
?             ++++++

   'fc': Linear[in_dim=2, out_dim=2, with_bias=True]}

Pax provides thin wrappers of jax's transformations with `check_treedef` and additional safeguards turned on. Pax includes: ``pax.jit``, ``pax.grad``, ``pax.vmap``, ``pax.pmap`` as alternatives to Jax's transformations.