### Necessary Packages

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

## Problem of "State"

In [3]:
# We've seen that impure functions are problematic

g = 0
# Acts as a state

def impure_uses_global(x):
    return x+g

print("First Call: ", jit(impure_uses_global)(2.))

print("Updating g from ", g," to 10")
g = 10

print("Second Call: ", jit(impure_uses_global)(2.))

First Call:  2.0
Updating g from  0  to 10
Second Call:  2.0


### Need to explicitly address the problem of state

##### Why? <br>
- Neural Networks love statefulness!
- Model Parameters, Optimizer Parameters, etc
- And JAX has a major problem with this 

In [4]:
# Stateful Class
class Counter:
    """A simple counter"""

    def __init__(self):
        self.n = 0
        # Acts as a state

    def count(self)->int:
        self.n += 1
        return self.n

    def reset(self):
        self.n = 0

In [5]:
counter = Counter()

for _ in range(3):
    print(counter.count())

1
2
3


In [6]:
counter.reset()

fast_count = jit(counter.count)

for _ in range(3):
    print(fast_count())

1
1
1


- Not exactly working as expected

- Let's use jaxpr to see what's happening

In [8]:
counter.reset()
print(jax.make_jaxpr(counter.count)())

{ lambda ; . let  in (1:i32[],) }


- When JIT runs a trace, it will go through the function
- the counter will get incremented to 1
- And it will return 1 as the trace ends one run through
- JIT will learn this behavior only
- And then keep sending out 1

In [9]:
# Solution:
# Implement Counter State as a primitive data type like Integer

CounterState = int

class CounterV2:
    
    def count(self, n:CounterState) -> Tuple[int, CounterState]:
        # Could just return n+1 here
        # but here we separate it's role
        # Outputting the counter and state for didactic purposes 
        return n+1, n+1

    def reset(self)-> CounterState:
        return 0

- Now, instead of hiding state inside objects or global variables, we pass it around explicitly as function arguments and return values

In [11]:
counter = CounterV2()
state = counter.reset()
# Notice how reset now returns state

In [12]:
for _ in range(3):
    value, state = counter.count(state)
    print(value)

1
2
3


- By making state a function parameter, JAX can see it!
- Each cell takes the current state and returns the new state.

In [13]:
state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
    value, state = fast_count(state)
    print(value)

1
2
3


##### We have now discovered a way to handle states without the side effects

##### We still need a way to handle gradients when dealing with big NNs

##### Enter:

## PyTree

##### Why are gradients a problem in the first place?

In [14]:
f = lambda x,y,z,w: x**2 + y**2 + z**2 + w**2
# Imagine this but billion parameters

In [None]:
# JAX .backward() is not that great

x, y, z, w = [1.]*4
dfdx, dfdy, dfdz, dfdw = grad(f, argnums=(0,1,2,3))(x,y,z,w)
# In case of billion paramters
# dfdx1, dfdx2, dfdx3, ...... dfdx_billion = grad(f, argnums=(0,1,2......billion!))
print(dfdx, dfdy, dfdz, dfdw)

2.0 2.0 2.0 2.0


- There is obviously a better way - PyTree

#### We want to naturally wrap our parameters in some more complex data structure like dictionaries, etc.
#### JAX knows how to deal with this - PyTrees

In [19]:
# Contrived Example for Pedagogical Purposes
pytree_eg = [
    [1, 'a', object()],
    (1, (2,3), ()),
    [1, {"k1":2, "k2":(3,4)}, 5],
    {'a':2, 'b':(2,3)},
    jnp.array([1,2,3]),
]

In [20]:
# Let's see how many leaves this pytree has
for pytree in pytree_eg:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x00000265179306B0>] has 3 leaves: [1, 'a', <object object at 0x00000265179306B0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]


#### How do we manipulate PyTrees?

In [39]:
list_of_lists = [
    {'a':3},
    [1,2,3],
    [1,2],
    [1,2,3,4],
]

# For single arg functions, use tree.map
# tree.map iterates through leaves and applies the lambda func
print(jax.tree.map(lambda x: x**2, list_of_lists))

[{'a': 9}, [1, 4, 9], [1, 4], [1, 4, 9, 16]]


In [40]:
another_list = list_of_lists
print(jax.tree_util.tree_map(lambda x,y: x+y, list_of_lists, another_list))

[{'a': 6}, [2, 4, 6], [2, 4], [2, 4, 6, 8]]


#### PyTrees need to have the same structure if we are to apply tree.map

In [41]:
another_list = deepcopy(list_of_lists)
another_list.append([23])
print(jax.tree_util.tree_map(lambda x,y: x+y, list_of_lists, another_list))

ValueError: List arity mismatch: 5 != 4; list: [{'a': 3}, [1, 2, 3], [1, 2], [1, 2, 3, 4], [23]].

## Now we have everything we need to train a beginner MLP Model using JAX!