Rough result replication of Izhikevich's 2007 paper,
Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling

https://www.izhikevich.org/publications/dastdp.pdf

Eugene M. Izhikevich(2007) Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling. Cerebral Cortex Advance Access, Jan 13, 2007. doi:10.1093/cercor/bhl152

In [1]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt

from spikey.core import *
from spikey.snn import *
from spikey.RL import *

np.random.seed(0)

In [2]:
N_STATES = 10

In [3]:
callback = RLCallback(reduced=False, measure_rates=True)

In [4]:
class game_template(Logic):
    def _get_state(self) -> np.ndarray:
        return np.random.randint(N_STATES)

## Classical Conditioning

In [5]:
def print_rates(experiment_output, training_params, episode=-1):
    network, __, ___, info = experiment_output

    # step_states = [[ep 0 states], [ep 1 states], ...]
    states = np.array(info['step_states'][episode])
    inrates = np.array(info['step_inrates'][episode])
    outrates = np.array(info['step_outrates'][episode])

    for state in range(10):
        mean_inrates = np.mean(inrates[states == state])
        mean_outrates = np.mean(outrates[states == state])

        print(f"{state}: {mean_inrates:.4f} -> {mean_outrates:.4f}")

In [6]:
training_params = {
    'n_episodes': 5,
    'len_episode': 100,
}

N_INPUTS = 100
N_NEURONS = 50
N_OUTPUTS = N_NEURONS

FIRE_STATES = [0, 3, 6, 9]  # States network should fire in

w_matrix = np.vstack((  # Feedforward, single layer
    np.random.uniform(0, .5, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))
w_matrix = np.ma.array(np.float16(w_matrix), mask=(w_matrix == 0), fill_value=0)

state_rate_map = np.zeros((N_STATES, N_STATES), dtype=np.float)
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,
        'magnitude': 1,
        'potential_decay': .05,

        'refractory_period': 0,
        'firing_threshold': 8,
        'trace_decay': .1,

        'processing_time': 100,
        'learning_rate': .1,
        'max_weight': 2,
        'stdp_window': 100,

        'reward_mult': 1,
        'punish_mult': 0,
        'action_threshold': .0,  # Makes network output always True, so reward is only given when state in FIRE_STATES 

        'expected_value': lambda state: state in FIRE_STATES,
        'state_rate_map': state_rate_map, 
    }
    parts = {
        'inputs': input.RateMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDPET,
        'weights': weight.Manual,
        'readout': readout.Threshold,
        'rewarder': reward.MatchExpected,
        'modifiers': None,
    }

In [7]:
# Control, without learning
training_loop = GenericLoop(network_template, game_template, callback, **training_params)
training_loop.reset(**{'learning_rate': 0, 'n_episodes': 2})
network, game, results, info = training_loop()

print(f"{callback.results['total_time']:.2f}s")
print_rates((network, game, results, info), training_params)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = arr.dtype.type(ret / rcount)
10.80s
0: 0.0203 -> 0.0321
1: 0.0204 -> 0.0325
2: 0.0195 -> 0.0334
3: 0.0197 -> 0.0288
4: 0.0200 -> 0.0280
5: 0.0195 -> 0.0317
6: 0.0195 -> 0.0305
7: 0.0196 -> 0.0313
8: 0.0192 -> 0.0248
9: 0.0195 -> 0.0296


In [8]:
# Real test
training_loop = GenericLoop(network_template, game_template, callback, **training_params)
network, game, results, info = training_loop()

print(FIRE_STATES)
print(f"{callback.results['total_time']:.2f}s")
print_rates((network, game, results, info), training_params)

[0, 3, 6, 9]
29.00s
0: 0.0204 -> 0.3521
1: 0.0204 -> 0.1705
2: 0.0202 -> 0.2000
3: 0.0194 -> 0.3191
4: 0.0199 -> 0.1638
5: 0.0201 -> 0.1888
6: 0.0193 -> 0.3352
7: 0.0202 -> 0.1873
8: 0.0196 -> 0.1625
9: 0.0197 -> 0.3381


## Instrumental Conditioning

In [9]:
def print_group_rates(experiment_output, training_params, episode=-1):
    network, __, ___, info = experiment_output

    # step_states = [[ep 0 states], [ep 1 states], ...]
    states = np.array(info['step_states'][episode])
    inrates = np.array(info['step_inrates'][episode])
    step_actions = np.array(info['step_actions'][episode])

    for state in range(10):
        mean_inrates = np.mean(inrates[states == state])
        actions, counts = np.unique(step_actions[states == state], return_counts=True)
        action = actions[np.argmax(counts)]
        print(f"{state}: {mean_inrates:.4f} -> {['A', 'B', 'C'][int(action)]}({action})", step_actions[states==state])

In [10]:
class max_group(readout.PopulationVector):
    def __call__(self, output_spike_train: np.bool) -> np.float:
        return np.argmax(super().__call__(output_spike_train))


In [11]:
training_params = {
    'n_episodes': 10,
    'len_episode': 100,
}

N_INPUTS = 100
N_NEURONS = 60
N_OUTPUTS = N_NEURONS

A_STATES = [2, 3, 6, 8]  # States where group A should be higher than B

w_matrix = np.vstack((  # Feedforward, single layer
    np.random.uniform(0, 1, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))
w_matrix = np.ma.array(np.float16(w_matrix), mask=(w_matrix == 0), fill_value=0)

state_rate_map = np.zeros((N_STATES, N_STATES), dtype=np.float)
for state in range(N_STATES):
    state_rate_map[state, state] = .2

class network_template(RLNetwork):
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        'matrix': w_matrix,
        'magnitude': 1,
        'potential_decay': .05,

        'refractory_period': 0,
        'firing_threshold': 8,
        'trace_decay': .1,

        'processing_time': 100,
        'learning_rate': .1,
        'max_weight': 2,
        'stdp_window': 100,

        'reward_mult': 1,
        'punish_mult': 0,
        'n_actions': 2,

        'expected_value': lambda state: [0, 1][state in A_STATES],
        'state_rate_map': state_rate_map,
    }
    parts = {
        'inputs': input.RateMap,
        'neurons': neuron.Neuron,
        'synapses': synapse.RLSTDPET,
        'weights': weight.Manual,
        'readout': max_group,
        'rewarder': reward.MatchExpected,
        'modifiers': None,
    }

In [12]:
# Control, without learning
training_loop = GenericLoop(network_template, game_template, callback, **training_params)
training_loop.reset(params={'learning_rate': 0, 'n_episodes': 2})
network, game, results, info = training_loop()

print(A_STATES)
print(f"{callback.results['total_time']:.2f}s")
print_group_rates((network, game, results, info), training_params)

[2, 3, 6, 8]
12.44s
0: 0.0198 -> A(0) [0 0 0 0 0 0 0 0 0 0 0]
1: 0.0199 -> A(0) [0 1 0 0 0 0 0 0 0 0 0 0]
2: 0.0200 -> B(1) [1 1 1 1 1 1 1]
3: 0.0204 -> A(0) [0 0 0 0 0 0 0]
4: 0.0199 -> A(0) [0 0 0 0 0 0 0]
5: 0.0204 -> B(1) [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
6: 0.0201 -> A(0) [0 0 0 0 0 0 0]
7: 0.0201 -> B(1) [1 1 1 1 1 1 1 1 1 1 1 1 1 1]
8: 0.0199 -> A(0) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
9: 0.0197 -> A(0) [0 0 0 0]


In [13]:
# Real test
training_loop = GenericLoop(network_template, game_template, callback, **training_params)
network, game, results, info = training_loop()

print(A_STATES)
print(f"{callback.results['total_time']:.2f}s")
print("Initial Responses")
print_group_rates((network, game, results, info), training_params, 0)
print("\nFinal Responses")
print_group_rates((network, game, results, info), training_params, -1)

[2, 3, 6, 8]
64.86s
Initial Responses
0: 0.0202 -> B(1) [0 0 1 1 1 1 1 1 1 1 1 1]
1: 0.0199 -> B(1) [0 1 1 1 1 1 1]
2: 0.0199 -> B(1) [1 1 1 1 1 1]
3: 0.0196 -> A(0) [0 0 0 0 0 0 0 0 0]
4: 0.0201 -> A(0) [0 0 0 1 0 0 1 0]
5: 0.0198 -> B(1) [1 1 1 1 0 0 1 1 1 1 1 1 1 1]
6: 0.0204 -> A(0) [0 0 0 0 0 0 0 0 0 0 0 0 0 0]
7: 0.0206 -> B(1) [1 1 1 1 1 1 1 1 1]
8: 0.0199 -> A(0) [0 0 0 0 0 0 1 0 0 0 0 1 0 0]
9: 0.0205 -> A(0) [0 0 0 0 0 1 0]

Final Responses
0: 0.0194 -> B(1) [1 1 1 1 1 1 1 1 1 1 1 1 1]
1: 0.0199 -> B(1) [1 1 1 1 1 1 1 1]
2: 0.0195 -> A(0) [0 0 0 0 0 0 0 0 0]
3: 0.0199 -> A(0) [0 0 0 0 0 0 0 0 0 0 0]
4: 0.0200 -> B(1) [1 1 1 1 1 1 1 1 1 1 1]
5: 0.0195 -> B(1) [1 1 1 1 1 1 1 1 1]
6: 0.0204 -> A(0) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
7: 0.0208 -> B(1) [1 1 1 1 1 1 1 1 1 1 1 1 1]
8: 0.0209 -> A(0) [1 0]
9: 0.0200 -> B(1) [1 1 1 1 1 1 1 1]
