Rough result replication of Izhikevich's 2007 paper,
Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling

https://www.izhikevich.org/publications/dastdp.pdf

Eugene M. Izhikevich(2007) Solving the Distal Reward Problem through linkage of STDP and Dopamine Signaling. Cerebral Cortex Advance Access, Jan 13, 2007. doi:10.1093/cercor/bhl152

In [1]:
import numpy as np

from spikey.snn import *
from spikey.core import GenericLoop, RLCallback
from spikey.RL import Logic
from spikey.viz import print_rates, print_common_action

np.random.seed(0)

In [2]:
callback = RLCallback(reduced=False, measure_rates=True)

In [3]:
class rand_state(Logic):
    NECESSARY_KEYS = Logic.extend_keys({"n_states": "Number of input groups."})
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.observation_space = list(range(self._n_states))

    def _get_state(self) -> np.ndarray:
        # Randomly sample a state integer on game start and at every update
        return np.random.randint(self._n_states)

## Classical Conditioning

In [4]:
Logic.NECESSARY_KEYS

["expected_value": "[any] func(state) Correct response of logic gate to specific state."]

In [5]:
rand_state.NECESSARY_KEYS

["expected_value": "[any] func(state) Correct response of logic gate to specific state.",
 'n_states']

In [6]:
training_params = {
    'n_episodes': 5,
    'len_episode': 100,
}

In [7]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 50
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, .5, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES))
for state in range(N_STATES):
    state_rate_map[state, state] = .2

FIRE_STATES = [0, 3, 6, 9]
def is_target_state(state):
    # Expect network to output True when current state in states listed,
    # otherwise False
    return state in FIRE_STATES

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.RLSTDPET,
        "weights": weight.Manual,
        "readout": readout.Threshold,
        "rewarder": reward.MatchExpected,
    }
    keys = {
        "n_inputs": N_INPUTS,
        "n_neurons": N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,
        "trace_decay": .1,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "action_threshold": .0,  # Makes network always output True

        "expected_value": is_target_state,
        "state_rate_map": state_rate_map, 
    }

In [8]:
# Control, without learning
training_loop = GenericLoop(network_template, rand_state, callback, **training_params)
training_loop.reset(**{'learning_rate': 0, 'n_episodes': 1})
network, game, results, info = training_loop()

print(f"{callback.results['total_time']:.2f}s")
print_rates(callback=callback)

5.61s
0: 0.02 -> 0.03
1: 0.02 -> 0.03
2: 0.02 -> 0.03
3: 0.02 -> 0.03
4: 0.02 -> 0.03
5: 0.02 -> 0.03
6: 0.02 -> 0.03
7: 0.02 -> 0.03
8: 0.02 -> 0.03
9: 0.02 -> 0.03


In [9]:
# Real test
training_loop = GenericLoop(network_template, rand_state, callback, **training_params)
network, game, results, info = training_loop()

print("Firing states:", FIRE_STATES)
print(f"{callback.results['total_time']:.2f}s")
print_rates(callback=callback)

Firing states: [0, 3, 6, 9]
28.00s
0: 0.02 -> 0.33
1: 0.02 -> 0.18
2: 0.02 -> 0.18
3: 0.02 -> 0.35
4: 0.02 -> 0.15
5: 0.02 -> 0.21
6: 0.02 -> 0.35
7: 0.02 -> 0.16
8: 0.02 -> 0.17
9: 0.02 -> 0.35


## Classical Conditioning with Ordinal Output

In [10]:
class max_group(readout.PopulationVector):
    def __call__(self, output_spike_train: np.bool) -> np.float:
        # Network returns index of highest firing output group
        population_vector = super().__call__(output_spike_train)
        return np.argmax(population_vector)

In [11]:
training_params = {
    'n_episodes': 10,
    'len_episode': 100,
}

In [12]:
N_STATES = 10
N_INPUTS = 100
N_NEURONS = 60
N_OUTPUTS = N_NEURONS

w_matrix = np.vstack((  # Feedforward, no hidden layers
    np.random.uniform(0, 1, (N_INPUTS, N_NEURONS)),
    np.zeros((N_NEURONS, N_NEURONS)),
))

# When state is 1 neuron group 1 fires, ...
state_rate_map = np.zeros((N_STATES, N_STATES), dtype=np.float)
for state in range(N_STATES):
    state_rate_map[state, state] = .2

A_STATES = [2, 3, 6, 8]
def is_a_state(state):
    # Expect network to output A(0) when current state in states listed,
    # otherwise B(0)
    return state in A_STATES

class network_template(RLNetwork):
    parts = {
        "inputs": input.RateMap,
        "neurons": neuron.Neuron,
        "synapses": synapse.RLSTDPET,
        "weights": weight.Manual,
        "readout": max_group,
        "rewarder": reward.MatchExpected,
    }
    keys = {
        "n_inputs": N_INPUTS,
        'n_neurons': N_NEURONS,
        "n_outputs": N_OUTPUTS,
        "matrix": w_matrix,
        "magnitude": 1,
        "potential_decay": .05,

        "n_states": N_STATES,
        "refractory_period": 0,
        "firing_threshold": 8,
        "trace_decay": .1,

        "processing_time": 100,
        "learning_rate": .1,
        "max_weight": 2,
        "stdp_window": 100,

        "reward_mult": 1,
        "punish_mult": 0,
        "n_actions": 2,

        "expected_value": is_a_state,
        "state_rate_map": state_rate_map,
    }

In [13]:
# Control, without learning
training_loop = GenericLoop(network_template, rand_state, callback, **training_params)
training_loop.reset(params={'learning_rate': 0, 'n_episodes': 1})
network, game, results, info = training_loop()

print(f"{callback.results['total_time']:.2f}s")
print_common_action(callback=callback)

6.63s
0: 0.0201 -> B(1). counts=[10]
1: 0.0202 -> A(0). counts=[12]
2: 0.0195 -> A(0). counts=[4]
3: 0.0194 -> A(0). counts=[7]
4: 0.0203 -> B(1). counts=[15]
5: 0.0200 -> A(0). counts=[9 5]
6: 0.0202 -> B(1). counts=[1 4]
7: 0.0201 -> A(0). counts=[6 5]
8: 0.0199 -> A(0). counts=[8 6]
9: 0.0202 -> B(1). counts=[8]


In [14]:
# Real test
training_loop = GenericLoop(network_template, rand_state, callback, **training_params)
network, game, results, info = training_loop()

print("A States:", A_STATES)
print(f"{callback.results['total_time']:.2f}s")
print("Initial Responses")
print_common_action(callback=callback, episode=0)
print("\nFinal Responses")
print_common_action(callback=callback, episode=-1)

A States: [2, 3, 6, 8]
59.82s
Initial Responses
0: 0.0200 -> B(1). counts=[11]
1: 0.0202 -> A(0). counts=[11  2]
2: 0.0192 -> A(0). counts=[7]
3: 0.0195 -> A(0). counts=[4 2]
4: 0.0201 -> B(1). counts=[14]
5: 0.0192 -> B(1). counts=[6]
6: 0.0201 -> A(0). counts=[11  1]
7: 0.0205 -> A(0). counts=[7 4]
8: 0.0205 -> A(0). counts=[10]
9: 0.0199 -> B(1). counts=[10]

Final Responses
0: 0.0194 -> B(1). counts=[12]
1: 0.0206 -> B(1). counts=[1 9]
2: 0.0199 -> A(0). counts=[13]
3: 0.0203 -> A(0). counts=[6 3]
4: 0.0195 -> B(1). counts=[12]
5: 0.0194 -> B(1). counts=[3 4]
6: 0.0202 -> A(0). counts=[11]
7: 0.0198 -> B(1). counts=[7]
8: 0.0201 -> A(0). counts=[5]
9: 0.0197 -> B(1). counts=[14]
