**S03P03_tutorial_stateful_computations_in_jax.ipynb**

Arz

2024 APR 22 (MON)

reference:
https://jax.readthedocs.io/en/latest/jax-101/07-state.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# motivation

in machine learning, program state most often comes in the form of:
- model parameters
- optimizer state
- stateful layers
    - ex) batch normalization
 
changing program state is one kind of side-effect. So, if we can’t have side effects, how do we update them? 

-> functional programming

# a simple example: counter

In [4]:
class Counter:
    def __init__(self):
        self.n = 0

    def count(self) -> int:
        self.n += 1
        return self.n

    def reset(self):
        self.n = 0

In [5]:
counter = Counter()

for _ in range(3):
    print(counter.count())

1
2
3


state n gets modified by count(). so self.n += 1 is a side effect.

## JIT test

In [6]:
counter.reset()

In [7]:
count_jit = jax.jit(counter.count)

for _ in range(3):
    print(count_jit())

1
1
1


# the solution: explicit state

in this new version of Counter, we moved n to be an argument of count, and added another return value that represents the new, updated, state. 

to use this counter, we now need to keep track of the state explicitly. but in return, we can now safely jax.jit this counter.

In [8]:
# typedef
Counter_State = int

# class
class Counter_V2:
    def count(self, n: Counter_State) -> tuple[int, Counter_State]:
        return n + 1, n + 1

    def reset(self) -> Counter_State:
        return 0

In [9]:
counter_V2 = Counter_V2()
counter_state = counter_V2.reset()

for _ in range(3):
    value, counter_state = counter_V2.count(counter_state)
    print(value)

1
2
3


## JIT test

In [10]:
counter_state = counter_V2.reset()

In [11]:
count_V2_jit = jax.jit(counter_V2.count)

for _ in range(3):
    value, counter_state = count_V2_jit(counter_state)
    print(value)

1
2
3


# a general strategy

**common functional programming pattern**
- the way that state is handled in all JAX programs.

we can apply the same process to any stateful method to convert it into a stateless one.

In [12]:
from typing import Any

In [13]:
State_Type = Any

## from

In [14]:
class My_Stateful_Class:
    def __init__(self):
        self.state = initial_state

    def my_stateful_method(*args, **kwargs) -> Any:
        # some operation on self.state
        return value

## to

In [15]:
class My_Stateless_Class:
    def my_stateless_method(state: State_Type, *args, **kwargs) -> (Any, State_Type):  # (value, state-type)
        # some operation on the argument state
        return value, state

# simple worked example: linear regression

in machine learning example, the only program state dealt is:
- model parameters

In [16]:
from typing import NamedTuple

In [17]:
class Params(NamedTuple):
    weight: jnp.ndarray
    bias: jnp.ndarray

def initialize(key) -> Params:
    weight_key, bias_key = jax.random.split(key)
    weight = jax.random.normal(weight_key, ())
    bias = jax.random.normal(bias_key, ())
    return Params(weight, bias)

def loss_function(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    y_pred = params.weight*x + params.bias
    return jnp.mean((y - y_pred)**2)

learning_rate = 0.005

#
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
    # compute loss and gradients on each given minibatch (individually on each device using pmap)
    loss, grad = jax.value_and_grad(loss_function)(params, x, y)

    # update params
    new_params = jax.tree_map(lambda param, grad: param - grad*learning_rate, params, grad)

    return new_params, loss

## data & setting

In [18]:
x = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))

y = 3*x - 1 + noise

**initialize parameters.**

In [19]:
params = initialize(jax.random.key(123))

## training

In [21]:
for i in range(1000):
    params, loss = update(params, x, y)

    if i%100 == 0:
        print(f"epoch {i:3d}, loss: {loss:.3f}")
        print(loss)

epoch   0, loss: 10.810
10.809915
epoch 100, loss: 1.238
1.2383049
epoch 200, loss: 0.165
0.16494352
epoch 300, loss: 0.031
0.03065598
epoch 400, loss: 0.012
0.011662046
epoch 500, loss: 0.009
0.008665814
epoch 600, loss: 0.008
0.008154484
epoch 700, loss: 0.008
0.0080628805
epoch 800, loss: 0.008
0.0080460245
epoch 900, loss: 0.008
0.008042876


In [22]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

In [23]:
fig = px.scatter(x=x, y=y)
fig_model = px.line(x=x, y=params.weight*x + params.bias)
fig_model.data[0].line.color = "#e02a19"
fig.add_trace(fig_model.data[0])

fig.show()

# taking it further

❓ how to deal with multiple parameters and more?

https://github.com/google/jax#neural-network-libraries