# Jax jit

## Lesson Goals:

By the end of this lesson, you will know how to use the `jit`, how to accurately time computations using `jit`-ted functions, and how to identify where to `jit` things. In the process, we will quickly discuss functional programming and why functional programming is useful for speeding up computations. 

In [1]:
from typing import TypeAlias
import time
import jax.numpy as jnp
import numpy as np
import jax
from tqdm.notebook import tqdm

np.random.seed(42)

# Functional Programming?

Functional programming is many things, but for the purposes of this tutorial, it is a form of programming without side-effects. Python is not a functional programming language, but you may have heard of others such as `haskell`, `ocaml`, or `erlang`.

The most common form of side-effects involves modifying some internal state. Consider the following:

In [2]:
import copy
class ShoppingCart:
    def __init__(self):
        self.items = []

    def add_item(self, item):
        self.items.append(item)  # Side effect: modifying internal state

    def __repr__(self):
        return f"ShoppingCart({self.items})"

cart = ShoppingCart()
cart.add_item("banana")
print(cart)
cart.add_item("apple")
print(cart)
print("*" * 10)

class FunctionalShoppingCart:
    def __init__(self):
        self.items = []

    def add_item(self, item):
        new_cart = FunctionalShoppingCart()
        all_items = copy.deepcopy(self.items)
        all_items.append(item)
        new_cart.items = all_items
        return new_cart

    def __repr__(self):
        return f"FunctionalShoppingCart({self.items})"

func_cart = FunctionalShoppingCart()
func_cart.add_item("banana")  # <- The banana was not added!
print(func_cart)
func_cart = func_cart.add_item("apple")
print(func_cart)

ShoppingCart(['banana'])
ShoppingCart(['banana', 'apple'])
**********
FunctionalShoppingCart([])
FunctionalShoppingCart(['apple'])


## Functional Programming:

Okay, but how is this relevant? Well, functional programming allows for:

- predictable behavior: compilers can more easily optimize your code
- immutability: the data cannot be modified, so all threads/ processes just grab a copy of the original data and process it async.

# Jax's JIT: Supercharged functions

A `jit` is a just-in-time compilation of your code. `Python` is famously slow because, among other things, the code is interpreted i.e. at run-time, the interpreter has to decide what to do. Languages like `C++` are `Rust` are compiled so at run-time, the code is just... run.

So, by compiling out Jax code via the `jit`, we can accelerate our programs. Assuming the numerical computation is the bottleneck, as is often the case in ML tasks, this means that we have sped up the slowest part of our program.

## Where does functional programming come in? 

FP makes it easier for the `jit` compiler to speed up the code. It can do things like:

- function inlining: the function call is replaced by the function itself

- loop fusion/elimination/unrolling: by removing dependencies between calls, jax can 

- memoization: jax can cache results for particular inputs and return those if it sees those particular inputs again

In [3]:
input_arr = np.random.rand(1000, 1000)

def func_np(m):
    mask = m > 0.5
    m = np.where(mask, m**2, np.sqrt(m))
    return m @ m

print("Numpy version")
%timeit func_np(input_arr)

input_arr_j = jnp.asarray(input_arr)

def func_jax(m):
    mask = m > 0.5
    mod_m = jnp.where(mask, m**2, jnp.sqrt(m))
    return mod_m @ mod_m

print("Jax Non-Jit version")
%timeit func_jax(input_arr_j)

jitted_func = jax.jit(func_jax)

print("Jax Jitted version")
%timeit jitted_func(input_arr_j).block_until_ready()

Numpy version
25.5 ms ± 2.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Jax Non-Jit version
5.26 ms ± 50.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jax Jitted version
4.67 ms ± 311 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Quick Aside: Benchmarking in Jax

The astute would have noticed the `.block_until_ready()` function call. What gives? Well, jax returns a future to prevent blocking the main python thread. So, to get accurate timings we had to use the `.block_until_ready()`. To ensure that you get accurate timings when benchmarking you can:

- use `.block_until_ready()`
- convert the `jnp.array` into `np.array` to wait for the future
- print the `jnp.array`

For more information check out: [Jax Async Dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)

In [4]:


from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)


with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [5]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Run the Simulations

In [8]:
@jax.jit
def run_step(
    v_prev,
    v_thresh,
    v_reset,

    W,
    tau_m,
    dt,
    membr_R):
    
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)
    I_syn = W.dot(spiked)

    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V = V + dV
    V = jnp.where(spiked, v_reset, V)
    return V, spiked

def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    """
    TODO while keeping the function signature the same, abstract out 
        the contents of the for-loop into a jit-ted function that you 
        can call
    """
    # Simulation

    spike_train = []
    for i, t in enumerate(jnp.arange(0, t_max, dt)):
        if i == 0:
            continue
        
        V, spiked= run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
        spike_train.append(spiked)
    return spike_train

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        t_max=_t_max, dt=_dt
    )
    np.asarray(spike_train)
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)
    if i == 1:
        print("Breaking out - point proven")
        break

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 18.002535820007324 seconds


KeyboardInterrupt: 

# What gives? 

`jax.jit` doesn't always play nicely with numpy! There are times where calling `jnp.asarray` is necessary

In [9]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    np.asarray(spike_train)
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 5.72046685218811 seconds
Iteration 1 took: 5.695834159851074 seconds
Iteration 2 took: 5.726809740066528 seconds
Average Time: 5.714370250701904
S.Dev Time: 0.01336034068134031


# Further Exercises:

1) Read through [extras_when_not_to_jit.ipynb](./extras_when_not_to_jit.ipynb)

2) Read through [Jax AoT Compiling](https://jax.readthedocs.io/en/latest/aot.html) and take note of the limitations that come with Jax's AoT compiling