# Kind assignments

This tutorial shows how the PAX kind assignment system works.

## Direct kind assignment

The most straightforward way to assign kinds to attributes is to call `self.set_attribute_kind` method. For example:

In [1]:
import jax
import jax.numpy as jnp
import pax
from pax import PaxKind

pax.seed_rng_key(42)


class M(pax.Module):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(0.0)
        self.counter = jnp.array(0)
        self.fc = pax.nn.Linear(3, 3)

        self.set_attribute_kind(
            weight=PaxKind.PARAMETER,
            counter=PaxKind.STATE,
            fc=PaxKind.MODULE,
        )


m = M()
print(m.pax)



PaxModuleInfo[training=True, nodes={weight:PARAMETER, counter:STATE, fc:MODULE}]


PAX also provides a more concise way to register attributes with `self.register_*` methods.

In [2]:
class M1(pax.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter("weight", jnp.array(0.0))
        self.register_state("counter", jnp.array(0))
        self.register_module("fc", pax.nn.Linear(3, 3))


m1 = M1()
print(m1.pax)

PaxModuleInfo[training=True, nodes={weight:PARAMETER, counter:STATE, fc:MODULE}]


Behind the scene, `self.register_*` methods do two things:

1. `self.__setattr__(name, value)`
2. `self.set_attribute_kind(**{name: kind})`

## Delayed kind assignment

### The `_default_kind` context manager

It is cumbersome that we have to set kind for every attribute manually. 
To get rid of this inconvenience, PAX introduces the `_default_kind` context manager.
Let's see how to use it with an example.

In [3]:
class M3(pax.Module):
    def __init__(self):
        super().__init__()
        with self._default_kind(PaxKind.MODULE):
            self.fc1 = pax.nn.Linear(1, 2)
            self.fc2 = pax.nn.Linear(2, 3)
            self.fc3 = pax.nn.Linear(3, 4)

    def __call__(self, x):
        x = jax.nn.relu(self.fc1(x))
        x = jax.nn.relu(self.fc2(x))
        x = jax.nn.relu(self.fc2(x))
        return x


m3 = M3()
print(m3.pax)

PaxModuleInfo[training=True, nodes={fc3:MODULE, fc2:MODULE, fc1:MODULE}]


Behind the scene, the `_default_kind` context manager searches for unregistered attributes created inside the context and assigns their kind to the default kind.

.. warning:: Only pytrees of ndarray or modules are considered by `_default_kind` context managers. Moreover, a pytree of ndarray's will not be considered as of kind `MODULE` and a pytreee of modules will not be considered as of kind `PARAMETER` or `STATE`.
```python
# only consider module objects
with self._default_kind(PaxKind.MODULE): 
    self.a = 123 # not considered
    self.b = jnp.array(0) # not considered 
    self.c = pax.nn.Linear(3, 3) # considered

# only consider pytree of ndarray
with self._default_kind(PaxKind.PARAMETER): 
    self.d = 123 # not considered
    self.e = jnp.array(0.) # considered
    self.f = pax.nn.Linear(3, 3) # not considered
```

### Why delay kind assignment?

The kind assignment actions are delayed to the end of the context. This feature is used to resolve ambiguities when an attribute is created as an empty container. For example:

```python
self.a = []
```

At this stage, PAX does not have enough information to decide if `self.a` is a pytree of ndarray/module or just a normal list (of integers, for example).
However, if we append a ndarray to the list, the ambiguities are resolved.

```python
self.a = []
self.a.append(jnp.array(0))
```

Therefore, PAX delays the decision to the end of the context when the attribute is fully initialized.

```python
with self._default_kind(PaxKind.STATE):
    self.a = []
    self.a.append(jnp.array(0))
```

### Shortcuts

It is still a bit inconvenient that we have to write `with self._default_kind(...)` for every module.
To get rid of this, PAX introduces three shortcuts:

Firstly, every `__init__` method of PAX modules will be executed inside a `self._default_kind(PaxKind.MODULE)` context.
Hence, submodules inside `__init__` are registered at the end of the `__init__` method, automatically.

In [4]:
class M3v2(pax.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = pax.nn.Linear(1, 2)
        self.fc2 = pax.nn.Linear(2, 3)
        self.fc3 = pax.nn.Linear(3, 4)

    def __call__(self, x):
        x = jax.nn.relu(self.fc1(x))
        x = jax.nn.relu(self.fc2(x))
        x = jax.nn.relu(self.fc2(x))
        return x


m3v2 = M3v2()
print(m3v2.pax)

PaxModuleInfo[training=True, nodes={fc3:MODULE, fc1:MODULE, fc2:MODULE}]


Secondly, PAX introduces `pax.ParameterModule` and `pax.StateModule` classes that execute the `__init__` method inside a `self._default_kind(PaxKind.{PARAMETER,STATE})` context.
Hence, new attributes inside `__init__` method of these classes will be registered automatically.

In [5]:
class M5(pax.ParameterModule):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)


m5 = M5()
print("m5.pax", m5.pax)


class M6(pax.StateModule):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)


m6 = M6()
print("m6.pax", m6.pax)

m5.pax PaxModuleInfo[training=True, nodes={bias:PARAMETER, weight:PARAMETER}]
m6.pax PaxModuleInfo[training=True, nodes={bias:STATE, weight:STATE}]


Lastly, PAX also introduces `self.add_{parameters,states}` methods as shortcuts for `self._default_kind(PaxKind.{PARAMETER,STATE})`.

In [6]:
class M7(pax.Module):
    def __init__(self):
        super().__init__()
        with self.add_parameters():
            self.weight = jnp.array(1.0)
            self.bias = jnp.array(0.0)

        with self.add_states():
            self.counter = jnp.array(0)


m7 = M7()
print("m7.pax", m7.pax)

m7.pax PaxModuleInfo[training=True, nodes={bias:PARAMETER, weight:PARAMETER, counter:STATE}]
