**S03P03_tutorial_stateful_computations_in_jax.ipynb**

Arz

2024 APR 22 (MON)

reference:
https://jax.readthedocs.io/en/latest/jax-101/07-state.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# motivation

in machine learning, program state most often comes in the form of:
- model parameters
- optimizer state
- stateful layers
    - ex) batch normalization
 
changing program state is one kind of side-effect. So, if we can’t have side effects, how do we update them? 

-> functional programming

# a simple example: counter

In [4]:
class Counter:
    def __init__(self):
        self.n = 0

    def count(self) -> int:
        self.n += 1
        return self.n

    def reset(self):
        self.n = 0

In [5]:
counter = Counter()

for _ in range(3):
    print(counter.count())

1
2
3


state n gets modified by count(). so self.n += 1 is a side effect.

## JIT test

In [6]:
counter.reset()

In [7]:
count_jit = jax.jit(counter.count)

for _ in range(3):
    print(count_jit())

1
1
1


# the solution: explicit state

in this new version of Counter, we moved n to be an argument of count, and added another return value that represents the new, updated, state. 

to use this counter, we now need to keep track of the state explicitly. but in return, we can now safely jax.jit this counter.

In [10]:
# typedef
Counter_State = int

# class
class Counter_V2:
    def count(self, n: Counter_State) -> tuple[int, Counter_State]:
        return n + 1, n + 1

    def reset(self) -> Counter_State:
        return 0

In [12]:
counter_V2 = Counter_V2()
counter_state = counter_V2.reset()

for _ in range(3):
    value, counter_state = counter_V2.count(counter_state)
    print(value)

1
2
3


## JIT test

In [13]:
counter_state = counter_V2.reset()

In [14]:
count_V2_jit = jax.jit(counter_V2.count)

for _ in range(3):
    value, counter_state = count_V2_jit(counter_state)
    print(value)

1
2
3
