In [1]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt
torch.manual_seed(0)

def f_weight(x):
    return torch.clamp(x, -1, 1.)

tau_pre = 2.
tau_post = 2.
T = 128
N = 1
lr = 0.01
net = nn.Sequential(
    layer.Linear(1, 1, bias=False),
    neuron.IFNode()
)
nn.init.constant_(net[0].weight.data, 0.4)

tensor([[0.4000]])

In [2]:
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)

in_spike = (torch.rand([T, N, 1]) > 0.7).float()
stdp_learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], tau_pre=tau_pre, tau_post=tau_post,
                                    f_pre=f_weight, f_post=f_weight)

In [3]:
out_spike = []
trace_pre = []
trace_post = []
weight = []
with torch.no_grad():
    for t in range(T):
        optimizer.zero_grad()
        out_spike.append(net(in_spike[t]).squeeze())
        stdp_learner.step(on_grad=True)  # 将STDP学习得到的参数更新量叠加到参数的梯度上
        optimizer.step()
        weight.append(net[0].weight.data.clone().squeeze())
        trace_pre.append(stdp_learner.trace_pre.squeeze())
        trace_post.append(stdp_learner.trace_post.squeeze())

in_spike = in_spike.squeeze()
out_spike = torch.stack(out_spike)
trace_pre = torch.stack(trace_pre)
trace_post = torch.stack(trace_post)
weight = torch.stack(weight)

# 官方完整代码

In [12]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt

def f_weight(x):
    return torch.clamp(x, -1, 1.)

torch.manual_seed(0)
# plt.style.use(['science'])

if __name__ == '__main__':

    def f_pre(x, w_min, alpha=0.):
        return (x - w_min) ** alpha

    def f_post(x, w_max, alpha=0.):
        return (w_max - x) ** alpha

    w_min, w_max = -1., 1.
    tau_pre, tau_post = 2., 2.
    N_in, N_out = 4, 3
    T = 128
    batch_size = 2
    lr = 0.01
    net = nn.Sequential(
        layer.Linear(N_in, N_out, bias=False),
        neuron.IFNode()
    )
    nn.init.constant_(net[0].weight.data, 0.4)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)

    in_spike = (torch.rand([T, batch_size, N_in]) > 0.7).float()
    learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], 
                                   tau_pre=tau_pre, tau_post=tau_post,
                                   f_pre=f_weight, f_post=f_weight)

    out_spike = []
    trace_pre = []
    trace_post = []
    weight = []
    with torch.no_grad():
        for t in range(T):
            optimizer.zero_grad()
            out_spike.append(net(in_spike[t]))
            learner.step(on_grad=True)
            optimizer.step()
            net[0].weight.data.clamp_(w_min, w_max)
            weight.append(net[0].weight.data.clone())
            trace_pre.append(learner.trace_pre)
            trace_post.append(learner.trace_post)

    out_spike = torch.stack(out_spike)   # [T, batch_size, N_out]
    trace_pre = torch.stack(trace_pre)   # [T, batch_size, N_in]
    trace_post = torch.stack(trace_post) # [T, batch_size, N_out]
    weight = torch.stack(weight)         # [T, N_out, N_in]

    t = torch.arange(0, T).float()
    
    in_spike = in_spike[:, 0, 0]
    out_spike = out_spike[:, 0, 0]
    trace_pre = trace_pre[:, 0, 0]
    trace_post = trace_post[:, 0, 0]
    weight = weight[:, 0, 0]

    cmap = plt.get_cmap('tab10')
    plt.subplot(5, 1, 1)
    plt.eventplot((in_spike * t)[in_spike == 1], lineoffsets=0, colors=cmap(0))
    plt.xlim(-0.5, T + 0.5)
    plt.ylabel('$s[i]$', rotation=0, labelpad=10)
    plt.xticks([])
    plt.yticks([])

    plt.subplot(5, 1, 2)
    plt.plot(t, trace_pre, c=cmap(1))
    plt.xlim(-0.5, T + 0.5)
    plt.ylabel('$tr_{pre}$', rotation=0)
    plt.yticks([trace_pre.min().item(), trace_pre.max().item()])
    plt.xticks([])

    plt.subplot(5, 1, 3)
    plt.eventplot((out_spike * t)[out_spike == 1], lineoffsets=0, colors=cmap(2))
    plt.xlim(-0.5, T + 0.5)
    plt.ylabel('$s[j]$', rotation=0, labelpad=10)
    plt.xticks([])
    plt.yticks([])

    plt.subplot(5, 1, 4)
    plt.plot(t, trace_post, c=cmap(3))
    plt.ylabel('$tr_{post}$', rotation=0)
    plt.yticks([trace_post.min().item(), trace_post.max().item()])
    plt.xlim(-0.5, T + 0.5)
    plt.xticks([])

    plt.subplot(5, 1, 5)
    plt.plot(t, weight, c=cmap(4))
    plt.xlim(-0.5, T + 0.5)
    plt.ylabel('$w[i][j]$', rotation=0)
    plt.yticks([weight.min().item(), weight.max().item()])
    plt.xlabel('time-step')
    
    plt.gcf().subplots_adjust(left=0.18)
    
    plt.show()
    plt.savefig('./docs/source/_static/tutorials/activation_based/stdp/stdp_trace.png')
    plt.savefig('./docs/source/_static/tutorials/activation_based/stdp/stdp_trace.svg')
    plt.savefig('./docs/source/_static/tutorials/activation_based/stdp/stdp_trace.pdf')

tensor([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
        1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0.])

# STDP（训练卷积层）与梯度下降（训练全连接层）结合

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional

T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'

In [14]:
def f_weight(x):
    return torch.clamp(x, -1, 1.)


net = nn.Sequential(
    layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Flatten(),
    layer.Linear(16 * 8 * 8, 64, bias=False),
    neuron.IFNode(),
    layer.Linear(64, 10, bias=False),
    neuron.IFNode(),
)

functional.set_step_mode(net, step_mode)

In [15]:
instances_stdp = (layer.Conv2d, )

stdp_learners = []

for i in range(net.__len__()):
    if isinstance(net[i], instances_stdp):
        stdp_learners.append(
            learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
                                f_pre=f_weight, f_post=f_weight)
        )

In [16]:
params_stdp = []
for m in net.modules():
    if isinstance(m, instances_stdp):
        for p in m.parameters():
            params_stdp.append(p)

params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
    if p not in params_stdp_set:
        params_gradient_descent.append(p)

optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)

In [17]:
x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])

In [18]:
target

tensor([6, 8])

In [None]:
optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()
optimizer_stdp.zero_grad()

for i in range(stdp_learners.__len__()):
    stdp_learners[i].step(on_grad=True)

optimizer_gd.step()
optimizer_stdp.step()

functional.reset_net(net)
for i in range(stdp_learners.__len__()):
    stdp_learners[i].reset()

# 完整的示例代码

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional

T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'

def f_weight(x):
    return torch.clamp(x, -1, 1.)


net = nn.Sequential(
    layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
    neuron.IFNode(),
    layer.MaxPool2d(2, 2),
    layer.Flatten(),
    layer.Linear(16 * 8 * 8, 64, bias=False),
    neuron.IFNode(),
    layer.Linear(64, 10, bias=False),
    neuron.IFNode(),
)

functional.set_step_mode(net, step_mode)

instances_stdp = (layer.Conv2d, )

stdp_learners = []

for i in range(net.__len__()):
    if isinstance(net[i], instances_stdp):
        stdp_learners.append(
            learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
                                f_pre=f_weight, f_post=f_weight)
        )


params_stdp = []
for m in net.modules():
    if isinstance(m, instances_stdp):
        for p in m.parameters():
            params_stdp.append(p)

params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
    if p not in params_stdp_set:
        params_gradient_descent.append(p)

optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)



x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])

optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()

y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()



optimizer_stdp.zero_grad()

for i in range(stdp_learners.__len__()):
    stdp_learners[i].step(on_grad=True)

optimizer_gd.step()
optimizer_stdp.step()

functional.reset_net(net)
for i in range(stdp_learners.__len__()):
    stdp_learners[i].reset()