Rough replication of Izhikevich's 2007 paper,
Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling

https://www.izhikevich.org/publications/dastdp.pdf

Eugene M. Izhikevich(2007) Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling. Cerebral Cortex Advance Access, Jan 13, 2007. doi:10.1093/cercor/bhl152

In [1]:
import numpy as np
from spikey.snn import *
from spikey.core import GenericLoop, RLCallback
from spikey.games import Logic
from spikey.viz import print_rates, print_common_action

np.random.seed(0)

In [2]:
class rand_state(Logic):
    """
    A customization of the Logic game, sets the game state randomly in 0..N at each timestep.
    """
    NECESSARY_KEYS = Logic.extend_keys({"n_states": "Number of input groups."})
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.observation_space = list(range(self._n_states))

    def _get_state(self) -> np.ndarray:
        # Randomly sample a state integer on game start and at every update
        return np.random.randint(self._n_states)

## Classical Conditioning

In the original experiment there are N distinct input neuron groups, all pointing towards a single output group. The goal is to condition the output neurons to fire heavily in response to certain input groups, while largely ignoring others. This is accomplished by rewarding the network when the desired input groups fire to strengthen that group's connections to the outputs.

Converting this description for use in the framework is straightforward, but if it's your first time needs frame of reference.

1. Divide experiment into network and game mechanics.
In this experiment the game is very simple, for each step a state in 0..N is randomly chosen that corresponds to the input group that is to fire, see rand_state in the cell above. The network will handle its own topology, input firings and reward scheme.

2. Set up network inputs. 
First we split the set of input neurons into N groups, each will fire at a set rate when its respective state is active. In Spikey we accomplish this with the RateMap input type, with its state_rate_map parameter as an ndarray of all zeros except the diagonal which is set to the desired firing rate(=.2). state_rate_map can be a dictionary, ndarray or any other object that will index the state, used as group_rates = state_rate_map[state]. In this case if the state = 0, then group_rates = [.2, 0, 0, ...] which means group 0 will fire at a rate of 20% and all other groups will remain quincient. RateMap automatically divides the set of inputs into groups based on the size of the group rates vector.

3. Set the topology of the network.
Here we have a single fully connected feedforward layer, with each input connected to each output. Using the Manual weight part, we specify the network topology as a matrix in the shape (n_inputs+n_body, n_body) with n_body = n_hidden + n_output. For our purposes this looks like,

```
n_neurons
------------------
|   connected    |        n_inputs
- - - - - - - - -
|   unconnected  |        n_neurons
------------------
```

with connected = uniform(0, 1) and unconnected = 0.

4. Setup reward scheme and network readout.
In this experiment reward is given solely based on the game state and ignores the network output. Therefore the readout function was arbitrarily chosen to be the simplest possible, a threshold function. A custom rewarder was setup in the state below, giving reward when the states is in the list 0, 3, 6 or 9.

In [3]:
class StateRewarder(reward.template.Reward):
    FIRE_STATES = [0, 3, 6, 9]
    def __call__(self, state, action, state_next):
        # Give reward when state in desired states
        if state in self.FIRE_STATES:
            return self._reward_mult
        else:
            return self._punish_mult

In [4]:
training_params = {
    'n_episodes': 5,
    'len_episode': 100,
}

In [5]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 50
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, .5, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES))
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.RLSTDP,
        "weights": weight.Manual,
        "readout": readout.Threshold,
        "rewarder": StateRewarder,
    }
    keys = {
        "n_inputs": N_INPUTS,
        "n_neurons": N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "action_threshold": .0,  # Does not matter

        "state_rate_map": state_rate_map, 
    }

In [6]:
# Control, without learning
training_loop = GenericLoop(network_template, rand_state, measure_rates=True, **training_params)
training_loop.reset(**{'learning_rate': 0, 'n_episodes': 1})
network, game, results, info = training_loop()

print(f"{training_loop.callback.results['total_time']:.2f}s")
print_rates(callback=training_loop.callback)

1.47s
0: 0.02 -> 0.03
1: 0.02 -> 0.03
2: 0.02 -> 0.03
3: 0.02 -> 0.03
4: 0.02 -> 0.03
5: 0.02 -> 0.03
6: 0.02 -> 0.03
7: 0.02 -> 0.03
8: 0.02 -> 0.03
9: 0.02 -> 0.03


In [7]:
# Real test
training_loop = GenericLoop(network_template, rand_state, measure_rates=True, **training_params)
network, game, results, info = training_loop()

print("Firing states:", StateRewarder.FIRE_STATES)
print(f"{training_loop.callback.results['total_time']:.2f}s")
print_rates(callback=training_loop.callback)

Firing states: [0, 3, 6, 9]
6.33s
0: 0.02 -> 0.24
1: 0.02 -> 0.03
2: 0.02 -> 0.04
3: 0.02 -> 0.19
4: 0.02 -> 0.03
5: 0.02 -> 0.03
6: 0.02 -> 0.21
7: 0.02 -> 0.03
8: 0.02 -> 0.02
9: 0.02 -> 0.21


## Classical Conditioning with Ordinal Output

On top of the last experiment, here network outputs are split into two groups. The networks output is equal to the highest firing group, eg if group 0 fires more than any other group the network outputs a 0. The network is conditioned to output a 0 for states 2, 3, 6 and 8 and a 1 otherwise.

A variation of the population vector readout was used, defined in the cell below. The base population vector readout returns a relative firing rate per each input group, eg [.25, .75], our custom MaxGroup readout takes this output and returns the index of the max group, eg 0 or 1.

A custom rewarder was used to reward the network when the correct group fires the most.

In [8]:
class MaxGroup(readout.PopulationVector):
    def __call__(self, output_spike_train: np.bool) -> np.float:
        # Network reads out index of highest firing output group
        population_vector = super().__call__(output_spike_train)
        return np.argmax(population_vector)

In [9]:
class OrdinalRewarder(reward.template.Reward):
    A_STATES = [2, 3, 6, 8]
    def __call__(self, state, action, state_next):
        # Expect network to output A(0) when current state in states listed,
        # otherwise B(0)
        if action == (state in self.A_STATES):
            return self._reward_mult
        else:
            return self._punish_mult

In [10]:
training_params = {
    'n_episodes': 10,
    'len_episode': 100,
}

In [11]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 60
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, 1, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES), dtype=float)
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.RLSTDP,
        "weights": weight.Manual,
        "readout": MaxGroup,
        "rewarder": OrdinalRewarder,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "n_actions": 2,

        "state_rate_map": state_rate_map,
    }

In [12]:
# Control, without learning
training_loop = GenericLoop(network_template, rand_state, measure_rates=True, **training_params)
training_loop.reset(params={'learning_rate': 0, 'n_episodes': 1})
network, game, results, info = training_loop()

print(f"{training_loop.callback.results['total_time']:.2f}s")
print_common_action(callback=training_loop.callback)

1.38s
0: 0.0203 -> B(1). counts=[9]
1: 0.0204 -> B(1). counts=[4 7]
2: 0.0200 -> A(0). counts=[12]
3: 0.0208 -> A(0). counts=[4 3]
4: 0.0202 -> A(0). counts=[10]
5: 0.0198 -> B(1). counts=[1 7]
6: 0.0199 -> A(0). counts=[13]
7: 0.0202 -> B(1). counts=[2 5]
8: 0.0207 -> B(1). counts=[4 7]
9: 0.0199 -> B(1). counts=[12]


In [13]:
# Real test
training_loop = GenericLoop(network_template, rand_state, measure_rates=True, **training_params)
network, game, results, info = training_loop()

print("A States:", OrdinalRewarder.A_STATES)
print(f"{training_loop.callback.results['total_time']:.2f}s")
print("Initial Responses")
print_common_action(callback=training_loop.callback, episode=0)
print("\nFinal Responses")
print_common_action(callback=training_loop.callback, episode=-1)

A States: [2, 3, 6, 8]
12.65s
Initial Responses
0: 0.0204 -> B(1). counts=[7]
1: 0.0199 -> A(0). counts=[12]
2: 0.0198 -> A(0). counts=[10]
3: 0.0195 -> A(0). counts=[7 3]
4: 0.0193 -> A(0). counts=[7]
5: 0.0197 -> B(1). counts=[15]
6: 0.0199 -> A(0). counts=[14]
7: 0.0196 -> B(1). counts=[1 8]
8: 0.0200 -> A(0). counts=[5 2]
9: 0.0200 -> B(1). counts=[9]

Final Responses
0: 0.0198 -> B(1). counts=[11]
1: 0.0197 -> A(0). counts=[9 3]
2: 0.0201 -> A(0). counts=[11]
3: 0.0193 -> A(0). counts=[10]
4: 0.0193 -> B(1). counts=[9]
5: 0.0208 -> B(1). counts=[7]
6: 0.0193 -> A(0). counts=[7]
7: 0.0208 -> B(1). counts=[10]
8: 0.0197 -> A(0). counts=[8]
9: 0.0199 -> B(1). counts=[15]
