Rough replication of Izhikevich's 2007 paper,
Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling

https://www.izhikevich.org/publications/dastdp.pdf

Eugene M. Izhikevich(2007) Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling. Cerebral Cortex Advance Access, Jan 13, 2007. doi:10.1093/cercor/bhl152

In [1]:
import numpy as np
from spikey.snn import *
from spikey.games import Logic
from spikey.viz import print_rates, print_common_action

from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, RunConfig

np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class rand_state(Logic):
    """
    A customization of the Logic game, sets the game state randomly in 0..N at each timestep.
    """
    NECESSARY_KEYS = Logic.extend_keys({"n_states": "Number of input groups."})
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.observation_space = list(range(self._n_states))

    def _get_state(self) -> np.ndarray:
        # Randomly sample a state integer on game start and at every update
        return np.random.randint(self._n_states)

## Classical Conditioning

In the original experiment there are N distinct input neuron groups, all pointing towards a single output group. The goal is to condition the output neurons to fire heavily in response to certain input groups, while largely ignoring others. This is accomplished by rewarding the network when the desired input groups fire to strengthen that group's connections to the outputs.

Converting this description for use in the framework is straightforward, but if it's your first time needs frame of reference.

1. Divide experiment into network and game mechanics.
In this experiment the game is very simple, for each step a state in 0..N is randomly chosen that corresponds to the input group that is to fire, see rand_state in the cell above. The network will handle its own topology, input firings and reward scheme.

2. Set up network inputs. 
First we split the set of input neurons into N groups, each will fire at a set rate when its respective state is active. In Spikey we accomplish this with the RateMap input type, with its state_rate_map parameter as an ndarray of all zeros except the diagonal which is set to the desired firing rate(=.2). state_rate_map can be a dictionary, ndarray or any other object that will index the state, used as group_rates = state_rate_map[state]. In this case if the state = 0, then group_rates = [.2, 0, 0, ...] which means group 0 will fire at a rate of 20% and all other groups will remain quincient. RateMap automatically divides the set of inputs into groups based on the size of the group rates vector.

3. Set the topology of the network.
Here we have a single fully connected feedforward layer, with each input connected to each output. Using the Manual weight part, we specify the network topology as a matrix in the shape (n_inputs+n_body, n_body) with n_body = n_hidden + n_output. For our purposes this looks like,

```
n_neurons
------------------
|   connected    |        n_inputs
- - - - - - - - -
|   unconnected  |        n_neurons
------------------
```

with connected = uniform(0, 1) and unconnected = 0.

4. Setup reward scheme and network readout.
In this experiment reward is given solely based on the game state and ignores the network output. Therefore the readout function was arbitrarily chosen to be the simplest possible, a threshold function. A custom rewarder was setup in the state below, giving reward when the states is in the list 0, 3, 6 or 9.

In [3]:
class StateRewarder(reward.template.Reward):
    FIRE_STATES = [0, 3, 6, 9]
    def __call__(self, state, action, state_next):
        # Give reward when state in desired states
        if state in self.FIRE_STATES:
            return self._reward_mult
        else:
            return self._punish_mult

In [4]:
training_params = {
    'n_episodes': 5,
    'len_episode': 100,
}

In [5]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 50
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, .5, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES))
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.LTP,
        "weights": weight.Manual,
        "readout": readout.Threshold,
        "rewarder": StateRewarder,
    }
    keys = {
        "n_inputs": N_INPUTS,
        "n_neurons": N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "action_threshold": .0,  # Does not matter

        "state_rate_map": state_rate_map, 
    }

In [6]:
def train_func(config):
    network_template.keys.update(config)
    game = rand_state(**network_template.keys)
    model = network_template()

    inrates = []
    outrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        outrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            outrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, -model._n_outputs:])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break

    print("Firing states:", StateRewarder.FIRE_STATES)
    print_rates(step_inrates=inrates, step_outrates=outrates, step_states=states, observation_space=game.observation_space)

    return {}

In [7]:
# Control, without learning
trainer = TorchTrainer(
    train_func,
    train_loop_config={'learning_rate': 0, 'n_episodes': 1},
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=15639)[0m 2022-12-11 08:04:30,393	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=15639)[0m Firing states: [0, 3, 6, 9]
[2m[36m(RayTrainWorker pid=15639)[0m 0: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 1: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 2: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 3: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 4: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 5: 0.02 -> 0.04
[2m[36m(RayTrainWorker pid=15639)[0m 6: 0.02 -> 0.04
[2m[36m(RayTrainWorker pid=15639)[0m 7: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 8: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15639)[0m 9: 0.02 -> 0.03


2022-12-11 08:04:40,391	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_7c824_00000 completed. Last result: 


In [8]:
# Real test
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=15717)[0m 2022-12-11 08:04:44,722	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=15717)[0m Firing states: [0, 3, 6, 9]
[2m[36m(RayTrainWorker pid=15717)[0m 0: 0.02 -> 0.26
[2m[36m(RayTrainWorker pid=15717)[0m 1: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15717)[0m 2: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15717)[0m 3: 0.02 -> 0.31
[2m[36m(RayTrainWorker pid=15717)[0m 4: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15717)[0m 5: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15717)[0m 6: 0.02 -> 0.29
[2m[36m(RayTrainWorker pid=15717)[0m 7: 0.02 -> 0.04
[2m[36m(RayTrainWorker pid=15717)[0m 8: 0.02 -> 0.03
[2m[36m(RayTrainWorker pid=15717)[0m 9: 0.02 -> 0.25


2022-12-11 08:04:55,000	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_84d91_00000 completed. Last result: 


## Classical Conditioning with Ordinal Output

On top of the last experiment, here network outputs are split into two groups. The networks output is equal to the highest firing group, eg if group 0 fires more than any other group the network outputs a 0. The network is conditioned to output a 0 for states 2, 3, 6 and 8 and a 1 otherwise.

A variation of the population vector readout was used, defined in the cell below. The base population vector readout returns a relative firing rate per each input group, eg [.25, .75], our custom MaxGroup readout takes this output and returns the index of the max group, eg 0 or 1.

A custom rewarder was used to reward the network when the correct group fires the most.

In [9]:
class MaxGroup(readout.PopulationVector):
    def __call__(self, output_spike_train: np.bool) -> np.float:
        # Network reads out index of highest firing output group
        population_vector = super().__call__(output_spike_train)
        return np.argmax(population_vector)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def __call__(self, output_spike_train: np.bool) -> np.float:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def __call__(self, output_spike_train: np.bool) -> np.float:


In [10]:
class OrdinalRewarder(reward.template.Reward):
    A_STATES = [2, 3, 6, 8]
    def __call__(self, state, action, state_next):
        # Expect network to output A(0) when current state in states listed,
        # otherwise B(0)
        if action == (state in self.A_STATES):
            return self._reward_mult
        else:
            return self._punish_mult

In [11]:
training_params = {
    'n_episodes': 10,
    'len_episode': 100,
}

In [12]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 60
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, 1, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES), dtype=float)
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.LTP,
        "weights": weight.Manual,
        "readout": MaxGroup,
        "rewarder": OrdinalRewarder,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "n_actions": 2,

        "state_rate_map": state_rate_map,
    }

In [13]:
def train_func():
    game = rand_state(**network_template.keys)
    model = network_template()

    inrates = []
    states = []
    actions = []
    for epoch in range(5):
        model.reset()
        state = game.reset()
        state_next = None

        inrates.append([])
        states.append([])
        actions.append([])
        for s in range(100):
            action = model.tick(state)
            state_next, _, done, __ = game.step(action)
            reward = model.reward(state, action, state_next)

            inrates[-1].append(np.mean(np.abs(model.spike_log[-model._processing_time:, :model._n_inputs])))
            states[-1].append(state)
            actions[-1].append(action)

            state = state_next
            if done:
                break

    print("A States:", OrdinalRewarder.A_STATES)
    print("Initial Responses")
    print_common_action(step_inrates=inrates, step_actions=actions, step_states=states, observation_space=game.observation_space, episode=0)
    print("\nFinal Responses")
    print_common_action(step_inrates=inrates, step_actions=actions, step_states=states, observation_space=game.observation_space, episode=-1)

    return {}

In [14]:
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=1),
    run_config=RunConfig(verbose=0),
)
results = trainer.fit()

[2m[36m(RayTrainWorker pid=15785)[0m 2022-12-11 08:04:59,260	INFO config.py:87 -- Setting up process group for: env:// [rank=0, world_size=1]


[2m[36m(RayTrainWorker pid=15785)[0m A States: [2, 3, 6, 8]
[2m[36m(RayTrainWorker pid=15785)[0m Initial Responses
[2m[36m(RayTrainWorker pid=15785)[0m 0: 0.0189 -> A(0). counts=[4 1]
[2m[36m(RayTrainWorker pid=15785)[0m 1: 0.0205 -> B(1). counts=[2 8]
[2m[36m(RayTrainWorker pid=15785)[0m 2: 0.0187 -> A(0). counts=[6]
[2m[36m(RayTrainWorker pid=15785)[0m 3: 0.0204 -> B(1). counts=[5 7]
[2m[36m(RayTrainWorker pid=15785)[0m 4: 0.0200 -> B(1). counts=[5 9]
[2m[36m(RayTrainWorker pid=15785)[0m 5: 0.0210 -> B(1). counts=[15]
[2m[36m(RayTrainWorker pid=15785)[0m 6: 0.0197 -> A(0). counts=[11  1]
[2m[36m(RayTrainWorker pid=15785)[0m 7: 0.0208 -> A(0). counts=[7]
[2m[36m(RayTrainWorker pid=15785)[0m 8: 0.0203 -> B(1). counts=[3 6]
[2m[36m(RayTrainWorker pid=15785)[0m 9: 0.0200 -> B(1). counts=[10]
[2m[36m(RayTrainWorker pid=15785)[0m 
[2m[36m(RayTrainWorker pid=15785)[0m Final Responses
[2m[36m(RayTrainWorker pid=15785)[0m 0: 0.0196 -> B(1). counts=

2022-12-11 08:05:08,623	ERROR checkpoint_manager.py:327 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key in the result dict. Valid keys are: ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'done']


Trial TorchTrainer_8db8e_00000 completed. Last result: 
