# 📜 Stateful computations<a id="StatefulComputation"></a>

In this notebook, we demonstrate how to handle internal states in the immutable `pytreeclass` with functional API.

First, [Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.

To give a concrete example, the following code demonstrate the difference between the _Explicit parameters approach_ and the _class instance state as parameter approach_

<table>

<tr><td align="center">Passing explicit parameters</td><td align="center">Passing state as parameter</td></tr>

<tr>
<td>

```python
import jax 
import jax.numpy as jnp
import jax.random as jr
import jax.nn as jnn
import jax.tree_util as jtu

def init_params(layers):
  keys = jr.split(jr.PRNGKey(0),len(layers)-1)
  params = list()
  init_func = jnn.initializers.he_normal()
  
  for key,n_in,n_out in zip(
    keys,layers[:-1],layers[1:]
  ):
    W = init_func(key,(n_in,n_out))
    B = jr.uniform(key,shape=(n_out,))
    params.append({'W':W,'B':B})
  return params

def fwd(params,x):
  *hidden,last = params
  for layer in hidden :
    x = jnn.tanh(x@layer['W']+layer['B'])
  return x@last['W'] + last['B']

@jax.value_and_grad
def loss_func(params,x,y):
  pred = fwd(params,x)
  return jnp.mean((pred-y)**2)

@jax.jit
def update(params,x,y):
  # gradient w.r.t to params
  value,grads= loss_func(params,x,y)
  sgd = lambda p,g: p-1e-3*g
  params = jtu.tree_map(sgd, params,grads)
  return value,params

x = jnp.linspace(0,1,100).reshape(100,1)
y = x**2 -1 

params = init_params([1] +[5]*4+[1] )

epochs = 10_000
for _ in range(1,epochs+1):
  value , params = update(params,x,y)

  # print loss and epoch info
  if _ %(1_000) ==0:
    print(f'Epoch={_}\tloss={value:.3e}')
```


</td>


<td>

```python
import jax 
import jax.numpy as jnp 
import jax.random as jr
import jax.nn as jnn
import jax.tree_util as jtu
import pytreeclass as pytc


class MLP(pytc.TreeClass):
  Layers : list

  def __init__(self,layers):
    keys = jr.split(jr.PRNGKey(0),len(layers)-1)
    self.Layers = list()
    init_func = jnn.initializers.he_normal()
    
    for key,n_in,n_out in zip(
      keys,layers[:-1],layers[1:]
    ):
      W = init_func(key,(n_in,n_out))
      B = jr.uniform(key,shape=(n_out,))
      self.Layers.append({'W':W,'B':B})
    
  def __call__(self,x):
    *hidden,last = self.Layers
    for layer in hidden :
      x = jnn.tanh(x@layer['W']+layer['B'])
    return x@last['W'] + last['B'] 

@jax.value_and_grad
def loss_func(model,x,y):
  pred = model(x)
  return jnp.mean((pred-y)**2)

@jax.jit
def update(model,x,y):
  # gradient w.r.t to model
  value , grads= loss_func(model,x,y)
  model = model-1e-3*grads
  return value , model

x = jnp.linspace(0,1,100).reshape(100,1)
y = x**2 -1 

model = MLP([1] +[5]*4+[1] )

epochs = 10_000
for _ in range(1,epochs+1):
  value , model = update(model,x,y)

  # print loss and epoch info
  if _ %(1_000) ==0:
    print(f'Epoch={_}\tloss={value:.3e}')
```

</td>

</tr>

</table>


Using the following pattern,Updating state **functionally** can be achieved under `jax.jit`

In [1]:
import jax
import pytreeclass as pytc


class Counter(pytc.TreeClass):
    calls: int = 0

    def increment(self):
        self.calls += 1


counter = Counter()  # Counter(calls=0)

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.  

In [2]:
@jax.jit
def update(counter):
    # update the function functionally
    # return the call output value and
    # the updated model
    value, new_counter = counter.at["increment"]()
    return new_counter


for i in range(10):
    counter = update(counter)

print(counter.calls)  # 10

10
