# Jax's Loops

By the end of this lesson, you'll be able to articulate why and when you want to use `jax`'s native `while_loop`, `fori_loop`, and `scan` over python's native loops. In the process, you'll learn how to read haskell-like type signatures, which will be useful as you explore the `jax` library further.

In [1]:
import numpy as np
from typing import TypeAlias
import time
import jax.numpy as jnp
import jax
from tqdm.notebook import tqdm
np.random.seed(42)

In [2]:
from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)

with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [3]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Haskell-like signatures

type signatures are a great way to abstractly understand functions and what they do. Let's walk through a few examples:

## Examples to work through

```haskell
map :: (a -> b) -> [a] -> [b]
```

```haskell
sum :: a => [a] -> a
```

```haskell
(++) :: [a] -> [a] -> [a]
```

```haskell
filter :: (a -> Bool) -> a -> a 
```

## The Jax Functions we will be covering:

```haskell
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
```

```haskell
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
```

```haskell
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
```

## Jax's Loops

```python
def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
  return carry, np.stack(ys)
```

## Jax's Loops

```python
def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val

def fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
  return carry, np.stack(ys)
```

# Looping in Jax

As mentioned before, you probably don't want to `jit` a function that has a native python `for-loop` in it as this increases your compilation time. Thankfully, `jax` provides:

- [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html#jax.lax.while_loop)
- [jax.lax.fori_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html)

to circumvent this issue. 

Note: we don't necessarily see a speedup in runtime (although that can happen). The primary advantage of using these jax functions is that the compilation time can be reduced. Below we have an example of the jax `fori_loop` in action

# Lotka Volterra

Let's consider a simple example using the following equation:

![](../assets/lotka_volterra.png)

In [6]:
import jax
import jax.numpy as jnp
from jax import jit
import numpy as np

@jax.jit
def lotka_volterra_step(state, params):
    x, y = state
    alpha, beta, gamma, delta, dt = params
    
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y
    
    x_new = x + dxdt * dt
    y_new = y + dydt * dt
    
    return x_new, y_new

@jax.jit
def body_func(i, val):
    x, y, trajectory = val
    x, y = lotka_volterra_step((x, y), params)
    trajectory = trajectory.at[i].set([x, y])
    return x, y, trajectory

# Parameters
alpha = 1.1
beta = 0.4
gamma = 0.4
delta = 0.1
dt = 0.1
num_steps = 20

# Initial populations
x_prev = 10.0
y_prev = 5.0

params = (alpha, beta, gamma, delta, dt)
trajectory = jnp.zeros((num_steps, 2))
trajectory = trajectory.at[0].set([x_prev, y_prev])

"""
TODO: convert the following explicit for-loop into a jax `fori_loop`
    fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
"""

trajectory = jax.lax.fori_loop(1, num_steps, body_func,  (x_prev, y_prev, trajectory))
print(trajectory)
# for i in range(1, num_steps):
#     x_new, y_new, trajectory = body_func(i, (x_prev, y_prev, trajectory))
#     x_prev, y_prev = x_new, y_new
# print(trajectory)

(Array(0.83927435, dtype=float32, weak_type=True), Array(4.9796286, dtype=float32, weak_type=True), Array([[10.        ,  5.        ],
       [ 9.1       ,  5.3       ],
       [ 8.171801  ,  5.5703    ],
       [ 7.249923  ,  5.802682  ],
       [ 6.3646545 ,  5.991265  ],
       [ 5.539473  ,  6.1329374 ],
       [ 4.7898855 ,  6.227352  ],
       [ 4.1236405 ,  6.276541  ],
       [ 3.5419528 ,  6.2843018 ],
       [ 3.0412197 ,  6.2555165 ],
       [ 2.6147778 ,  6.19554   ],
       [ 2.2544048 ,  6.109718  ],
       [ 1.9514382 ,  6.003067  ],
       [ 1.6975118 ,  5.8800907 ],
       [ 1.4849771 ,  5.7447023 ],
       [ 1.3070946 ,  5.6002216 ],
       [ 1.1580743 ,  5.449413  ],
       [ 1.0330294 ,  5.2945447 ],
       [ 0.92788583,  5.137457  ],
       [ 0.83927435,  4.9796286 ]], dtype=float32))


# Jax Scan

As we saw in the type signature, using the `scan` allows us to carry values between iterations. This is clearly ideal for our application and we should thus use it!

In [None]:
@jax.jit
def run_step(v_prev, v_thresh, v_reset, W, tau_m, dt, membr_R):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    I_syn = W @ spiked.astype(jnp.float32)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V = V + dV

    V = jnp.where(spiked, v_reset, V)
    return V, spiked

@jax.jit
def scan_step(carry, _):

    (V, v_thresh, v_reset, W, tau_m, dt, membr_R) = carry
    new_V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
    return (new_V, v_thresh, v_reset, W, tau_m, dt, membr_R), spike

def run_simulation(W, V, tau_m, v_reset, v_thresh, membr_R, t_max, dt):
    """
    TODO: 
        Implement the scan step!
        
    """
    num_steps = int(t_max / dt)
    # Run the scan over the number of time steps
    final_V, accum_spikes = ...
    return accum_spikes

In [None]:
"""
No need to change the code below
"""

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    np.asarray(spike_train)
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

# Further Exercises: 

## 1) Convert the example Lotka Volterra equation above to use a `while_loop`

## 2) Work through [extras_namedtuple.ipynb](./extras_namedtuple.ipynb) 

which teaches you some best practices to learn how you can keep your code clean using python's `namedtuples`