# Jax jit

## Lesson Goals:

By the end of this lesson, you will know how to use the `jit`, how to accurately time computations using `jit`-ted functions, and how to identify where to `jit` things. In the process, we will quickly discuss functional programming and why functional programming is useful for speeding up computations. 

## Core Concepts



## Concepts In action:


- Easy: [lotka-volterra](../case_studies/lotka-volterra/README.md)

- Intermediate: [leaky_integrate_and_fire](../case_studies/leaky_integrate_and_fire/README.md)

- Advanced: [gaussian_mixture_model](../case_studies/gaussian_mixture_model/README.md)

In [None]:
import jax.numpy as jnp
import numpy as np
import jax

np.random.seed(42)

# Functional Programming?

Functional programming is many things, but for the purposes of this tutorial, it is a form of programming without side-effects. Python is not a functional programming language, but you may have heard of others such as `haskell`, `ocaml`, or `erlang`.

The most common form of side-effects involves modifying some internal state. Consider the following:

In [None]:
import copy
class ShoppingCart:
    def __init__(self):
        self.items = []

    def add_item(self, item):
        self.items.append(item)  # Side effect: modifying internal state

    def __repr__(self):
        return f"ShoppingCart({self.items})"

cart = ShoppingCart()
cart.add_item("banana")
print(cart)
cart.add_item("apple")
print(cart)
print("*" * 10)

class FunctionalShoppingCart:
    def __init__(self):
        self.items = []

    def add_item(self, item):
        new_cart = FunctionalShoppingCart()
        all_items = copy.deepcopy(self.items)
        all_items.append(item)
        new_cart.items = all_items
        return new_cart

    def __repr__(self):
        return f"FunctionalShoppingCart({self.items})"

func_cart = FunctionalShoppingCart()
func_cart.add_item("banana")  # <- The banana was not added!
print(func_cart)
func_cart = func_cart.add_item("apple")
print(func_cart)

## Functional Programming:

Okay, but how is this relevant? Well, functional programming allows for:

- predictable behavior: compilers can more easily optimize your code
- immutability: the data cannot be modified, so all threads/ processes just grab a copy of the original data and process it async.

# Jax's JIT: Supercharged functions

A `jit` is a just-in-time compilation of your code. `Python` is famously slow because, among other things, the code is interpreted i.e. at run-time, the interpreter has to decide what to do. Languages like `C++` are `Rust` are compiled so at run-time, the code is just... run.

So, by compiling out Jax code via the `jit`, we can accelerate our programs. Assuming the numerical computation is the bottleneck, as is often the case in ML tasks, this means that we have sped up the slowest part of our program.

## Where does functional programming come in? 

FP makes it easier for the `jit` compiler to speed up the code. It can do things like:

- function inlining: the function call is replaced by the function itself

- loop fusion/elimination/unrolling: by removing dependencies between calls, jax can 

- memoization: jax can cache results for particular inputs and return those if it sees those particular inputs again

How do we use the jit? We can either use it as a function decorator, or as a function call. Each has its own advantages and disadvantages

In [None]:
def doubler(x):
    return x * 2

@jax.jit
def jitted_doubler(x):
    return x * 2

alternative_jitted_doubler = jax.jit(doubler)

# Jit in action

Let's do a simple-ish task where we generate a matrix, $M \in [0, 1)^{1000, 1000}$ and everything less than 0.5 we take the square-root of, and anything greater than 0.5 we square. We finally multiply this matrix with itself

In [None]:
input_arr = np.random.rand(1000, 1000)

def func_np(m):
    mask = m > 0.5
    m = np.where(mask, m**2, np.sqrt(m))
    return m @ m

print("Numpy version")
%timeit func_np(input_arr)

input_arr_j = jnp.asarray(input_arr)

def func_jax(m):
    # TODO: implement the jax equivalent of the above
    pass

print("Jax Non-Jit version")
%timeit func_jax(input_arr_j)

jitted_func = ... # TODO: Jit the function you just implemented

print("Jax Jitted version")
%timeit jitted_func(input_arr_j).block_until_ready()

# Jit in action Part 2

The previous task was pretty trivial for a modern computer. Let's increase the size of M. $M \in [0, 1)^{4K, 4K}$, increasing the size by 16X

In [None]:
input_arr = np.random.rand(4_000, 4_000)
print("Numpy version")
%timeit func_np(input_arr)

input_arr_j = jnp.asarray(input_arr)

print("Jax Non-Jit version")
%timeit func_jax(input_arr_j)

jitted_func = jax.jit(func_jax)

print("Jax Jitted version")
%timeit jitted_func(input_arr_j).block_until_ready()

## Quick Aside: Benchmarking in Jax

The astute would have noticed the `.block_until_ready()` function call. What gives? Well, jax returns a future to prevent blocking the main python thread. So, to get accurate timings we had to use the `.block_until_ready()`. To ensure that you get accurate timings when benchmarking you can:

- use `.block_until_ready()`
- convert the `jnp.array` into `np.array` to wait for the future
- print the `jnp.array`

For more information check out: [Jax Async Dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)

# Exercise: Using side-effects

To really drill in the idea of mutability and immutability, we're going to show how important it is to be immutable when you use `jit`. Define your own custom function that relies on some external state, as well as its `jit` version. We'll then see how they diverge.  

In [None]:
small_rand_jnp = jnp.asarray(np.random.rand(3, 3))


In [None]:

def impure_non_jit(x):
    raise NotImplemented

def impure_jitted(x):
    raise NotImplemented


for i in range(5):
    print(impure_non_jit(small_rand_jnp))
    print(impure_jitted(small_rand_jnp))
    print("*" * 10)
    
    # TODO: update the state here and see how they diverge!
    
    

    

# Further Exercises:

0) Read through [case_studies/leaky_integrate_and_fire/jax_leaky_integrate_and_fire_2_jit.ipynb](../case_studies/leaky_integrate_and_fire/jax_leaky_integrate_and_fire_2_jit.ipynb)

1) Read through [extras_when_not_to_jit.ipynb](./extras_when_not_to_jit.ipynb)

2) Read through [Jax AoT Compiling](https://jax.readthedocs.io/en/latest/aot.html) and take note of the limitations that come with Jax's AoT compiling