# Limitations

In this tutorial, we show some limitations when using ``pax.Module`` and the rationale behind it.

## No unregistered `ndarray`'s

A Pax's module should not include any unregistered `ndarray`. 
This is to make sure every `ndarray` is a leaf of the pytree.

In [None]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git

In [1]:
import pax
import jax
import jax.numpy as jnp
import opax

In [2]:
class M1(pax.Module):
    def __init__(self):
        super().__init__()
        self.a = [jnp.array(0.), jnp.array(1)]
        
net = M1()



ValueError: Cannot assign ndarray to an attribute directly. Use `self.register_*` methods to assign your value.

Whenever we try to assign a pytree that contains a `ndarray` leaf, Pax will raise a ``ValueError``
informing that we should register it as part of the pytree.

The correct implementation should be:

In [3]:
class FixedM1(pax.Module):
    def __init__(self):
        super().__init__()
        self.register_states("a", [jnp.array(0.), jnp.array(1)])
        
net = FixedM1()

## Immutability in the forward pass

It is not recommended to modify a module's attributes during its forward-pass computation.
This is to ensure that the module keeps its original tree structure after a forward pass.
If this constraint is not satisfied, a `ValueError` will be raised when updating the model during training.

In [4]:
class M2(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(1, 1)
        self.info = 0
    
    def __call__(self, x):
        self.info = self.info + 1
        return self.fc(x)

net = M2()
optimizer = opax.adam(1e-4)(net.parameters())

def loss_fn(model, inputs):
    y = model(inputs[0])
    loss = jnp.mean(jnp.square(y - inputs[1]))
    return loss, (loss, model)

def update_fn(model, optimizer, inputs):
    grads, (loss, model) = jax.grad(loss_fn, has_aux=True, allow_int=True)(model, inputs)
    model, optimizer = pax.apply_gradients(model, optimizer, grads=grads)
    return model, optimizer, loss
    
x = jax.random.normal(jax.random.PRNGKey(42), (8, 1))
y = jax.random.normal(jax.random.PRNGKey(44), (8, 1))



In [5]:
for i in range(3):
    print(net.info)
    update_fn(net, optimizer, (x, y))

0


AssertionError: {'_na[57 chars]nd.MODULE: 3>)])), '_training': True, '_name': None, 'info': 0} != {'_na[57 chars]nd.MODULE: 3>)])), '_training': True, '_name': None, 'info': 1}
  {'_name': None,
   '_name_to_kind': mappingproxy(OrderedDict([('fc', <PaxFieldKind.MODULE: 3>)])),
   '_training': True,
-  'info': 0}
?          ^

+  'info': 1}
?          ^


We can easily see from the assertion error that `info` has changed.

Unfortunately, there are cases when the error went silent, for example, when you're appending to a list.

In [6]:
class M3(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc = pax.nn.Linear(1, 1)
        self.info = [0]
    
    def __call__(self, x):
        self.info.append(0)
        return self.fc(x)

net = M3()
optimizer = opax.adam(1e-4)(net.parameters())

for i in range(3):
    print(net.info)
    update_fn(net, optimizer, (x, y))

[0]
[0, 0]
[0, 0, 0]


Our model now has a side effect, which is usually not a good thing. For example, the side effect will _somehow_ disappear when we use `jax.jit`.

In [7]:
net = M3()
optimizer = opax.adam(1e-4)(net.parameters())
fast_update_fn = jax.jit(update_fn)
for i in range(3):
    print(net.info)
    fast_update_fn(net, optimizer, (x, y))

[0]
[0, 0]
[0, 0]


`net.info` has a new element after the first step. However, after the function is complied, the side effect disappears!


This is one of the many reasons, we introduce
`pax.grad`, the alternative version of `jax.grad`. 

`pax.grad` behaves the same as `jax.grad`, however, it turns on
pax's immutable mode to prevent these side effects.

In [8]:
def update_fn(model, optimizer, inputs):
    grads, (loss, model) = pax.grad(loss_fn, has_aux=True, allow_int=True, io_check=False)(model, inputs)
    model, optimizer = pax.apply_gradients(model, optimizer, grads=grads)
    return model, optimizer, loss
    
net = M3()
optimizer = opax.adam(1e-4)(net.parameters())

for i in range(10):
    print(net.info)
    update_fn(net, optimizer, (x, y))

[0]


AssertionError: {'_name_to_kind': mappingproxy(OrderedDict([('fc[74 chars] [0]} != {'_name': None, '_name_to_kind': mappingproxy(Or[77 chars], 0]}
  {'_name': None,
   '_name_to_kind': mappingproxy(OrderedDict([('fc', <PaxFieldKind.MODULE: 3>)])),
   '_training': True,
-  'info': [0]}
+  'info': [0, 0]}
?           +++


With `pax.grad`, we now detect that `info` has grown in the forward pass. We recommend using `pax.grad`, `pax.jit`, `pax.vmap`, etc. as the replacements for `jax.grad`, `jax.jit`, `jax.vmap`, etc.