In [22]:
import sys
sys.path.insert(1, './..')

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from matplotlib.patches import Rectangle
from math import prod
from tqdm import tqdm

import spark
from spark.nn.interfaces import (PoissonSpiker, LinearSpiker, TopologicalPoissonSpiker, TopologicalLinearSpiker,
                                 ExponentialIntegrator, Merger, Sampler)
from spark.nn.components import SimpleSynapses, TracedSynapses, NDelays, N2NDelays, ALIFSoma
from spark.nn.neurons import ALIFNeuron
from spark import SpikeArray, CurrentArray, FloatArray, Variable, Constant, IntegerArray
from spark.core.specs import ModuleSpecs, PortSpecs, PortMap

In [None]:
# Note that the input map can be see as the output of a virtual node (the environment.)
input_map = {
    'drive': PortSpecs(payload_type=FloatArray, shape=(4,), dtype=jnp.float16)
}
output_map = {
    'integrator': {
        'output': PortSpecs(payload_type=FloatArray, shape=(2,), dtype=jnp.float16)
	}
}
modules_map = {
   'spiker': ModuleSpecs(
        name ='spiker', 
		module_cls = TopologicalLinearSpiker, 
		inputs = {
        	'drive': [
                PortMap(origin='__call__', port='drive'),
        	]
    	},
		init_args = {
            'input_shape':(4,), 
			'glue': jnp.array(0), 
			'mins': jnp.array(-1),  
			'maxs': jnp.array(1), 
			'resolution': 128, 
			'max_freq': 200.0, 
			'tau': 30.0
		}),
    'A_ex': ModuleSpecs(
        name ='A_ex', 
		module_cls = ALIFNeuron, 
		inputs = {
        	'input_spikes': [
                PortMap(origin='spiker', port='spikes'),
                PortMap(origin='A_ex', port='spikes'),
                PortMap(origin='B_in', port='spikes'),
        	]
    	},
		init_args = {
            'input_shape': (4*128+256+64,), 
			'output_shape': (256,), 
			'synapses_params': {
                'kernel_scale': 3.0
                }, 
			'soma_params': {
				'threshold_tau': 25.0 * jax.random.uniform(jax.random.key(43), shape=(256,), dtype=jnp.float16)**2,
				'threshold_delta': 250.0 * jax.random.uniform(jax.random.key(43), shape=(256,), dtype=jnp.float16)**2,
				'cooldown':2.0,  
			},
			'inhibitory_rate': 0.0,
            'max_delay': 1
		}),
    'A_in': ModuleSpecs(
        name ='A_in', 
		module_cls = ALIFNeuron, 
		inputs = {
        	'input_spikes': [
                PortMap(origin='spiker', port='spikes'),
                PortMap(origin='A_ex', port='spikes'),
        	]
    	},
		init_args = {
            'input_shape': (4*128+256,), 
			'output_shape': (64,), 
			'synapses_params': {
                'kernel_scale': 4.0
                }, 
			'soma_params': {
				'threshold_tau':1.0,
				'threshold_delta':0.0, 
				'cooldown':2.0,  
			},
			'inhibitory_rate': 1.0,
            'max_delay': 1
		}),
	'B_ex': ModuleSpecs(
        name ='B_ex', 
		module_cls = ALIFNeuron, 
		inputs = {
        	'input_spikes': [
                PortMap(origin='spiker', port='spikes'),
                PortMap(origin='B_ex', port='spikes'),
                PortMap(origin='A_in', port='spikes'),
        	]
    	},
		init_args = {
            'input_shape': (4*128+256+64,), 
			'output_shape': (256,), 
			'synapses_params': {
                'kernel_scale': 3.0
                }, 
			'soma_params': {
				'threshold_tau': 25.0 * jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float16)**2,
				'threshold_delta': 250.0 * jax.random.uniform(jax.random.key(42), shape=(256,), dtype=jnp.float16)**2,
				'cooldown':2.0,  
			},
			'inhibitory_rate': 0.0,
            'max_delay': 1
		}),
	'B_in': ModuleSpecs(
        name ='B_in', 
		module_cls = ALIFNeuron, 
		inputs = {
        	'input_spikes': [
                PortMap(origin='spiker', port='spikes'),
                PortMap(origin='B_ex', port='spikes'),
        	]
    	},
		init_args = {
            'input_shape': (4*128+256,), 
			'output_shape': (64,), 
			'synapses_params': {
                'kernel_scale': 4.0
                }, 
			'soma_params': {
				'threshold_tau':1.0,
				'threshold_delta':0.0, 
				'cooldown':2.0,  
			},
			'inhibitory_rate': 1.0,
            'max_delay': 1
		}),
	'integrator': ModuleSpecs(
        name ='integrator', 
		module_cls = ExponentialIntegrator, 
		inputs = {
        	'input_spikes': [
                PortMap(origin='A_ex', port='spikes'),
                PortMap(origin='B_ex', port='spikes'),
        	]
    	},
		init_args = {
            'input_shape': (256+256,), 
			'output_dim': 2
		}),
}

brain = spark.Brain(input_map=input_map, output_map=output_map, modules_map=modules_map)

In [None]:
@jax.jit
def run_model(graph, state, x):
	model = nnx.merge(graph, state)
	out, spikes = model(drive=x)
	_, state = nnx.split((model))
	return out, spikes, state

def process_obs(x):
	# CartPos, CartSpeed, PoleAngle, PoleAngSpeed 
	x = x / np.array([2.4, 2.5, 0.2095, 3.5])
	x = np.clip(x, a_min=-1, a_max=1)
	return x

def compute_real_reward(x, x_prev, r_prev, terminated):
	# CartPos, CartSpeed, PoleAngle, PoleAngSpeed 
	if terminated:
		return 0
	r = (x_prev[0]**2 - x[0]**2) + (x_prev[2]**2 - x[2]**2)
	r = np.clip(0.5 * r_prev + 2 * r, a_min=-1, a_max=1)
	return r

In [None]:
import time
import gymnasium as gym
import ale_py
import numpy as np
gym.register_envs(ale_py)

env_name =  'CartPole-v1'

env = gym.make(env_name)
next_obs, _ = env.reset(seed=42)
next_obs = process_obs(next_obs)

model = spark.Brain(input_map=input_map, output_map=output_map, modules_map=modules_map)
model(drive=FloatArray(jnp.zeros((4,), dtype=jnp.float16)))
graph, state = nnx.split((model))
#starting_kernel = model.neurons.synapses.get_flat_kernel()
brain_steps_per_env_step = 10

if False:
	reward = 0
	reward_array = []
	for i in tqdm(range(5000)):
		prev_obs = next_obs
		# Model logic
		out, model_spikes, state = run_model(graph, state, FloatArray(jnp.array(next_obs, dtype=jnp.float16)))
		# Environment logic.
		next_action = int(np.argmax(out['integrator.output'].value))
		next_obs, _, terminated, truncated, info = env.step(next_action)
		if terminated:
			next_obs, _ = env.reset()
			# Flush model
			for i in range(16):
				_, _, state = run_model(graph, state, FloatArray(jnp.zeros_like(next_obs, dtype=jnp.float16)))
		next_obs = process_obs(next_obs)
		reward = compute_real_reward(next_obs, prev_obs, reward, terminated)
		reward_array.append(reward)

outs = []
spikes = []
obs = []
breaks = []
break_obs = []
actions = []
reward = 0
next_obs, _ = env.reset(seed=42+1)
next_obs = process_obs(next_obs)
for i in tqdm(range(100)):
	prev_obs = next_obs
	# Model logic
	for _ in range(brain_steps_per_env_step):
		out, model_spikes, state = run_model(graph, state, FloatArray(jnp.array(next_obs, dtype=jnp.float16)))
		outs.append(out['integrator.output'].value)
		spikes.append(jnp.concatenate([s.value.reshape(-1) for s in model_spikes]))
		# Environment logic.
		next_action = int(np.argmax(out['integrator.output'].value))
		actions.append(next_action)
	next_obs, _, terminated, truncated, info = env.step(next_action)
	if terminated:
		break_obs.append(next_obs)
		next_obs, _ = env.reset()
		breaks.append(brain_steps_per_env_step*i)
		# Flush model
		for i in range(50):
			_, _, state = run_model(graph, state, FloatArray(jnp.zeros_like(next_obs, dtype=jnp.float16)))
	next_obs = process_obs(next_obs)
	reward = compute_real_reward(next_obs, prev_obs, reward, terminated)
	obs.append(next_obs)
	
model = nnx.merge(graph, state)

spikes = np.abs(np.array(spikes))
fig, ax = plt.subplots(2,1,figsize=(20,10), height_ratios=(8,2))
ax[0].imshow(1-spikes.T, cmap='gray', aspect='auto', interpolation='none')
for b in breaks:
    ax[0].plot([b,b], [0-0.5,len(spikes)-0.5], 'r--', alpha=0.1)
for i in range(3):
    ax[0].plot([0-0.5,len(spikes)-0.5], [128*(i+1), 128*(i+1)], 'g--', alpha=0.1)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[0]+64, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[1]+64+128, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[2]+64+256, alpha=0.4)
ax[0].plot(brain_steps_per_env_step*np.arange(len(spikes)//brain_steps_per_env_step), 64*np.array(obs).T[3]+64+128+256, alpha=0.4)
ax[1].plot(actions)
ax[1].set_xlim(0, len(actions))
plt.tight_layout()
plt.show()
if False:
	plt.imshow(starting_kernel, aspect='auto', interpolation='none')
	plt.colorbar()
	plt.show()
	plt.imshow(final_kernel, aspect='auto', interpolation='none')
	plt.colorbar()
	plt.show()

In [1]:
%gui qt
import sys
sys.path.insert(1, './..')
import spark

In [2]:
editor = spark.SparkGraphEditor()
if __name__ == "__main__":
    editor.launch()

<class 'spark.nn.components.somas.ALIFSomaConfig'> False
<class 'spark.nn.components.synapses.SimpleSynapsesConfig'> False
<class 'spark.nn.initializers.kernel.KernelInitializerConfig'> True
hi-ho
<class 'spark.nn.components.delays.N2NDelaysConfig'> False
<class 'spark.nn.initializers.delay.DelayInitializerConfig'> True
hi-ho
<class 'spark.nn.components.learning_rules.HebbianLearningConfig'> False
