# Core Concepts

`PyTrees`:

In [1]:
import numpy as np
from typing import TypeAlias
import time

from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)
import jax.numpy as jnp
import jax
from tqdm.notebook import tqdm

with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [2]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

In [3]:
from collections import namedtuple

NeuronParams = namedtuple("NeuronParams", ["v_thresh", "v_reset", "W", "tau_m", "dt", "membr_R"])


# Run the Simulations

In [6]:

@jax.jit
def run_step(v_prev, params):
    spiked = v_prev >= params.v_thresh
    V = jnp.where(spiked, params.v_reset, v_prev)

    I_syn = params.W @ spiked.astype(jnp.float32)  # Synaptic current from spikes
    reset_adjustment = params.v_reset + params.membr_R * I_syn
    dV = (params.dt / params.tau_m) * (-V + reset_adjustment)
    V = V + dV

    V = jnp.where(spiked, params.v_reset, V)
    return V, spiked


def scan_step(carry, _):
    (V, params) = carry
    new_V, spike = run_step(V, params)
    return (new_V, params), spike

def run_simulation(V: jnp.array, t_max: float, params: NeuronParams):
    num_steps = int(t_max / params.dt)
    # Run the scan over the number of time steps
    final_V, accum_spikes = jax.lax.scan(
        f=scan_step, 
        init=(V, params), 
        xs=jnp.arange(num_steps)
    )
    return accum_spikes

In [7]:
time_arr = []
for i in range(num_simulations):
    params = NeuronParams(v_thresh=_V_thresh, v_reset=_V_reset, W=jnp.asarray(W), tau_m=_tau_m, dt=_dt, membr_R=_R)
    start = time.time()
    spike_train = run_simulation(
        _V,
        _t_max,
        params
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 0.7021489143371582 seconds
Iteration 1 took: 0.6558258533477783 seconds
Iteration 2 took: 0.65932297706604 seconds
Iteration 3 took: 0.6642570495605469 seconds
Iteration 4 took: 0.6738781929016113 seconds
Iteration 5 took: 0.6349859237670898 seconds
Iteration 6 took: 0.6387691497802734 seconds
Iteration 7 took: 0.6425280570983887 seconds
Iteration 8 took: 0.6432950496673584 seconds
Iteration 9 took: 0.6407830715179443 seconds
Average Time: 0.6555794239044189
S.Dev Time: 0.01959538753031299
