In [1]:
import sys
sys.path.insert(1, './..')

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from matplotlib.patches import Rectangle
from math import prod
from tqdm import tqdm
import spark

In [2]:
input_map = {
    'signal': spark.PortSpecs(payload_type=spark.FloatArray, shape=(4,), dtype=jnp.float16)
}
output_map = {
    'integrator': {
        'signal': spark.PortSpecs(payload_type=spark.FloatArray, shape=(2,), dtype=jnp.float16)
	}
}
modules_map = {
   'spiker': spark.ModuleSpecs(
        name ='spiker', 
		module_cls = spark.nn.interfaces.TopologicalLinearSpiker,
        inputs = {
        	'signal': [
                spark.PortMap(origin='__call__', port='signal'),
        	]
    	},
		config = spark.nn.interfaces.TopologicalLinearSpikerConfig(
			glue = jnp.array(0), 
			mins = jnp.array(-1),  
			maxs = jnp.array(1), 
            resolution = 128,
            max_freq = 200,
            tau = 30.0,
		)
	),
    'A_ex': spark.ModuleSpecs(
        name ='A_ex', 
		module_cls = spark.nn.neurons.ALIFNeuron, 
		inputs = {
        	'in_spikes': [
                spark.PortMap(origin='spiker', port='spikes'),
                spark.PortMap(origin='A_ex', port='out_spikes'),
                spark.PortMap(origin='B_in', port='out_spikes'),
        	]
    	},
		config = spark.nn.neurons.ALIFNeuronConfig(
			units = (4*128+256+64,), 
            _s_target_units = (256,),
            inhibitory_rate = 0.0,
			_s_async_spikes = True
		)
	),
    'A_in': spark.ModuleSpecs(
        name ='A_in', 
		module_cls = spark.nn.neurons.ALIFNeuron, 
		inputs = {
        	'in_spikes': [
                spark.PortMap(origin='spiker', port='spikes'),
                spark.PortMap(origin='A_ex', port='out_spikes'),
        	]
    	},
		config = spark.nn.neurons.ALIFNeuronConfig(
			units = (4*128+256,), 
            _s_target_units = (64,),
            inhibitory_rate = 1.0,
			_s_async_spikes = True
		)
	),
	'B_ex': spark.ModuleSpecs(
        name ='B_ex', 
		module_cls = spark.nn.neurons.ALIFNeuron, 
		inputs = {
        	'in_spikes': [
                spark.PortMap(origin='spiker', port='spikes'),
                spark.PortMap(origin='B_ex', port='out_spikes'),
                spark.PortMap(origin='A_in', port='out_spikes'),
        	]
    	},
		config = spark.nn.neurons.ALIFNeuronConfig(
			units = (4*128+256+64,), 
            _s_target_units = (256,),
            inhibitory_rate = 0.0,
			_s_async_spikes = True
		)
	),
	'B_in': spark.ModuleSpecs(
        name ='B_in', 
		module_cls = spark.nn.neurons.ALIFNeuron, 
		inputs = {
        	'in_spikes': [
                spark.PortMap(origin='spiker', port='spikes'),
                spark.PortMap(origin='B_ex', port='out_spikes'),
        	]
    	},
		config = spark.nn.neurons.ALIFNeuronConfig(
			units = (4*128+256,), 
            _s_target_units = (64,),
            inhibitory_rate = 1.0,
			_s_async_spikes = True
		)
	),
	'integrator': spark.ModuleSpecs(
        name ='integrator', 
		module_cls = spark.nn.interfaces.ExponentialIntegrator, 
		inputs = {
        	'spikes': [
                spark.PortMap(origin='A_ex', port='out_spikes'),
                spark.PortMap(origin='B_ex', port='out_spikes'),
        	]
    	},
		config = spark.nn.interfaces.ExponentialIntegratorConfig(
            num_outputs = 2
		)
	),
}

brain_config = spark.nn.BrainConfig(input_map=input_map, output_map=output_map, modules_map=modules_map)

In [3]:

# Load your configuration
brain_config = spark.nn.BrainConfig(input_map=input_map, output_map=output_map, modules_map=modules_map)

# Build the model
brain = spark.nn.Brain(config=brain_config)

# Execute the model
brain(signal=spark.FloatArray(jnp.zeros((4,), dtype=jnp.float16)))


TypeError: mul got incompatible shapes for broadcasting: (256,), (832,).

In [5]:
isinstance(brain.__class__, type)

True

In [7]:
brain_config.modules_map

AttributeError: 'SimpleSynapsesConfig' object has no attribute 'target_units'

In [6]:
spark.nn.synapses.SimpleSynapsesConfig(target_units=(10,), async_spikes=True)

SimpleSynapsesConfig(seed=1712792851, dtype=<class 'jax.numpy.float16'>, dt=1.0, target_units=(10,), async_spikes=True, kernel_initializer=SparseUniformKernelInitializerConfig(name='sparse_uniform_kernel_initializer', dtype=<class 'jax.numpy.float16'>, density=0.2))

In [9]:
x = {'a':1, 'b':2, 'c':3,}
y = {'d':4}



def test(**kwargs):
    print(kwargs)
    
test(**{**x, **y})

{'a': 1, 'b': 2, 'c': 3, 'd': 4}


In [None]:
@jax.jit
def run_model(graph, state, x):
	model = nnx.merge(graph, state)
	out, spikes = model(drive=x)
	_, state = nnx.split((model))
	return out, spikes, state

def process_obs(x):
	# CartPos, CartSpeed, PoleAngle, PoleAngSpeed 
	x = x / np.array([2.4, 2.5, 0.2095, 3.5])
	x = np.clip(x, a_min=-1, a_max=1)
	return x

def compute_real_reward(x, x_prev, r_prev, terminated):
	# CartPos, CartSpeed, PoleAngle, PoleAngSpeed 
	if terminated:
		return 0
	r = (x_prev[0]**2 - x[0]**2) + (x_prev[2]**2 - x[2]**2)
	r = np.clip(0.5 * r_prev + 2 * r, a_min=-1, a_max=1)
	return r

In [None]:
import time
import gymnasium as gym
import ale_py
import numpy as np
gym.register_envs(ale_py)

env_name =  'CartPole-v1'

env = gym.make(env_name)
next_obs, _ = env.reset(seed=42)
next_obs = process_obs(next_obs)

model = spark.Brain(input_map=input_map, output_map=output_map, modules_map=modules_map)
model(drive=spark.FloatArray(jnp.zeros((4,), dtype=jnp.float16)))
graph, state = nnx.split((model))
#starting_kernel = model.neurons.synapses.get_flat_kernel()
brain_steps_per_env_step = 10

if False:
	reward = 0
	reward_array = []
	for i in tqdm(range(5000)):
		prev_obs = next_obs
		# Model logic
		out, model_spikes, state = run_model(graph, state, spark.FloatArray(jnp.array(next_obs, dtype=jnp.float16)))
		# Environment logic.
		next_action = int(np.argmax(out['integrator.output'].value))
		next_obs, _, terminated, truncated, info = env.step(next_action)
		if terminated:
			next_obs, _ = env.reset()
			# Flush model
			for i in range(16):
				_, _, state = run_model(graph, state, spark.FloatArray(jnp.zeros_like(next_obs, dtype=jnp.float16)))
		next_obs = process_obs(next_obs)
		reward = compute_real_reward(next_obs, prev_obs, reward, terminated)
		reward_array.append(reward)

outs = []
spikes = []
obs = []
breaks = []
break_obs = []
actions = []
reward = 0
next_obs, _ = env.reset(seed=42+1)
next_obs = process_obs(next_obs)
for i in tqdm(range(100)):
	prev_obs = next_obs
	# Model logic
	for _ in range(brain_steps_per_env_step):
		out, model_spikes, state = run_model(graph, state, spark.FloatArray(jnp.array(next_obs, dtype=jnp.float16)))
		outs.append(out['integrator.output'].value)
		spikes.append(jnp.concatenate([s.value.reshape(-1) for s in model_spikes]))
		# Environment logic.
		next_action = int(np.argmax(out['integrator.output'].value))
		actions.append(next_action)
	next_obs, _, terminated, truncated, info = env.step(next_action)
	if terminated:
		break_obs.append(next_obs)
		next_obs, _ = env.reset()
		breaks.append(brain_steps_per_env_step*i)
		# Flush model
		for i in range(50):
			_, _, state = run_model(graph, state, spark.FloatArray(jnp.zeros_like(next_obs, dtype=jnp.float16)))
	next_obs = process_obs(next_obs)
	reward = compute_real_reward(next_obs, prev_obs, reward, terminated)
	obs.append(next_obs)
	
model = nnx.merge(graph, state)

spikes = np.abs(np.array(spikes))
fig, ax = plt.subplots(2,1,figsize=(20,10), height_ratios=(8,2))
ax[0].imshow(1-spikes.T, cmap='gray', aspect='auto', interpolation='none')
for b in breaks:
    ax[0].plot([b,b], [0-0.5,len(spikes)-0.5], 'r--', alpha=0.1)
for i in range(3):
    ax[0].plot([0-0.5,len(spikes)-0.5], [128*(i+1), 128*(i+1)], 'g--', alpha=0.1)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[0]+64, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[1]+64+128, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[2]+64+256, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[3]+64+128+256, alpha=0.4)
ax[1].plot(actions)
ax[1].set_xlim(0, len(actions))
plt.tight_layout()
plt.show()
if False:
	plt.imshow(starting_kernel, aspect='auto', interpolation='none')
	plt.colorbar()
	plt.show()
	plt.imshow(final_kernel, aspect='auto', interpolation='none')
	plt.colorbar()
	plt.show()

In [4]:
# Add QT to the update loop. 
# Makes the editor non-blocking.
%gui qt
import sys
sys.path.insert(1, './..')

import spark
editor = spark.SparkGraphEditor()
# Start editor on the main thread.
if __name__ == '__main__':
	editor.launch()

In [4]:
input_map, output_map, modules_map = editor.compile_model()
model_def = {
	'input_map': input_map,
	'output_map': output_map,
	'modules_map': modules_map,
}

In [None]:
from spark.core.specs import ModuleSpecs, PortSpecs, PortMap
from spark.core.registry import REGISTRY
from spark.core.module import SparkModule
import typing as tp
import copy

def from_json(model_json: dict[str, dict]) -> dict[str, tp.Any]:
	# Deepcopy to prevent overrides to original JSON.
	model_json = copy.deepcopy(model_json)
	# Reconstruct input_map
	input_map: dict[str, PortSpecs] = {}
	for name, map in model_json['input_map'].items():
		input_map[name] = PortSpecs(**map) 
	# Reconstruct output_map
	output_map: dict[str, dict[str, PortSpecs]] = {}
	for name, map in model_json['output_map'].items():
		port_map = map.pop('port_maps')[0]
		map.pop('is_optional')
		origin, port_name = port_map['origin'], port_map['port']
		if not origin in output_map:
			output_map[origin] = {}
		output_map[origin][port_name] = PortSpecs(**map) 
	# Reconstruct modules_map
	modules_map: dict[str, dict[str, spark.ModuleSpecs]] = {}
	for name, map in model_json['modules_map'].items():
		class_ref: type[SparkModule] = REGISTRY.MODULES.get(map['module_cls']).class_ref
		port_maps: dict[str, list[spark.PortMap]] = {}
		for input_port_name in map['inputs']:
			port_maps[input_port_name] = [
				spark.PortMap(origin=pm['origin'], port=pm['port']) for pm in map['inputs'][input_port_name]
			]
		modules_map[name] = spark.ModuleSpecs(
			name = name,
			module_cls = class_ref,
			inputs = port_maps,
			config = class_ref.get_default_config_class()(**map['config']),
		)
	return {
		'input_map': input_map, 
		'output_map': output_map,
		'modules_map': modules_map,
	}

from_json(model_def)

In [None]:
x = f'abc {1}' \
	f'fgh {2}'
x