# Intro to Jax a.k.a swapping `np.X` for `jnp.X`

## Lesson Goals:

By the end of this lesson, you will have an understanding of how to migrate from `numpy` to `jax`, and get a feel for how similar the two libraries can be. 

In [1]:
import numpy as np
from typing import TypeAlias
import time
import jax.numpy as jnp
from tqdm.notebook import tqdm

np.random.seed(42)

# What is Jax?

To put it simply, Jax is numpy for various hardware accelerators. However, it offers much more than that by providing higher-level abstractions, utilizing a different backend (XLA), and supporting automatic differentiation.

From the website:

> JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Despite these capabilities, not all concepts and idioms from NumPy translate directly, and there are certain ‼️sharp edges‼️ of which you should be aware.

## Sample Exercises

Below, we provide some exercises to help you become familiar with Jax and Numpy. The solutions are more or less what you might expect from a drop-in replacement.

In [3]:

def dot_product():
    v = np.random.rand(10)
    M = np.random.rand(10, 5)

    expected_result = np.dot(v, M)
    actual_result = jnp.dot(v, M)

    assert jnp.allclose(expected_result, actual_result)
    print("Dot product passed")
    

def is_even_filter():
    to_filter_np = np.asarray([1, 2, 3, 5, 10, 20])
    expected_result = to_filter_np[to_filter_np % 2 == 0]

    to_filter_jnp = jnp.asarray(to_filter_np)
    
    actual_result = to_filter_jnp[to_filter_jnp % 2 == 0]

    assert jnp.allclose(expected_result, actual_result)
    print("is_even_filter passed")

def top_n_of_norm_squared():
    M = np.random.rand(10, 5)
    TOP_N = 5
    
    expected_result = np.sort(np.linalg.norm(M @ M.T, axis=1))[::-1][:TOP_N]

    jnp_M = jnp.asarray(M)
    actual_result = jnp.sort(jnp.linalg.norm(jnp_M @ jnp_M.T, axis=1))[::-1][:TOP_N]

    assert jnp.allclose(expected_result, actual_result)
    print("top_n_of_norm_squared passed")


def hadamard():
    M = np.random.rand(10, 5)
    expected_result = M * M

    jnp_M = jnp.asarray(M)
    actual_result = jnp_M * jnp_M # Your code here

    assert jnp.allclose(expected_result, actual_result)
    print("hadamard passed")
    
    

dot_product()
is_even_filter()
top_n_of_norm_squared()
hadamard()

Dot product passed
is_even_filter passed
top_n_of_norm_squared passed
hadamard passed


# Returning to the LIF Example

<img src="../assets/lif_formulation.png" alt="drawing" width="400"/>

We now return to the LIF model we discussed earlier. Simply put, we want to replace the np.X calls with jnp.X.

In [4]:
from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)


with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [5]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Run the Simulations

In [6]:
def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    # Simulation

    spike_train = []
    for i, t in enumerate(jnp.arange(0, t_max, dt)):
        if i == 0:
            continue
    
        fired = V >= v_thresh
        V = jnp.where(fired, v_reset, V)
        
        # Record spike times
        spike_train.append(fired)
    
        # Update voltages
        I_syn = W.dot(spike_train[-1])  # Synaptic current from spikes
        dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
        V += dV
    
        # No self-inputs; neurons cannot spike themselves in this timestep
        V = jnp.where(fired, v_reset, V)
    return spike_train

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    #print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

KeyboardInterrupt: 