# The Leaky Integrate-and-fire

<img src="./lif_formulation.png" alt="drawing" width="400"/>

## Notebook Description

Here we've restructured the code from the previous notebook, [jax_leaky_integrate_and_fire_2.ipynb](./jax_leaky_integrate_and_fire_2.ipynb), to use the `scan` provided in [lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) 


See [exercises/exe_03_jit.ipynb](../../exercises/exe_03_primitives.ipynb) for more information

## Core Concepts:

- `scan` for loop-less loops
- `jit`

**Note**: we do not expect to see a speedup in the code over the previous notebook- `scan` is provided for optimized compilation; this notebook merely shows you **how** you can speed up slow compilations. Again, read the notebook for more information. 

In [None]:
import jax
import jax.numpy as jnp

from typing import TypeAlias
import time
import numpy as np
from functools import partial

from lif_hparams import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)

with open('weights.npy', 'rb') as f:
    W = jnp.asarray(np.load(f))


# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

from tqdm.notebook import tqdm

# Type Definitions for Clarity

In [None]:
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

In [None]:
@partial(jax.jit, static_argnames=["v_thresh", "v_reset", "dt", "tau_m", "membr_R"])
def _update_step(
        carry,
        _,  # The second arg here should be None
        W,
        # Hyperparameters
        v_thresh, v_reset, dt, tau_m, membr_R):
    V = carry
    fired = V >= v_thresh
    V = jnp.where(fired, v_reset, V)
    
    # Update voltages
    I_syn = W.dot(fired)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V += dV

    # No self-inputs; neurons cannot spike themselves in this timestep
    V = jnp.where(fired, v_reset, V)
    return (V), fired


def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    # Simulation

    update_step = partial(
        _update_step,
        W= W,
        v_thresh=v_thresh,
        v_reset=v_reset,
        dt=dt,
        tau_m=tau_m,
        membr_R=membr_R
    )
    state, spike_train = jax.lax.scan(
        f=update_step,
        init=V,
        xs=None,
        length=int(t_max // dt)
    )
    return spike_train

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    np.asarray(spike_train)
    end = time.time()
    #print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")