In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from ehc_sn import config
from ehc_sn.parameters import Plasticity, Synapses
from norse.torch.functional import stdp

config.device = "cpu"

In [2]:
from ipywidgets import interact, IntSlider, FloatSlider
from functools import partial

IntSlider = partial(IntSlider, continuous_update=False)
FloatSlider = partial(FloatSlider, continuous_update=False)

In [3]:
def dw_pre(delay, p_stdp):  # Post -> Pre: LTD
    """Compute the weight change for a pre spike."""
    pre_spike = torch.tensor([[1.0]]), torch.tensor([[0.0]])
    post_spike = torch.tensor([[0.0]]), torch.tensor([[1.0]])
    no_spike = torch.tensor([[0.0]]), torch.tensor([[0.0]])
    w = torch.tensor([[0.0]], requires_grad=False)
    state = stdp.STDPState(t_pre=torch.tensor([[0.0]]), t_post=torch.tensor([[0.0]]))
    _, state = stdp.stdp_step_linear(*post_spike, w, state, p_stdp, dt=0.001)
    time = -0.001 * np.arange(0, delay)
    dw = []
    for _ in time:
        w = torch.tensor([[0.0]], requires_grad=False)
        dw.append(stdp.stdp_step_linear(*pre_spike, w, state, p_stdp, dt=0.001)[0][0])
        state = stdp.stdp_step_linear(*no_spike, w, state, p_stdp, dt=0.001)[1]
    return time, dw

In [4]:
def dw_post(delay, p_stdp):  # Pre -> Post: LTP
    """Compute the weight change for a post spike."""
    pre_spike = torch.tensor([[1.0]]), torch.tensor([[0.0]])
    post_spike = torch.tensor([[0.0]]), torch.tensor([[1.0]])
    no_spike = torch.tensor([[0.0]]), torch.tensor([[0.0]])
    w = torch.tensor([[0.0]], requires_grad=False)
    state = stdp.STDPState(t_pre=torch.tensor([[0.0]]), t_post=torch.tensor([[0.0]]))
    _, state = stdp.stdp_step_linear(*pre_spike, w, state, p_stdp, dt=0.001)
    time = 0.001 * np.arange(0, delay)
    dw = []
    for _ in time:
        w = torch.tensor([[0.0]], requires_grad=False)
        dw.append(stdp.stdp_step_linear(*post_spike, w, state, p_stdp, dt=0.001)[0][0])
        state = stdp.stdp_step_linear(*no_spike, w, state, p_stdp, dt=0.001)[1]
    return time, dw

In [5]:
@interact(
    gain_ltp=FloatSlider(min=0.0, max=1.0, step=0.01, value=1.0),
    gain_ltd=FloatSlider(min=0.0, max=1.0, step=0.01, value=1.0),
    tau_ltp=FloatSlider(min=0.0, max=1.0, step=0.01, value=0.01),
    tau_ltd=FloatSlider(min=0.0, max=1.0, step=0.01, value=0.02),
)
def experiment(gain_ltp, gain_ltd, tau_ltp, tau_ltd):
    p = Synapses(
        input_size=1, w_min=-10.0, w_max=10.0,
        ltp=Plasticity(gain=gain_ltp, tau=tau_ltp),
        ltd=Plasticity(gain=gain_ltd, tau=tau_ltd),
    ).stdp_parameters() # fmt: skip

    plt.figure(figsize=(10, 5))
    plt.title("Model STDP parameters")

    plt.plot(*dw_pre(100, p), "-", label="LTD")
    plt.plot(*dw_post(100, p), "-", label="LTP")
    plt.axhline(y=0, color="gray", linestyle="--", linewidth=0.5)
    plt.axvline(x=0, color="gray", linestyle="--", linewidth=0.5)

    plt.xlabel("delay [ms]")
    plt.ylabel("weight change")
    plt.legend()
    plt.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='gain_ltp', max=1.0, step=0.…