# namedtuple

## Goals:

- understand how to use `namedtuple` to keep your code clean

## Concepts:

- namedtuple
- scan

In [None]:
import numpy as np
from typing import TypeAlias
import time

from hyperparameters import (
    _dt,
    _t_max,
    _tau_m,
    _V_reset,
    _V_thresh,
    _R,
    num_simulations
)
import jax.numpy as jnp
import jax
from tqdm.notebook import tqdm

with open('weights.npy', 'rb') as f:
    W = np.load(f)

# Initial conditions
n_neurons = len(W)# Number of neurons in the network
_V = jnp.ones(n_neurons) * _V_reset  # Initial potentials

# Type Definitions for Clarity

In [None]:
Tensor3D: TypeAlias = jnp.ndarray
Mat: TypeAlias = jnp.ndarray
Vec: TypeAlias = jnp.ndarray 

# Adding Functionality:

Our LIF code is fine, but one way we can make this slightly more biologically plausible is by adding a constant input to the neurons, a sort of environmental noise or ongoing background synaptic activity that's not exactly a part of our modeling process.

This is a simple change and involves adding a term to `dV = (-V + ... + I_bg)`

In [None]:
def run_step(v_prev, v_thresh, v_reset, W, tau_m, dt, membr_R,
            I_bg
            ):
    spiked = v_prev >= v_thresh
    V = jnp.where(spiked, v_reset, v_prev)

    I_syn = W @ spiked.astype(jnp.float32)  # Synaptic current from spikes
    dV = (dt / tau_m) * (-V + v_reset + membr_R * I_syn + I_bg)
    V = V + dV

    V = jnp.where(spiked, v_reset, V)
    return V, spiked


def scan_step(carry, _):
    (V, v_thresh, v_reset, W, tau_m, dt, membr_R, I_bg) = carry
    new_V, spike = run_step(V, v_thresh, v_reset, W, tau_m, dt, membr_R, I_bg)
    return (new_V, v_thresh, v_reset, W, tau_m, dt, membr_R, I_bg), spike

def run_simulation(W, V, tau_m, v_reset, v_thresh, membr_R, t_max, dt, I_bg=0):
    num_steps = int(t_max / dt)
    # Run the scan over the number of time steps
    final_V, accum_spikes = jax.lax.scan(
        f=scan_step, 
        init=(V, v_thresh, v_reset, W, tau_m, dt, membr_R, I_bg), 
        xs=jnp.arange(num_steps)
    )
    return accum_spikes

In [None]:
time_arr = []
for i in range(num_simulations):
    start = time.time()
    spike_train = run_simulation(
        jnp.asarray(W),
        _V,
        _tau_m, _V_reset, _V_thresh, _R,
        _t_max, _dt
    )
    end = time.time()
    np.asarray(spike_train)
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")

# Simplifying the code using a NamedTuple

Modifying the code above was more involved than it should have been, and can easily grow out of hand once we have many more parameters. 

Some of my own code, that inspired this tutorial, has over 15 parameters and at this point I'm afraid to touch it. I wish that someone had sat me down and talked to me about using these containers

In [None]:
from collections import namedtuple

NeuronParams = namedtuple("NeuronParams", 
                          ["v_thresh", "v_reset", "W", "tau_m", "dt", "membr_R", "I_bg"]
                         )


# Run the Simulations

In [None]:

@jax.jit
def run_step(v_prev, params):
    spiked = v_prev >= params.v_thresh
    V = jnp.where(spiked, params.v_reset, v_prev)

    I_syn = params.W @ spiked.astype(jnp.float32)  # Synaptic current from spikes
    dV = (params.dt / params.tau_m) * (
        -V + params.v_reset + params.membr_R * I_syn + params.I_bg
    )
    V = V + dV

    V = jnp.where(spiked, params.v_reset, V)
    return V, spiked


def scan_step(carry, _):
    (V, params) = carry
    new_V, spike = run_step(V, params)
    return (new_V, params), spike

def run_simulation(V: jnp.array, t_max: float, params: NeuronParams):
    num_steps = int(t_max / params.dt)
    # Run the scan over the number of time steps
    final_V, accum_spikes = jax.lax.scan(
        f=scan_step, 
        init=(V, params), 
        xs=jnp.arange(num_steps)
    )
    return accum_spikes

In [None]:
time_arr = []
for i in range(num_simulations):
    params = NeuronParams(v_thresh=_V_thresh, v_reset=_V_reset, 
                          W=jnp.asarray(W), tau_m=_tau_m, dt=_dt, membr_R=_R,
                          I_bg=0.0
                         )
    start = time.time()
    spike_train = run_simulation(
        _V,
        _t_max,
        params
    )
    end = time.time()
    np.asarray(spike_train)
    print(f"Iteration {i} took: {end - start} seconds")
    time_arr.append(end - start)

print(f"Average Time: {np.mean(time_arr)}")
print(f"S.Dev Time: {np.std(time_arr)}")