<a href="https://colab.research.google.com/github/Suvoo/Daad-Wise-prep/blob/main/JaxQuickstart-2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### The goal of this notebook will be to gain the knowledge necessary to build complex ML models (such as NNs) and train them in parallel on multiple devices! 💻💻💻

In [None]:
# Let's import the necessary packages
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap #parallelize-->pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

## The Problem of State
JAX ❤️ Pure Functions => JAX "!❤️" State.

In [None]:
# 1) We've seen in the last notebook/video that impure functions are problematic.

g = 0.  # state

# We're accessing some external state in this function which causes problems
def impure_uses_globals(x):
    return x + g

# JAX captures the value of the global/state during the first run
print ("First call: ", jit(impure_uses_globals)(4.))

# Let's update the global/state!
g = 10.

# Subsequent runs may silently use the cached value of the globals/state
print ("Second call: ", jit(impure_uses_globals)(5.))



First call:  4.0
Second call:  5.0


In [None]:
# 2) We've also seen this pattern how JAX's PRNG 
# (which is not stateful in contrast to NumPy's PRNG) is handling state.

seed = 0
state = jax.random.PRNGKey(seed)

# We input the state, we somehow manipulate it and we return it back.
# The state is not saved internally.
state1, state2 = jax.random.split(state)  # recall: key/subkey was the terminology we used

In [None]:
# Let's now explictly address and understand the problem of state!
# Why? 
# Well, NNs love statefulness: model params, optimizer params, BatchNorm, etc.
# and we've seen that JAX seems to have a problem with it.

class Counter:
    """A simple counter."""
    def __init__(self):
        self.n = 0
    def count(self) -> int:
        """Increments the counter and returns the new value."""
        self.n += 1
        return self.n
    def reset(self):
        """Resets the counter to zero."""
        self.n = 0
    
counter =  Counter()

for _ in range(3): # works fine
    print(counter.count())



1
2
3


In [None]:
counter.reset()
fast_count = jit(counter.count)

for _ in range(3): # not working
    print(fast_count()) # as count is not pure, so jit fails...cache "1" from first time

1
1
1


In [None]:
from jax import make_jaxpr # use jaxpr to understand why this is happening

counter.reset()
print(make_jaxpr(counter.count)()) # return 1

{ lambda ; . let  in (1,) }


In [None]:
counter.reset()
counter.count() # --> modifies state to 1
fast_count = jit(counter.count)

for _ in range(3): # not working
    print(fast_count())

2
2
2


In [None]:
from jax import make_jaxpr # use jaxpr to understand why this is happening

counter.reset()
counter.count()
print(make_jaxpr(counter.count)()) # return 2

{ lambda ; . let  in (2,) }


 Solution to avoid creating impure function

In [None]:
CounterState =  int

class CounterV2:
    def count(self, n: CounterState) -> Tuple[int, CounterState]:
        # You could just return n+1, but here we separate its role as 
        # the output and as the counter state for didactic purposes.
        # (as the output may be some arbitrary function of state in general case)
        return n+1,n+1

    def reset(self) -> CounterState:
        return 0

counter = CounterV2()
state = counter.reset() # notice how reset() now returns state (external vs internal imp)

for _ in range(3):
    value,state =  counter.count(state)
    print(value,state)

1 1
2 2
3 3


In [None]:
state = counter.reset()
fast_count = jit(counter.count)

for _ in range(3):
    value,state =  fast_count(state)
    print(value,state)

1 1
2 2
3 3


In [None]:
from jax import make_jaxpr # use jaxpr to understand why this is happening

counter.reset()
print(make_jaxpr(counter.count)(10)) # 

{ lambda ; a:i32[]. let b:i32[] = add a 1; c:i32[] = add a 1 in (b, c) }


In summary we used the following rule to convert a stateful class:

```python
class StatefulClass

    state: State

    def stateful_method(*args, **kwargs) -> Output:
```

into a class of the form:

```python
class StatelessClass

    def stateless_method(state: State, *args, **kwargs) -> (Output, State):
```

Nice - we figured an equivalent way to handle states without introducing the side-effects.

This brings us 1 step closer to building neural networks! 🥳

We still need to find a way to handle gradients when dealing with big NNs.