In [None]:
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import torch
from norse.torch.functional import stdp

In [2]:
from ipywidgets import interact, IntSlider, FloatSlider
from functools import partial

IntSlider = partial(IntSlider, continuous_update=False)
FloatSlider = partial(FloatSlider, continuous_update=False)

In [3]:
def dw_pre(delay, p_stdp):
    """Compute the weight change for a pre spike."""
    pre_spike = torch.tensor([[1.0]]), torch.tensor([[0.0]])
    post_spike = torch.tensor([[0.0]]), torch.tensor([[1.0]])
    no_spike = torch.tensor([[0.0]]), torch.tensor([[0.0]])
    w = torch.tensor([[0.0]], requires_grad=False)
    state = stdp.STDPState(t_pre=torch.tensor([[0.0]]), t_post=torch.tensor([[0.0]]))
    _, state = stdp.stdp_step_linear(*post_spike, w, state, p_stdp, dt=0.001)
    time = -0.001 * np.arange(0, delay)
    dw = []
    for _ in time:
        w = torch.tensor([[0.0]], requires_grad=False)
        dw.append(stdp.stdp_step_linear(*pre_spike, w, state, p_stdp, dt=0.001)[0][0])
        state = stdp.stdp_step_linear(*no_spike, w, state, p_stdp, dt=0.001)[1]
    return time, dw

In [4]:
def dw_post(delay, p_stdp):
    """Compute the weight change for a post spike."""
    pre_spike = torch.tensor([[1.0]]), torch.tensor([[0.0]])
    post_spike = torch.tensor([[0.0]]), torch.tensor([[1.0]])
    no_spike = torch.tensor([[0.0]]), torch.tensor([[0.0]])
    w = torch.tensor([[0.0]], requires_grad=False)
    state = stdp.STDPState(t_pre=torch.tensor([[0.0]]), t_post=torch.tensor([[0.0]]))
    _, state = stdp.stdp_step_linear(*pre_spike, w, state, p_stdp, dt=0.001)
    time = 0.001 * np.arange(0, delay)
    dw = []
    for _ in time:
        w = torch.tensor([[0.0]], requires_grad=False)
        dw.append(stdp.stdp_step_linear(*post_spike, w, state, p_stdp, dt=0.001)[0][0])
        state = stdp.stdp_step_linear(*no_spike, w, state, p_stdp, dt=0.001)[1]
    return time, dw

In [5]:
@interact(
    a_pre=FloatSlider(min=0.0, max=10.0, step=0.01, value=1.0),
    a_post=FloatSlider(min=0.0, max=10.0, step=0.01, value=1.0),
    lgt_pre=FloatSlider(min=-3.0, max=1.0, step=0.01, value=-1.50),
    lgt_post=FloatSlider(min=-3.0, max=1.0, step=0.01, value=-1.50),
    lgη_plus=FloatSlider(min=-8.0, max=0.0, step=0.01, value=-3.0),
    lgη_minus=FloatSlider(min=-8.0, max=0.0, step=0.01, value=-3.0),
    mu=FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5),
)
def experiment(a_pre, a_post, lgt_pre, lgt_post, lgη_plus, lgη_minus, mu):
    gen_params = partial(
        stdp.STDPParameters,
        a_pre=a_pre,
        a_post=a_post,
        tau_pre_inv=10**-lgt_pre,
        tau_post_inv=10**-lgt_post,
        w_max=10.0,
        w_min=-10.0,
        eta_plus=10**lgη_plus,
        eta_minus=10.0**lgη_minus,
        mu=mu,
    )
    p_add = gen_params(stdp_algorithm="additive")
    p_step = gen_params(stdp_algorithm="additive_step")
    p_mult = gen_params(stdp_algorithm="multiplicative_pow")
    p_relu = gen_params(stdp_algorithm="multiplicative_relu")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle("STDP  algorithms")

    ax[0].set_title("Weight change for pre spike")
    ax[0].set_xlabel("delay [ms]")
    ax[0].set_ylabel("weight change")
    ax[0].plot(*dw_pre(100, p_add), "-", label="Additive")
    ax[0].plot(*dw_pre(100, p_step), "--", label="Additive step")
    ax[0].plot(*dw_pre(100, p_mult), ":", label="Multiplicative pow")
    ax[0].plot(*dw_pre(100, p_relu), "-.", label="Multiplicative relu")
    ax[0].axhline(y=0, color="gray", linestyle="--", linewidth=0.5)
    ax[0].axvline(x=0, color="gray", linestyle="--", linewidth=0.5)
    ax[0].legend()

    ax[1].set_title("Weight change for post spike")
    ax[1].set_xlabel("delay [ms]")
    ax[1].set_yticks([])
    ax[1].plot(*dw_post(100, p_add), "-", label="Additive")
    ax[1].plot(*dw_post(100, p_step), "--", label="Additive step")
    ax[1].plot(*dw_post(100, p_mult), ":", label="Multiplicative pow")
    ax[1].plot(*dw_post(100, p_relu), "-.", label="Multiplicative relu")
    ax[1].axhline(y=0, color="gray", linestyle="--", linewidth=0.5)
    ax[1].axvline(x=0, color="gray", linestyle="--", linewidth=0.5)
    ax[1].legend()

    y_min = min(ax[0].get_ylim()[0], ax[1].get_ylim()[0])
    y_max = max(ax[0].get_ylim()[1], ax[1].get_ylim()[1])
    ax[0].set_ylim(min(-1e-5, y_min), max(1e-5, y_max))
    ax[1].set_ylim(min(-1e-5, y_min), max(1e-5, y_max))
    fig.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='a_pre', max=10.0, step=0.01…