Replication of Florian 2007 XOR gate experiments.
* Rate based input coding
* Temporal pattern coding

https://www.florian.io/papers/2007_Florian_Modulated_STDP.pdf

Florian R (2007) Reinforcement Learning Through Modulation of Spike-Timing-Dependent Synaptic Plasticity. Neural Computation 19(6). https://doi.org/10.1162/neco.2007.19.6.1468

In [1]:
import numpy as np

from spikey.snn import *
from spikey.games import Logic
from spikey.viz import print_rates

from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, RunConfig

np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def print_w_diffs(original_w, final_w, layer_cutoff):
    print(f"{np.sum(original_w[:, :layer_cutoff]):.0f} -> {np.sum(final_w[:, :layer_cutoff]):.0f}")
    print(f"{np.sum(original_w[:, layer_cutoff:]):.0f} -> {np.sum(final_w[:, layer_cutoff:]):.0f}")

In [3]:
def print_success(step_states, step_inrates, step_sysrates, step_outrates, training_params):
    states = np.array(step_states).reshape((-1, 2))
    inrates = np.array(step_inrates).reshape((-1))
    sysrates = np.array(step_sysrates).reshape((-1))
    outrates = np.array(step_outrates).reshape((-1))

    HIGH = [[False, True], [True, False]]
    LOW =  [[False, False], [True, True]]

    relevant_timeframe = training_params['eval_steps'] // 4

    high_rate = min([np.mean(outrates[np.all(states == state, axis=1)][-relevant_timeframe:]) for state in HIGH])
    low_rate = max([np.mean(outrates[np.all(states == state, axis=1)][-relevant_timeframe:]) for state in LOW])

    print(high_rate, low_rate)
    florian_win = high_rate > low_rate + .02

    correct = 0
    for i in range(training_params['eval_steps']):
        state = states[-i]
        rate = outrates[-i]

        if np.sum(state) % 2:
            correct += int(rate > low_rate)
        else:
            correct += int(rate < high_rate)

    florian_accuracy = correct / training_params['eval_steps']

    print(f"Florian - Win: {florian_win}, Accuracy: {florian_accuracy}")

In [4]:
class FlorianReward(spikey.snn.reward.template.Reward):
    def __call__(self, state, action, state_next):
        if sum(state) % 2 == 1:  # (0, 1) and (1, 0)
            return self._reward_mult if action == True else 0
        else:  # (0, 0) and (1, 1)
            return -self._punish_mult if action == True else 0

## Rate Coding

The goal of this experiment is to train a spiking neural network to mimic a XOR gate, meaning it will take two binary inputs and return one binary output. The desired input output mapping is as follows,
```
0, 0 -> 0
0, 1 -> 1
1, 0 -> 1
1, 1 -> 0
```
In this specific experiment the inputs are rate coded. There are two input groups, corresponding to the two boolean inputs. If the specific input is 0, its group will not fire at all, otherwise it will at a rate of 40hz. There should be 60 input neurons(30 per group), 60 hidden and 1 output neuron with each layer fully connected to the next. Each input pattern would be presented to the network for 500ms, with 800 patterns being shown in total. While a pattern is being shown, if the correct output is 1, whenever the network's output neuron fires it recieves a reward of 1. Otherwise when the network's output fires it recieves a reward of -1.

Converting this description for use in the framework is straightforward, but if it's your first time(s) needs a frame of reference.

1. Divide experiment into network and game mechanics.
Splitting the mechanics of the network and game here are simple, the game simply gives two random boolean inputs at every timestep and the network respond to these inputs.

2. Set up network inputs.
We use the RateMap input type for this experiment. The main parameter of this input type is 'state_rate_map' for which we construct a dictionary as follows,
```
LOW_RATE = 0
HIGH_RATE = frequency / steps per pattern = 40 / 500
state_rate_map = {
    (0, 0): [LOW_RATE, LOW_RATE],
    (0, 1): [LOW_RATE, HIGH_RATE],
    (1, 0): [HIGH_RATE, LOW_RATE],
    (1, 1): [HIGH_RATE, HIGH_RATE],
}
```
The RateMap will do the work to split the input neurons into two groups, each with a respective rate given by state_rate_map\[current_state].

3. Set topology of network.
For this we use a manually configured network. We give it one matrix for each layer: the input-hidden and hidden-output layers.
Each initial weight value is sampled uniformly between 0 and .2, with these parameters chosen by trial and error.
```
w_matrix = [
    np.random.uniform(0, .2, (N_INPUTS, N_HIDDEN)),
    np.random.uniform(0, .2, (N_HIDDEN, N_OUTPUTS)),
]
```

4. Set reward scheme and network readout.
This is typically the most complex part of constructing an experiment with spikey, and this example is no different.
Here we use the ActiveRLNetwork base so that our reward function is called at every network update, whereas RLNetwork calls the reward every game step, or every PROCESSING_TIME network steps.
In conjunction we use the parameter continuous_rwd_action in place of the Readout part since the Readout is only meant to apply at every game step. continuous_rwd_action will tell the rewarder whether or not the output neuron fired via the action paramterer.
Finally we use the custom Florian rewarder(defined in the cell above) that functions exactly as the original experiment states. If the expected action is one, a reward is given on every output spike. Otherwise a punishment is given on every output spike.

5. Set other parameters.
Most of the other parameters are taken directly from the paper or are intuitively chosen.


In [5]:
training_params = {
    'n_episodes': 1,
    'len_episode': 800,
    'eval_steps': 50, 
}

In [6]:
N_INPUTS = 60
N_NEURONS = 61
N_OUTPUTS = 1
N_HIDDEN = N_NEURONS - N_OUTPUTS
PROCESSING_TIME = 500

w_matrix = [
    np.random.uniform(0, .2, (N_INPUTS, N_HIDDEN)),
    np.random.uniform(0, .2, (N_HIDDEN, N_OUTPUTS)),
]

LOW_RATE = 0
HIGH_RATE = 40 / PROCESSING_TIME
state_rate_map = {# 2 input groups. 0hz when group false, 40hz when true
    (0, 0): np.array([LOW_RATE, LOW_RATE]),
    (0, 1): np.array([LOW_RATE, HIGH_RATE]),
    (1, 0): np.array([HIGH_RATE, LOW_RATE]),
    (1, 1): np.array([HIGH_RATE, HIGH_RATE]),
}

class network_template(ActiveRLNetwork):
    parts = {
        'inputs': input.RateMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDP,
        'weights': weight.Manual,
        'readout': readout.Threshold,
        'rewarder': FlorianReward,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,

        'input_pct_inhibitory': .5,
        'neuron_pct_inhibitory': 0,
        'magnitude': 1,
        'firing_threshold': 16,
        'refractory_period': 0,  # Gutig, Aharonov, Rotter, & Sompolinsky 2003
        'prob_rand_fire': .15,
        'potential_decay': .05,  # Decay constant Tau=20ms, lambda=e^(-t/T)
        'trace_decay': .04,  # T_z = 25, lambda = e^(-1/T_z)
        "punish_mult": 1,

        'processing_time': PROCESSING_TIME,
        'learning_rate': .625 / 25,  # gamma_0 = gamma / Tau_z
        'max_weight': 5,
        'stdp_window': 20,  # Tau_+ = Tau_- = 20ms
        'action_threshold': 0,  # Makes network always output True

        'continuous_rwd_action': lambda network, game: network.spike_log[-1, -1],
        'state_rate_map': state_rate_map,
    }

In [7]:
# Control, without learning
def train_func():
    network_template.keys.update({'learning_rate': 0, 'len_episode': 50})
    game = Logic(preset="XOR")
    model = network_template()

    original_w = model.synapses.weights.matrix.copy()
    inrates = []
    sysrates = []
    outrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        sysrates.append([])
        outrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            sysrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, model._n_inputs:-model._n_outputs])))
            outrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, -model._n_outputs:])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break
    print_rates(step_inrates=inrates, step_outrates=outrates, step_states=states, observation_space=game.observation_space)
    print_w_diffs(original_w, model.synapses.weights.matrix, model._n_inputs)
    print_success(states, inrates, sysrates, outrates, training_params)

    return {}

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=26469)[0m 2022-12-11 11:59:42,686	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=26469)[0m (False, False): 0.00 -> 0.15
[2m[36m(RayTrainWorker pid=26469)[0m (False, True): 0.04 -> 0.14
[2m[36m(RayTrainWorker pid=26469)[0m (True, False): 0.04 -> 0.15
[2m[36m(RayTrainWorker pid=26469)[0m (True, True): 0.08 -> 0.15
[2m[36m(RayTrainWorker pid=26469)[0m 362 -> 362
[2m[36m(RayTrainWorker pid=26469)[0m 6 -> 6
[2m[36m(RayTrainWorker pid=26469)[0m 0.1471 0.1555
[2m[36m(RayTrainWorker pid=26469)[0m Florian - Win: False, Accuracy: 0.36


2022-12-11 12:00:49,981	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_580c3_00000 completed. Last result: 


In [8]:
# Real experiment
def train_func():
    game = Logic(preset="XOR")
    model = network_template()

    original_w = model.synapses.weights.matrix.copy()
    inrates = []
    sysrates = []
    outrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        sysrates.append([])
        outrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            sysrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, model._n_inputs:-model._n_outputs])))
            outrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, -model._n_outputs:])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break
    print_rates(step_inrates=inrates, step_outrates=outrates, step_states=states, observation_space=game.observation_space)
    print_w_diffs(original_w, model.synapses.weights.matrix, model._n_inputs)
    print_success(states, inrates, sysrates, outrates, training_params)

    return {}

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=26651)[0m 2022-12-11 12:00:54,470	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=26651)[0m (False, False): 0.00 -> 0.69
[2m[36m(RayTrainWorker pid=26651)[0m (False, True): 0.04 -> 0.86
[2m[36m(RayTrainWorker pid=26651)[0m (True, False): 0.04 -> 0.90
[2m[36m(RayTrainWorker pid=26651)[0m (True, True): 0.08 -> 0.69
[2m[36m(RayTrainWorker pid=26651)[0m 362 -> 12924
[2m[36m(RayTrainWorker pid=26651)[0m 6 -> 116
[2m[36m(RayTrainWorker pid=26651)[0m 0.8267 0.7085
[2m[36m(RayTrainWorker pid=26651)[0m Florian - Win: True, Accuracy: 0.9


2022-12-11 12:02:05,413	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_82998_00000 completed. Last result: 


## Temporal Coding

The goal of this experiment is to train a spiking neural network to mimic a XOR gate, meaning it will take two binary inputs and return one binary output. The desired input output mapping is as follows,
```
0, 0 -> 0
0, 1 -> 1
1, 0 -> 1
1, 1 -> 0
```
In this specific experiment the inputs are temporal coded. There are two input groups, corresponding to the two boolean inputs. Each input value, 0 or 1, has a static spike train that is shared between input groups and is triggered whenever that input value is used. There should be 60 input neurons(30 per group), 60 hidden and 1 output neuron with each layer fully connected to the next. Each input pattern would be presented to the network for 500ms, with 800 patterns being shown in total. While a pattern is being shown, if the correct output is 1, whenever the network's output neuron fires it recieves a reward of 1. Otherwise when the network's output fires it recieves a reward of -1.

Converting this description for use in the framework is straightforward, but if it's your first time(s) needs a frame of reference.

1. Divide experiment into network and game mechanics.
Splitting the mechanics of the network and game here are simple, the game simply gives two random boolean inputs at every timestep and the network respond to these inputs.

2. Set up network inputs.
Here we used temporally coded inputs, each input value, 0 or 1, corresponds to a static spike train that is shared between input groups. Each time the input value is given the spike train is shown by all neurons in the input group.
We accomplish this using the StaticMap input type. This works similarly to the RateMap used before, in that we give it an input value to spike train mapping. This can be done on a per input or an aggregate basis, here we do in aggregate for readability.
```
LOW_TRAIN = np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001)
HIGH_TRAIN = np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001)
input_map = {
    (0, 0): np.hstack((LOW_TRAIN, LOW_TRAIN)),
    (0, 1): np.hstack((LOW_TRAIN, HIGH_TRAIN)),
    (1, 0): np.hstack((HIGH_TRAIN, LOW_TRAIN)),
    (1, 1): np.hstack((HIGH_TRAIN, HIGH_TRAIN)),
}
```

3. Set topology of network.
For this we use a manually configured network. We give it one matrix for each layer: the input-hidden and hidden-output layers.
Each initial weight value is sampled uniformly between 0 and .2, with these parameters chosen by trial and error.
```
w_matrix = [
    np.random.uniform(0, .2, (N_INPUTS, N_HIDDEN)),
    np.random.uniform(0, .2, (N_HIDDEN, N_OUTPUTS)),
]
```

4. Set reward scheme and network readout.
This is typically the most complex part of constructing an experiment with spikey, and this example is no different.
Here we use the ActiveRLNetwork base so that our reward function is called at every network update, whereas RLNetwork calls the reward every game step, or every PROCESSING_TIME network steps.
In conjunction we use the parameter continuous_rwd_action in place of the Readout part since the Readout is only meant to apply at every game step. continuous_rwd_action will tell the rewarder whether or not the output neuron fired via the action paramterer.
Finally we use the custom Florian rewarder(defined in the cell above) that functions exactly as the original experiment states. If the expected action is one, a reward is given on every output spike. Otherwise a punishment is given on every output spike.

5. Set other parameters.
Most of the other parameters are taken directly from the paper or are intuitively chosen.


In [9]:
training_params = {
    'n_episodes': 1,
    'len_episode': 800,
    'eval_steps': 50,
}

In [10]:
N_INPUTS = 2
N_NEURONS = 21
N_OUTPUTS = 1
N_HIDDEN = N_NEURONS - N_OUTPUTS
PROCESSING_TIME = 500

LOW_TRAIN = np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001)
HIGH_TRAIN = np.int_(np.random.uniform(0, 1, (PROCESSING_TIME, N_INPUTS // 2)) <= 50 * .0001)
input_map = {
    (0, 0): np.hstack((LOW_TRAIN, LOW_TRAIN)),
    (0, 1): np.hstack((LOW_TRAIN, HIGH_TRAIN)),
    (1, 0): np.hstack((HIGH_TRAIN, LOW_TRAIN)),
    (1, 1): np.hstack((HIGH_TRAIN, HIGH_TRAIN)),
}

w_matrix = [
    np.random.uniform(0, .4, (N_INPUTS, N_HIDDEN)),
    np.random.uniform(0, .4, (N_HIDDEN, N_OUTPUTS)),
]
class network_template(ActiveRLNetwork):
    parts = {
        'inputs': input.StaticMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDP,
        'weights': weight.Manual,
        'readout': readout.Threshold,
        'rewarder': FlorianReward,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,

        'input_pct_inhibitory': .5,
        'neuron_pct_inhibitory': 0,
        'magnitude': 1,
        'firing_threshold': 16,
        'refractory_period': 0,  # Gutig, Aharonov, Rotter, & Sompolinsky 2003
        'prob_rand_fire': .15,
        'potential_decay': .05,  # Decay constant Tau=20ms, lambda=e^(-t/T)
        'trace_decay': .04,  # T_z = 25, lambda = e^(-1/T_z)
        "punish_mult": 1,

        'processing_time': PROCESSING_TIME,
        'learning_rate': .5 / 25,  # gamma_0 = gamma / Tau_z
        'max_weight': 5,
        'stdp_window': 20,  # Tau_+ = Tau_- = 20ms
        'action_threshold': 0,  # Makes network always output True

        'continuous_rwd_action': lambda network, game: network.spike_log[-1, -1],
        'state_spike_map': input_map,
    }

In [11]:
# Control, without learning
def train_func():
    network_template.keys.update({'learning_rate': 0, 'len_episode': 50})
    game = Logic(preset="XOR")
    model = network_template()

    original_w = model.synapses.weights.matrix.copy()
    inrates = []
    sysrates = []
    outrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        sysrates.append([])
        outrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            sysrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, model._n_inputs:-model._n_outputs])))
            outrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, -model._n_outputs:])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break
    print_rates(step_inrates=inrates, step_outrates=outrates, step_states=states, observation_space=game.observation_space)
    print_w_diffs(original_w, model.synapses.weights.matrix, model._n_inputs)
    print_success(states, inrates, sysrates, outrates, training_params)

    return {}

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=26761)[0m 2022-12-11 12:02:09,982	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=26761)[0m (False, False): 0.00 -> 0.15
[2m[36m(RayTrainWorker pid=26761)[0m (False, True): 0.01 -> 0.15
[2m[36m(RayTrainWorker pid=26761)[0m (True, False): 0.01 -> 0.15
[2m[36m(RayTrainWorker pid=26761)[0m (True, True): 0.01 -> 0.15
[2m[36m(RayTrainWorker pid=26761)[0m 1 -> 1
[2m[36m(RayTrainWorker pid=26761)[0m 11 -> 11
[2m[36m(RayTrainWorker pid=26761)[0m 0.1495 0.1595
[2m[36m(RayTrainWorker pid=26761)[0m Florian - Win: False, Accuracy: 0.42


2022-12-11 12:03:18,863	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_afa98_00000 completed. Last result: 


In [12]:
# Real experiment
def train_func():
    game = Logic(preset="XOR")
    model = network_template()

    original_w = model.synapses.weights.matrix.copy()
    inrates = []
    sysrates = []
    outrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        sysrates.append([])
        outrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            sysrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, model._n_inputs:-model._n_outputs])))
            outrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, -model._n_outputs:])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break
    print_rates(step_inrates=inrates, step_outrates=outrates, step_states=states, observation_space=game.observation_space)
    print_w_diffs(original_w, model.synapses.weights.matrix, model._n_inputs)
    print_success(states, inrates, sysrates, outrates, training_params)

    return {}

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=26870)[0m 2022-12-11 12:03:23,601	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=26870)[0m (False, False): 0.00 -> 0.35
[2m[36m(RayTrainWorker pid=26870)[0m (False, True): 0.01 -> 0.39
[2m[36m(RayTrainWorker pid=26870)[0m (True, False): 0.01 -> 0.40
[2m[36m(RayTrainWorker pid=26870)[0m (True, True): 0.01 -> 0.34
[2m[36m(RayTrainWorker pid=26870)[0m 1 -> 6
[2m[36m(RayTrainWorker pid=26870)[0m 11 -> 104
[2m[36m(RayTrainWorker pid=26870)[0m 0.416 0.3682
[2m[36m(RayTrainWorker pid=26870)[0m Florian - Win: True, Accuracy: 0.68


2022-12-11 12:04:29,959	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_db511_00000 completed. Last result: 
