# Core Concepts

`jit`: JIT: just-in-time compiler

In [1]:
import numpy as np
from typing import TypeAlias
import time

from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)
import jax.numpy as jnp
import jax
from tqdm.notebook import tqdm

with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [2]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Run the Simulations

In [3]:
def run_simulation(
    W: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    # Simulation

    spike_train = []
    for i, t in enumerate(np.arange(0, t_max, dt)):
        if i == 0:
            continue
        V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
        spike_train.append(spike)

    return spike_train

In [4]:
@jax.jit
def run_step(
    v_prev,
    v_thresh,
    v_reset,

    W,
    tau_m,
    dt,
    membr_R,
):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    # Update voltages
    I_syn = W.dot(spiked)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V = V + dV

    # No self-inputs; neurons cannot spike themselves in this timestep
    V = jnp.where(spiked, v_reset, V)
    return V, spiked



In [5]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)
    if i > 2:
        print("Breaking out - point proven")
        break

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 2.084556818008423 seconds
Iteration 1 took: 2.0548107624053955 seconds
Iteration 2 took: 2.0027401447296143 seconds
Iteration 3 took: 1.9824838638305664 seconds
Breaking out - point proven
Average Time: 2.0311478972434998
S.Dev Time: 0.04058211599907389


# What gives? 

`jax.jit` doesn't play nicely with numpy! There are times where calling `jnp.asarray` is necessary

In [6]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 0.704719066619873 seconds
Iteration 1 took: 0.6860570907592773 seconds
Iteration 2 took: 0.7029411792755127 seconds
Iteration 3 took: 0.6984376907348633 seconds
Iteration 4 took: 0.7066440582275391 seconds


KeyboardInterrupt: 

# Helping out the compiler

In [20]:
from functools import partial

partial(jax.jit, static_argnums=(1, 2, 4, 5, 6))  # Indexes of v_thresh, v_reset, dt
#@jax.jit
def run_step(
    v_prev,
    v_thresh,
    v_reset,

    W,
    tau_m,
    dt,
    membr_R,
):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    # Update voltages
    I_syn = W.dot(spiked)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
    V = V + dV

    # No self-inputs; neurons cannot spike themselves in this timestep
    V = jnp.where(spiked, v_reset, V)
    return V, spiked

@jax.jit
def run_simulation(W, V, tau_m, v_reset, v_thresh, membr_R, t_max, dt, num_steps):
    def body_fun(i, val):
        V, spike_train = val
        V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R)
        return (V, spike_train.at[i].set(spike))


    spike_train = jnp.zeros((num_steps, len(W)), dtype=bool)
    _, spike_train = jax.lax.fori_loop(0, num_steps, body_fun, (V, spike_train))
    return spike_train

In [21]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt, num_steps = int(_t_max / _dt)
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 2500).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function run_simulation at /var/folders/nr/plbmjfs57_98f06_km7x4ny40000gn/T/ipykernel_57255/2643218216.py:27 for jit. This concrete value was not available in Python because it depends on the value of the argument num_steps.