# The Leaky Integrate-and-fire

<img src="./lif_formulation.png" alt="drawing" width="400"/>

## Notebook Description

Here we've restructured the code from the previous notebook, [jax_leaky_integrate_and_fire_1.ipynb](./jax_leaky_integrate_and_fire_1.ipynb), to use the `jit`. Note that we 
extracted the "update" step and moved the spike storage to outside the update function. This is because jax's `jit` requires that the function and contents are immutable. 

See [exercises/exe_02_jit.ipynb](../../exercises/exe_02_jit.ipynb) for more information

## Core Concepts:

- `jit`
- immutability

In [1]:
import jax
import jax.numpy as jnp

from typing import TypeAlias
import time
import numpy as np

from lif_hparams import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)

with open('weights.npy', 'rb') as f:
    W = np.load(f)


# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

from tqdm.notebook import tqdm

# Type Definitions for Clarity

In [2]:
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

In [3]:
@jax.jit
def update_step(
        V, W,
        # Hyperparameters
        v_thresh, v_reset, dt, tau_m, membr_R):
    fired = V >= v_thresh
    V = jnp.where(fired, v_reset, V)
    
    # Update voltages
    I_syn = W.dot(fired)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V += dV

    # No self-inputs; neurons cannot spike themselves in this timestep
    V = jnp.where(fired, v_reset, V)
    return V, fired


def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    # Simulation

    spike_train = []
    for i, t in enumerate(jnp.arange(0, t_max, dt)):
        if i == 0:
            continue
        V, fired = update_step(V, W, v_thresh, v_reset, dt, tau_m, membr_R )
        spike_train.append(fired)

    return spike_train

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    np.asarray(spike_train)
    end = time.time()
    #print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Average Time: 30.854262351989746
S.Dev Time: 7.38322676024436


‼️‼️**Wait** this took more time than the numpy version, what gives? Jax's jit seems doesn't automatically convert our `np.ndarray` to `jnp.ndarray`. So, let's do so.  

In [4]:

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    np.asarray(spike_train)
    end = time.time()
    #print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")


Average Time: 6.209884564081828
S.Dev Time: 0.06889911224045832
