In [None]:
import os
os.environ['CC'] = 'gcc'
os.environ['CXX'] = 'g++'
import brian2 as b2
import brian2cuda
b2.set_device("cuda_standalone", build_on_run=False)
b2.prefs.codegen.cpp.extra_compile_args_gcc = ['-std=c++17']
b2.prefs.devices.cuda_standalone.cuda_backend.extra_compile_args_nvcc = ['-std=c++17']

import numpy as np
import time
import tqdm

n = 1024 
n_ext = 256 
# Synaptic base weights (for scaling)
w_ext_base = 1.5  # Base weight for external
w_rec_base = 0.5  # Base weight for recurrent
sigma = n / 10  # Standard deviation of the Gaussian
# --- Kernels ---
neuron_indices = np.arange(n)
dist_i, dist_j = np.meshgrid(neuron_indices, neuron_indices)
distance = np.abs(dist_i - dist_j)
distance = np.minimum(distance, n - distance)
W_rec_matrix = np.exp(-0.5 * (distance / sigma)**2)
np.fill_diagonal(W_rec_matrix, 0)
W_rec_matrix = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
W_ext_matrix = (np.random.rand(n_ext, n)**2) * w_ext_base * (np.random.rand(n_ext,n) < (64/n_ext))
t = 50000
dt = 0.1
kernel_rec = W_rec_matrix
kernel_ext = W_ext_matrix

# Reset scope
b2.start_scope()
t = t * b2.ms
dt = dt * b2.ms
b2.defaultclock.dt = dt
# LIF Neuron parameters
tau = 20 * b2.ms
v_rest = -70 * b2.mV
v_thresh = -50 * b2.mV
v_reset = -70 * b2.mV
tau_ref = 5 * b2.ms
ker_scale = 

# External input
f_ext = 15 * b2.Hz

# Neurons
eqs_in = """
	dv/dt = (-(v-v_rest) +delta_T*exp((v-v_rheobase)/delta_T)+ R * input_current(t) - R * w)/(tau_m) : volt
	dw/dt = (a*(v-v_rest)-w)/tau_w : amp
"""
eqs = """
	dv/dt = (-(v-v_rest) +delta_T*exp((v-v_rheobase)/delta_T) - R * w)/(tau_m) : volt
	dw/dt = (a*(v-v_rest)-w)/tau_w : amp
"""
n_A1 = b2.NeuronGroup(
	n,
	model=eqs_in, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_A1.v = v_rest
n_B1 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_B1.v = v_rest
n_B2 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_B2.v = v_rest
n_B3 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_B3.v = v_rest
n_C1 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_C1.v = v_rest
n_C2 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_C2.v = v_rest
n_C3 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_C3.v = v_rest
n_D1 = b2.NeuronGroup(
	n,
	model=eqs, 
	reset='v=v_reset;w+=b', 
	threshold='v>firing_threshold',
	method='euler'
)
n_D1.v = v_rest

# Synapses

ker_A1_B1 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_A1_B1 = b2.Synapses(n_A1, n_B1, model='w_syn : volt', on_pre='v += w_syn')
s_A1_B1_src, s_A1_B1_tgt = ker_A1_B1.nonzero()
s_A1_B1.connect(i=s_A1_B1_src, j=s_A1_B1_tgt)
s_A1_B1.w_syn = ker_A1_B1[s_A1_B1_src, s_A1_B1_tgt] * b2.mV

ker_B1_B2 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_B1_B2 = b2.Synapses(n_B1, n_B2, model='w_syn : volt', on_pre='v += w_syn')
s_B1_B2_src, s_B1_B2_tgt = ker_B1_B2.nonzero()
s_B1_B2.connect(i=s_B1_B2_src, j=s_B1_B2_tgt)
s_B1_B2.w_syn = ker_B1_B2[s_B1_B2_src, s_B1_B2_tgt] * b2.mV

ker_B2_B3 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_B2_B3 = b2.Synapses(n_B2, n_B3, model='w_syn : volt', on_pre='v += w_syn')
s_B2_B3_src, s_B2_B3_tgt = ker_B2_B3.nonzero()
s_B2_B3.connect(i=s_B2_B3_src, j=s_B2_B3_tgt)
s_B2_B3.w_syn = ker_B2_B3[s_B2_B3_src, s_B2_B3_tgt] * b2.mV

ker_B3_D1 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_B3_D1 = b2.Synapses(n_B3, n_D1, model='w_syn : volt', on_pre='v += w_syn')
s_B3_D1_src, s_B3_D1_tgt = ker_B3_D1.nonzero()
s_B3_D1.connect(i=s_B3_D1_src, j=s_B3_D1_tgt)
s_B3_D1.w_syn = ker_B3_D1[s_B3_D1_src, s_B3_D1_tgt] * b2.mV

ker_A1_C1 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_A1_C1 = b2.Synapses(n_A1, n_C1, model='w_syn : volt', on_pre='v += w_syn')
s_A1_C1_src, s_A1_C1_tgt = ker_A1_C1.nonzero()
s_A1_C1.connect(i=s_A1_C1_src, j=s_A1_C1_tgt)
s_A1_C1.w_syn = ker_A1_C1[s_A1_C1_src, s_A1_C1_tgt] * b2.mV

ker_C1_C2 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_C1_C2 = b2.Synapses(n_C1, n_C2, model='w_syn : volt', on_pre='v += w_syn')
s_C1_C2_src, s_C1_C2_tgt = ker_C1_C2.nonzero()
s_C1_C2.connect(i=s_C1_C2_src, j=s_C1_C2_tgt)
s_C1_C2.w_syn = ker_C1_C2[s_C1_C2_src, s_C1_C2_tgt] * b2.mV

ker_C2_C3 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_C2_C3 = b2.Synapses(n_C2, n_C3, model='w_syn : volt', on_pre='v += w_syn')
s_C2_C3_src, s_C2_C3_tgt = ker_C2_C3.nonzero()
s_C2_C3.connect(i=s_C2_C3_src, j=s_C2_C3_tgt)
s_C2_C3.w_syn = ker_C2_C3[s_C2_C3_src, s_C2_C3_tgt] * b2.mV

ker_C3_D1 = (W_rec_matrix*W_rec_matrix) * np.random.rand(n, n) * w_rec_base * (np.random.rand(n,n) < (256/n))
s_C3_D1 = b2.Synapses(n_C3, n_D1, model='w_syn : volt', on_pre='v += w_syn')
s_C3_D1_src, s_C3_D1_tgt = ker_C3_D1.nonzero()
s_C3_D1.connect(i=s_C3_D1_src, j=s_C3_D1_tgt)
s_C3_D1.w_syn = ker_C3_D1[s_C3_D1_src, s_C3_D1_tgt] * b2.mV

#b2.device.build(clean=True)

start = time.time()
b2.run(t)
end = time.time()
print(end - start)

In [15]:
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import spark
import tqdm
from functools import partial

DT = 0.1
MEMBRANE_TIME_SCALE_TAU_M = 5.0
MEMBRANE_RESISTANCE_R = 500.0
V_REST = -70.0
V_RESET = -51.0
RHEOBASE_THRESHOLD_V_RH = -50.0
SHARPNESS_DELTA_T = 2.0
ADAPTATION_VOLTAGE_COUPLING_A = 0.5
ADAPTATION_TIME_CONSTANT_TAU_W = 100.0
SPIKE_TRIGGERED_ADAPTATION_INCREMENT_B = 7.0
FIRING_THRESHOLD_V_SPIKE = -30.0
INPUT_SPIKE_TIMES = [10, 20, 30, 32, 34, 36, 38, 55, 60, 62, 70, 84]
SYNAPSE_STRENGTH = 75.0


In [8]:
def adex_brain_config(
	dt,
	synapse_strength,
	potential_rest,
	potential_reset,
	potential_tau,
	resistance, # M立 -> G立
	threshold,
	rheobase_threshold,
	spike_slope,
	adaptation_tau,
	adaptation_delta,
	adaptation_subthreshold,
) -> spark.nn.BrainConfig:

	adex_config = spark.nn.neurons.AdExNeuronConfig(
		_s_units = (1024,),
		_s_dt = dt,
		_s_dtype = jnp.float16,
		inhibitory_rate = 0.0,
		soma_config = spark.nn.somas.AdaptiveExponentialSomaConfig(
			potential_rest = potential_rest,
			potential_reset = potential_reset,
			potential_tau = potential_tau,
			resistance = resistance / 1000, # M立 -> G立
			threshold = threshold,
			rheobase_threshold = rheobase_threshold,
			spike_slope = spike_slope,
			adaptation_tau = adaptation_tau,
			adaptation_delta = adaptation_delta,
			adaptation_subthreshold = adaptation_subthreshold,
		),
		synapses_config = spark.nn.synapses.LinearSynapsesConfig(
			units = (1,),
			kernel_initializer = spark.nn.initializers.ConstantInitializerConfig(scale=synapse_strength),
		),
		delays_config = None,
		learning_rule_config = None,
	)

	def adex_specs(name, origin, port) -> spark.ModuleSpecs:
		return spark.ModuleSpecs(
			name = name, 
			module_cls = spark.nn.neurons.AdExNeuron, 
			inputs = {
				'in_spikes': [
					spark.PortMap(origin=o, port=p) for o, p in zip(origin, port)
				]
			},
			config = adex_config
		)
	
	input_map = {
		'spikes': spark.InputSpec(
			payload_type=spark.FloatArray, 
			shape=(1024,), 
			dtype=jnp.float16,
		)
	}
	output_map = {
		'spikes': {
			'input': spark.PortMap(
				origin='n_D1',
				port='out_spikes'
			),
			'spec': spark.OutputSpec(
				payload_type=spark.SpikeArray,
				shape=(1024,),
				dtype=jnp.float16
			)
		}
	}
	modules_map = {
		'n_A1': adex_specs('n_A1', ['__call__'], ['spikes']),
		'n_B1': adex_specs('n_B1', ['n_A1'], ['out_spikes']),
		'n_B2': adex_specs('n_B2', ['n_B1'], ['out_spikes']),
		'n_B3': adex_specs('n_B3', ['n_B2'], ['out_spikes']),
		'n_C1': adex_specs('n_C1', ['n_A1'], ['out_spikes']),
		'n_C2': adex_specs('n_C2', ['n_C1'], ['out_spikes']),
		'n_C3': adex_specs('n_C3', ['n_C2'], ['out_spikes']),
		'n_D1': adex_specs('n_D1', ['n_B3', 'n_C3'], ['out_spikes', 'out_spikes']),
	}

	return spark.nn.BrainConfig(input_map=input_map, output_map=output_map, modules_map=modules_map)


brain_config = adex_brain_config(
	dt = DT,
	synapse_strength = SYNAPSE_STRENGTH,
	potential_rest = V_REST,
	potential_reset = V_RESET,
	potential_tau = MEMBRANE_TIME_SCALE_TAU_M,
	resistance = MEMBRANE_RESISTANCE_R,
	threshold = FIRING_THRESHOLD_V_SPIKE,
	rheobase_threshold = RHEOBASE_THRESHOLD_V_RH,
	spike_slope = SHARPNESS_DELTA_T,
	adaptation_tau = ADAPTATION_TIME_CONSTANT_TAU_W,
	adaptation_delta = SPIKE_TRIGGERED_ADAPTATION_INCREMENT_B ,
	adaptation_subthreshold = ADAPTATION_VOLTAGE_COUPLING_A,	
)

In [12]:
brain = spark.nn.Brain(config=brain_config)
brain(spikes=spark.SpikeArray( jnp.zeros((1024,)) ))

{'spikes': SpikeArray(value=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float16), async_spikes=False)}

In [None]:
brain(spikes=spark.SpikeArray( jnp.ones((1024,)) ))

{'spikes': SpikeArray(value=Array([0., 0., 0., ..., 0., 0., 0.], dtype=float16), async_spikes=False)}

In [18]:
graph, state = spark.split((brain))

In [17]:
@partial(jax.jit, static_argnames='k')
def run_model_k_steps(graph, state, k, **inputs):
    model = spark.merge(graph, state)
    for i in range(k):
        outputs = model(**inputs)
    _, state = spark.split((model))
    return outputs, state

In [20]:
outputs, state = run_model_k_steps(graph, state, 10, spikes=spark.SpikeArray( jnp.ones((1024,)) ))

In [22]:
import time

In [23]:
start = time.time()
for i in range(1000):
    outputs, state = run_model_k_steps(graph, state, 10, spikes=spark.SpikeArray( jnp.ones((1024,)) ))
end = time.time()
print(end-start)

0.4832925796508789
