### Necessary Packages

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

## Problem of "State"

In [3]:
# We've seen that impure functions are problematic

g = 0
# Acts as a state

def impure_uses_global(x):
    return x+g

print("First Call: ", jit(impure_uses_global)(2.))

print("Updating g from ", g," to 10")
g = 10

print("Second Call: ", jit(impure_uses_global)(2.))

First Call:  2.0
Updating g from  0  to 10
Second Call:  2.0


### Need to explicitly address the problem of state

##### Why? <br>
- Neural Networks love statefulness!
- Model Parameters, Optimizer Parameters, etc
- And JAX has a major problem with this 

In [4]:
# Stateful Class
class Counter:
    """A simple counter"""

    def __init__(self):
        self.n = 0
        # Acts as a state

    def count(self)->int:
        self.n += 1
        return self.n

    def reset(self):
        self.n = 0

In [5]:
counter = Counter()

for _ in range(3):
    print(counter.count())

1
2
3


In [6]:
counter.reset()

fast_count = jit(counter.count)

for _ in range(3):
    print(fast_count())

1
1
1


- Not exactly working as expected

- Let's use jaxpr to see what's happening

In [8]:
counter.reset()
print(jax.make_jaxpr(counter.count)())

{ lambda ; . let  in (1:i32[],) }


- When JIT runs a trace, it will go through the function
- the counter will get incremented to 1
- And it will return 1 as the trace ends one run through
- JIT will learn this behavior only
- And then keep sending out 1

In [9]:
# Solution:
# Implement Counter State as a primitive data type like Integer

CounterState = int

class CounterV2:
    
    def count(self, n:CounterState) -> Tuple[int, CounterState]:
        # Could just return n+1 here
        # but here we separate it's role
        # Outputting the counter and state for didactic purposes 
        return n+1, n+1

    def reset(self)-> CounterState:
        return 0

In [11]:
counter = CounterV2()
state = counter.reset()
# Notice how reset now returns state

In [12]:
for _ in range(3):
    value, state = counter.count(state)
    print(value)

1
2
3
