# Jax scan

## Goals:

- explain why you'd want to use Jax's native loops

- understanding `scan`

- static arguments/ static argnames

## Concepts:

- reading haskell function signatures
- partial functions
- loopless-loops

# Looping in Jax

Is it possible to make jax `jit` more of the simulation for a larger speedup? Yes, but we need to address the 

we'd need to convert that pesky python `for-loop` into something that `jax` knows how to handle. Thankfully, `jax` provides:

- [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html#jax.lax.while_loop)
- [jax.lax.fori_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html)

Note: we don't necessarily see a speedup in runtime (although that can happen). The primary advantage of using these jax functions is that the compilation time can be reduced. If you structure it well

In [None]:
import numpy as np
from typing import TypeAlias
import time

from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)
import jax.numpy as jnp
import jax
from tqdm.notebook import tqdm

with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [None]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Fori_loop

In [None]:
from functools import partial

@jax.jit
def run_step(
    v_prev,
    v_thresh,
    v_reset,

    W,
    tau_m,
    dt,
    membr_R,
):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    # Update voltages
    I_syn = W.dot(spiked)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V = V + dV

    # No self-inputs; neurons cannot spike themselves in this timestep
    V = jnp.where(spiked, v_reset, V)
    return V, spiked

def run_simulation(W, V, tau_m, v_reset, v_thresh, membr_R, t_max, dt, num_steps):
    def body_func(i, val):
        V, spike_train = val
        V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
        return (V, spike_train.at[i].set(spike))


    spike_train = jnp.zeros((num_steps, len(W)), dtype=bool)
    _, spike_train = jax.lax.fori_loop(0, num_steps, body_func, (V, spike_train))
    return spike_train

In [None]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt, num_steps = int(_t_max / _dt)
    )
    end = time.time()
    np.asarray(spike_train)
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

# Run the Simulations

In [None]:
from functools import partial

@partial(jax.jit, 
         static_argnames=['v_thresh', "v_reset", "W", "tau_m", "dt", "membr_R"],
        )
def run_step(v_prev, v_thresh, v_reset, W, tau_m, dt, membr_R):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    I_syn = W @ spiked.astype(jnp.float32)  # Synaptic current from spikes
    reset_adjustment = v_reset + membr_R * I_syn
    dV = (dt / tau_m) * (-V + reset_adjustment)
    V = V + dV

    V = jnp.where(spiked, v_reset, V)
    return V, spiked


def scan_step(carry, _):
    (V, v_thresh, v_reset, W, tau_m, dt, membr_R) = carry
    new_V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
    return (new_V, v_thresh, v_reset, W, tau_m, dt, membr_R), spike

def run_simulation(W, V, tau_m, v_reset, v_thresh, membr_R, t_max, dt):
    num_steps = int(t_max / dt)
    # Run the scan over the number of time steps
    final_V, accum_spikes = jax.lax.scan(
        f=scan_step, 
        init=(V, v_thresh, v_reset, W, tau_m, dt, membr_R), 
        xs=jnp.arange(num_steps)
    )
    return accum_spikes

In [None]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")