Replication of Florian 2007 XOR gate experiments.
* Rate based input coding
* Temporal pattern coding

https://www.florian.io/papers/2007_Florian_Modulated_STDP.pdf

Florian R (2007) Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity. Neural Computation 19(6). https://doi.org/10.1162/neco.2007.19.6.1468

In [1]:
import numpy as np

from spikey.snn import *
from spikey.core import GenericLoop, RLCallback
from spikey.RL import Logic
from spikey.viz import print_rates

np.random.seed(0)

In [2]:
def print_w_diffs(callback, training_params, layer_cutoff=None):
    network = callback.network
    info = callback.info

    layer_cutoff = layer_cutoff or network._n_inputs

    original_w = info['weights_original']
    final_w = network.synapses.weights.matrix

    print(f"{np.sum(original_w[:, :layer_cutoff]):.0f} -> {np.sum(final_w[:, :layer_cutoff]):.0f}")
    print(f"{np.sum(original_w[:, layer_cutoff:]):.0f} -> {np.sum(final_w[:, layer_cutoff:]):.0f}")

In [3]:
def print_success(callback, training_params):
    info = callback.info

    states = np.array(info['step_states']).reshape((-1, 2))
    inrates = np.array(info['step_inrates']).reshape((-1))
    sysrates = np.array(info['step_sysrates']).reshape((-1))
    outrates = np.array(info['step_outrates']).reshape((-1))

    HIGH = [[False, True], [True, False]]
    LOW =  [[False, False], [True, True]]

    relevant_timeframe = training_params['eval_steps'] // 4

    high_rate = min([np.mean(outrates[np.all(states == state, axis=1)][-relevant_timeframe:]) for state in HIGH])
    low_rate = max([np.mean(outrates[np.all(states == state, axis=1)][-relevant_timeframe:]) for state in LOW])

    florian_win = high_rate > low_rate + .05

    correct = 0
    for i in range(training_params['eval_steps']):
        state = states[-i]
        rate = outrates[-i]

        if np.sum(state) % 2:
            correct += int(rate > low_rate)
        else:
            correct += int(rate < high_rate)

    florian_accuracy = correct / training_params['eval_steps']

    print(f"Florian - Win: {florian_win}, Accuracy: {florian_accuracy}")

In [4]:
def print_runtime(callback):
    print(f"{callback.results['total_time']:.2f}s")

## Rate Coding

In [5]:
training_params = {
    'n_episodes': 1,
    'len_episode': 800,
    'eval_steps': 50, 
}

In [6]:
N_INPUTS = 60
N_NEURONS = 61
N_OUTPUTS = 1
PROCESSING_TIME = 500

w_matrix = np.vstack((  # Feedforward w/ 1 hidden layer
    np.hstack((
        np.random.uniform(0, .2, (N_INPUTS, N_NEURONS - N_OUTPUTS)),
        np.zeros((N_INPUTS, N_OUTPUTS)))),
    np.hstack((
        np.zeros((N_NEURONS - N_OUTPUTS, N_NEURONS - N_OUTPUTS)),
        np.random.uniform(0, .2, (N_NEURONS - N_OUTPUTS, N_OUTPUTS)))),
    np.zeros((N_OUTPUTS, N_NEURONS)),
))

class network_template(ContinuousRLNetwork):
    parts = {
        'inputs': input.RateMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDPET,
        'weights': weight.Manual,
        'readout': readout.Threshold,
        'rewarder': reward.MatchExpected,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,

        'input_pct_inhibitory': .5,
        'neuron_pct_inhibitory': 0,
        'magnitude': 1,
        'firing_threshold': 16,
        'refractory_period': 0,  # Gutig, Aharonov, Rotter, & Sompolinsky 2003
        'prob_rand_fire': .15,  # Seemingly 0 in paper but >0 was needed
        'potential_decay': .05,  # Decay constant Tau=20ms, lambda=e^(-t/T)
        'trace_decay': .04,  # T_z = 25, lambda = e^(-1/T_z)
        "punish_mult": 1,

        'processing_time': PROCESSING_TIME,
        'learning_rate': .625 / 25,  # gamma_0 = gamma / Tau_z
        'max_weight': 5,
        'stdp_window': 20,  # Tau_+ = Tau_- = 20ms
        'action_threshold': 0,  # Makes network always output True

        'continuous_rwd_action': lambda *a: True,
        'state_rate_map': [0, 40 / PROCESSING_TIME],   # 0hz in False groups, 40hz in True groups
    }

In [7]:
# Control, without learning
training_loop = GenericLoop(network_template, Logic(preset="XOR"), measure_rates=True, **training_params)
training_loop.reset(**{'learning_rate': 0, 'len_episode': 50})
e_output = training_loop()

callback = training_loop.callback
print_rates(callback=callback)
print_w_diffs(callback, training_params, layer_cutoff=None)
print_success(callback, training_params)
print_runtime(callback)

(False, False): 0.00 -> 0.15
(False, True): 0.04 -> 0.15
(True, False): 0.04 -> 0.16
(True, True): 0.08 -> 0.14
362 -> 362
6 -> 6
Florian - Win: False, Accuracy: 0.56
17.53s


In [8]:
# Real experiment
training_loop = GenericLoop(network_template, Logic(preset="XOR"), measure_rates=True, **training_params)
e_output = training_loop()

callback = training_loop.callback
print_rates(callback=callback)
print_w_diffs(callback, training_params, layer_cutoff=None)
print_success(callback, training_params)
print_runtime(callback)

(False, False): 0.00 -> 0.30
(False, True): 0.04 -> 0.96
(True, False): 0.04 -> 0.97
(True, True): 0.08 -> 0.30
362 -> 7510
6 -> 296
Florian - Win: True, Accuracy: 1.0
306.01s


## Temporal Coding

In [9]:
training_params = {
    'n_episodes': 1,
    'len_episode': 800,
    'eval_steps': 50,
}

In [10]:
N_INPUTS = 2
N_NEURONS = 21
N_OUTPUTS = 1
PROCESSING_TIME = 500

spike_train_map = {  # Static 100hz spike trains in response to stimulus
    False: np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001),
    True: np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001),
}
input_map = {
    (A, B): np.hstack((spike_train_map[A], spike_train_map[B]))
    for A in [False, True] for B in [False, True]
}

N_HIDDEN = N_NEURONS - N_OUTPUTS
w_matrix = np.vstack((  # Feedforward w/ 1 hidden layer
    np.hstack((np.random.uniform(0, .4, (N_INPUTS, N_HIDDEN)), np.zeros((N_INPUTS, N_OUTPUTS)))),
    np.hstack((np.zeros((N_HIDDEN, N_HIDDEN)), np.random.uniform(0, .4, (N_HIDDEN, N_OUTPUTS)))),
    np.zeros((N_OUTPUTS, N_NEURONS)),
))

class network_template(ContinuousRLNetwork):
    parts = {
        'inputs': input.StaticMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDPET,
        'weights': weight.Manual,
        'readout': readout.Threshold,
        'rewarder': reward.MatchExpected,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,

        'input_pct_inhibitory': .5,
        'neuron_pct_inhibitory': 0,
        'magnitude': 1,
        'firing_threshold': 16,
        'refractory_period': 0,  # Gutig, Aharonov, Rotter, & Sompolinsky 2003
        'prob_rand_fire': .15,
        'potential_decay': .05,  # Decay constant Tau=20ms, lambda=e^(-t/T)
        'trace_decay': .04,  # T_z = 25, lambda = e^(-1/T_z)
        "punish_mult": 1,

        'processing_time': PROCESSING_TIME,
        'learning_rate': .25 / 25,  # gamma_0 = gamma / Tau_z
        'max_weight': 5,
        'stdp_window': 20,  # Tau_+ = Tau_- = 20ms
        'action_threshold': 0,  # Makes network always output True

        'continuous_rwd_action': lambda *a: True,
        'state_spike_map': input_map,
    }

In [11]:
# Control, without learning
training_loop = GenericLoop(network_template, Logic(preset="XOR"), measure_rates=True, **training_params)
training_loop.reset(**{'learning_rate': 0, 'len_episode': 50})
e_output = training_loop()

callback = training_loop.callback
print_rates(callback=callback)
print_w_diffs(callback, training_params, layer_cutoff=None)
print_success(callback, training_params)
print_runtime(callback)

(False, False): 0.00 -> 0.15
(False, True): 0.00 -> 0.15
(True, False): 0.00 -> 0.15
(True, True): 0.00 -> 0.14
1 -> 1
11 -> 11
Florian - Win: False, Accuracy: 0.5
10.35s


In [12]:
# Real experiment
training_loop = GenericLoop(network_template, Logic(preset="XOR"), measure_rates=True, **training_params)
e_output = training_loop()

callback = training_loop.callback
print_rates(callback=callback)
print_w_diffs(callback, training_params, layer_cutoff=None)
print_success(callback, training_params)
print_runtime(callback)

(False, False): 0.00 -> 0.21
(False, True): 0.00 -> 0.56
(True, False): 0.00 -> 0.57
(True, True): 0.00 -> 0.21
1 -> 2
11 -> 191
Florian - Win: True, Accuracy: 1.0
166.01s
