In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import datasets, transforms
from spikingjelly.activation_based import neuron, functional, surrogate, layer

# 定义卷积SNN网络
class CSNN(nn.Module):
    def __init__(self, T=10, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(1, 24, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(24),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 14 * 14

        layer.Conv2d(24, 24, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(24),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 7 * 7

        layer.Flatten(),
        layer.Linear(24 * 7 * 7, 24 * 4 * 4, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),

        layer.Linear(24 * 4 * 4, 10, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='../.cache/minst', train=True, download=True, transform=transforms.ToTensor())


train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)


# 初始化网络和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CSNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


model.train()
print(model)
pbar = tqdm(train_loader)
for data, target in pbar:
    optimizer.zero_grad()
    data, label = data.to(device), target.to(device)
    label_onehot = F.one_hot(label, 10).float()
    out_fr = model(data)
    loss = F.mse_loss(out_fr, label_onehot)
    loss.backward()
    optimizer.step()
    functional.reset_net(model)
    pbar.set_description(f'loss:{loss.item():.4f}')



CSNN(
  (conv_fc): Sequential(
    (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (4): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): IFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (8): Flatten(start_dim=1, end_dim=-1, st

loss:0.0114: 100%|██████████| 938/938 [00:42<00:00, 22.31it/s]
