Analyze STDP update dynamics.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from spikey.synapse import *


def onehot(value, buckets):
    output = np.zeros(buckets)
    output[value] = 1
    return output


SYNAPSE = RLSTDPET

In [None]:
## Excitatory vs Inhibitory
config = {
    'n_neurons': 2,
    'n_inputs': 0,
    'stdp_window': 5,
    'learning_rate': 1,
    'max_weight': 1,
    'trace_decay': None,
}
synapses = SYNAPSE(np.zeros((config['n_inputs'] + config['n_neurons'], config['n_neurons'])), **config)

LENGTH = config['stdp_window'] + 1

for plot_id, (label, INHIBITORIES) in enumerate({'Excitatory': [1, 1], 'Inhibitory': [-1, -1]}.items(), 1):
    INHIBITORIES = np.array(INHIBITORIES)

    for j in range(LENGTH):
        X, X_minus, Y, Y_minus = [], [], [], []

        for i in range(LENGTH + 1):
            synapses.reset()

            a = np.zeros(LENGTH)
            a[-1] = 1

            b = onehot(LENGTH - i - 1, LENGTH)
            b[LENGTH - j - 1] = 1

            full_spike_log = np.array([a, b]).T

            synapses._apply_stdp(full_spike_log, INHIBITORIES)

            X_minus.append(-i)
            Y_minus.append(synapses.trace[0, 1])
            X.append(i)
            Y.append(synapses.trace[1, 0])

        plt.plot(X_minus[::-1] + X, Y_minus[::-1] + Y, label=j)

    plt.title(f"{label} Multiple Spike Effect")
    plt.xlabel('dt')
    plt.ylabel('trace')
    plt.legend()
    plt.show()

In [None]:
## Max Update
INHIBITORIES = np.array([1, 1])

for j in np.arange(0, 1.1, .2):
    config = {
        'n_neurons': 2,
        'n_inputs': 0,
        'stdp_window': 5,
        'learning_rate': j,
        'max_weight': 1,
        'trace_decay': None,
    }
    synapses = SYNAPSE(np.zeros((config['n_inputs'] + config['n_neurons'], config['n_neurons'])), **config)

    X, X_minus, Y, Y_minus = [], [], [], []

    LENGTH = config['stdp_window'] + 1

    for i in range(LENGTH + 1):
        synapses.reset()

        a = np.zeros(LENGTH)
        a[-1] = 1

        b = onehot(LENGTH - i - 1, LENGTH)

        full_spike_log = np.array([a, b]).T

        synapses._apply_stdp(full_spike_log, INHIBITORIES)

        X_minus.append(-i)
        Y_minus.append(synapses.trace[0, 1])
        X.append(i)
        Y.append(synapses.trace[1, 0])

    plt.plot(X_minus[::-1] + X, Y_minus[::-1] + Y, label=j)

plt.title('Max Update')
plt.xlabel('dt')
plt.ylabel('trace')
plt.legend()

In [None]:
## Window
INHIBITORIES = np.array([1, 1])

for j in range(1, 10):
    config = {
        'n_neurons': 2,
        'n_inputs': 0,
        'stdp_window': j,
        'learning_rate': 1,
        'max_weight': 1,
        'trace_decay': None,
    }
    synapses = SYNAPSE(np.zeros((config['n_inputs'] + config['n_neurons'], config['n_neurons'])), **config)

    X, X_minus, Y, Y_minus = [], [], [], []

    LENGTH = config['stdp_window'] + 1

    for i in range(LENGTH + 1):
        synapses.reset()

        a = np.zeros(LENGTH)
        a[-1] = 1

        b = onehot(LENGTH - i - 1, LENGTH)

        full_spike_log = np.array([a, b]).T

        synapses._apply_stdp(full_spike_log, INHIBITORIES)

        X_minus.append(-i)
        Y_minus.append(synapses.trace[0, 1])
        X.append(i)
        Y.append(synapses.trace[1, 0])

    plt.plot(X_minus[::-1] + X, Y_minus[::-1] + Y, label=j)

plt.title('STDP Window')
plt.xlabel('dt')
plt.ylabel('trace')
plt.legend()