In [3]:
import numpy as np
from typing import TypeAlias
import time

from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)

with open('weights.npy', 'rb') as f:
    W = np.load(f)


# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = np.ones(n_neurons) * _V_reset  # Initial potentials
_spike_train = np.zeros((n_neurons, len(np.arange(0, _t_max, _dt))))

from tqdm.notebook import tqdm

SyntaxError: invalid syntax (4099582326.py, line 5)

# Type Definitions for Clarity

In [3]:
Tensor3D: TypeAlias = np.ndarray
Mat: TypeAlias = np.ndarray
Vec: TypeAlias = np.ndarray 
CurrentArray: TypeAlias = np.ndarray

# Run the Simulations

In [4]:
def run_simulation(
    W: Mat,

    spike_train: Mat,
    V: Vec,

    # Neuron Parameters
    tau_m: float,
    v_reset: float,
    v_thresh: float,
    membr_R: float,

    # How long do we run for? 
    t_max: float,
    dt: float, 

):
    # Simulation

    for i, t in enumerate(np.arange(0, t_max, dt)):
        if i == 0:
            continue
    
        fired = V >= v_thresh
        V[fired] = v_reset  # Reset voltage if threshold is crossed
    
        # Record spike times
        spike_train[fired, i] = 1
    
        # Update voltages
        I_syn = W.dot(spike_train[:, i-1])  # Synaptic current from spikes
        dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn)
        V += dV
    
        # No self-inputs; neurons cannot spike themselves in this timestep
        V[fired] = v_reset
    return spike_train

time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        W,
        _spike_train, _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

Iteration 0 took: 5.18075704574585 seconds
Iteration 1 took: 5.004680871963501 seconds
Iteration 2 took: 4.873786211013794 seconds
Iteration 3 took: 5.169120788574219 seconds
Iteration 4 took: 4.585036039352417 seconds
Iteration 5 took: 4.613517999649048 seconds
Iteration 6 took: 4.438601970672607 seconds
Iteration 7 took: 4.990131855010986 seconds
Iteration 8 took: 5.531838893890381 seconds
Iteration 9 took: 5.351461887359619 seconds
Average Time: 4.973893356323242
S.Dev Time: 0.3337091298627308


# Save the final spikes for comparison

In [5]:
with open('spike_res.npy', 'wb') as f:
    np.save(f, spike_train)