In [None]:
# for colab use
# %pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [2]:
# debug
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
from spikingjelly.activation_based import neuron, functional, layer, surrogate

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current device: {torch.cuda.get_device_name(0)}")
    print(f"Device count: {torch.cuda.device_count()}") # 应该输出 1

PyTorch version: 2.9.0+cu126
CUDA available: True
Current device: Tesla T4
Device count: 1


In [4]:

# --- 辅助函数：阶跃函数 ---
@torch.jit.script
def heaviside(x: torch.Tensor):
    """
    前向传播用的阶跃函数：x >= 0 时输出 1，否则输出 0
    """
    return (x >= 0).float()

# =========================================================
# 1. SuperSpike 实现
# 公式: h(x) = 1 / (beta * |x| + 1)^2
# =========================================================
class SuperSpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        alpha = ctx.alpha
        
        # 实现图片中的公式: 1 / (beta * |x| + 1)^2
        denom = (alpha * x.abs() + 1.0)
        grad_x = grad_output * (1.0 / (denom * denom))
        
        return grad_x, None

class SuperSpike(nn.Module):
    def __init__(self, alpha=100.0, spiking=True):
        """
        SuperSpike 替代梯度
        :param alpha: 对应公式中的 beta，控制梯度的陡峭程度
        """
        super().__init__()
        self.alpha = alpha
        self.spiking = spiking

    def forward(self, x):
        if self.spiking:
            return SuperSpikeFunction.apply(x, self.alpha)
        else:
            return heaviside(x)

# =========================================================
# 2. Sigmoid' (Image Version) 实现
# 公式: h(x) = s(x)(1 - s(x)), 其中 s(x) = sigmoid(beta * x)
# =========================================================
class SigmoidDerivativeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        alpha = ctx.alpha
        
        # 计算 s(x) = sigmoid(beta * x)
        sigmoid_x = torch.sigmoid(alpha * x)
        
        # 实现: h(x) = s(x) * (1 - s(x))
        grad_x = grad_output * sigmoid_x * (1.0 - sigmoid_x)
        
        return grad_x, None

class SigmoidDerivative(nn.Module):
    def __init__(self, alpha=4.0, spiking=True):
        """
        Sigmoid' 替代梯度 (图片版本)
        :param alpha: 对应公式中的 beta
        """
        super().__init__()
        self.alpha = alpha
        self.spiking = spiking

    def forward(self, x):
        if self.spiking:
            return SigmoidDerivativeFunction.apply(x, self.alpha)
        return heaviside(x)

# =========================================================
# 3. Esser et al. 实现
# 公式: h(x) = max(0, 1.0 - beta * |x|)
# =========================================================
class EsserFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        alpha = ctx.alpha
        
        # 实现图片公式: max(0, 1.0 - beta * |x|)
        grad_x = grad_output * torch.clamp(1.0 - alpha * x.abs(), min=0.0)
        
        return grad_x, None

class Esser(nn.Module):
    def __init__(self, alpha=1.0, spiking=True):
        """
        Esser et al. 替代梯度
        :param alpha: 对应公式中的 beta，通常设为 1.0 或更大
        """
        super().__init__()
        self.alpha = alpha
        self.spiking = spiking

    def forward(self, x):
        if self.spiking:
            return EsserFunction.apply(x, self.alpha)
        return heaviside(x)

In [5]:
# ----------------------------------------
# 1. 定义超参数和设置
# ----------------------------------------

T = 8             # 仿真总时长 (SNN 的关键参数)
BATCH_SIZE = 64   # 批处理大小
EPOCHS = 10       # 训练轮数 (为快速演示，设置较小)
LR = 1e-3         # 学习率
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BETA = 10.0       # 替代梯度中的超参数, 论文中规定值

print(f"--- 实验设置 ---")
print(f"设备 (DEVICE): {DEVICE}")
print(f"仿真时长 (T): {T}")
print(f"批大小 (BATCH_SIZE): {BATCH_SIZE}")
print(f"训练轮数 (EPOCHS): {EPOCHS}")
print(f"------------------\n")

--- 实验设置 ---
设备 (DEVICE): cuda:0
仿真时长 (T): 8
批大小 (BATCH_SIZE): 64
训练轮数 (EPOCHS): 10
------------------



In [6]:
# ----------------------------------------
# 2. 加载和预处理 CIFAR10 数据集
# ----------------------------------------
print("正在加载 CIFAR10 数据集...")
# CIFAR10 图像的均值和标准差 (用于归一化)
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(), # 简单数据增强：随机翻转
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

# 加载数据
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("数据集加载完毕。\n")

正在加载 CIFAR10 数据集...


100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s] 


数据集加载完毕。



In [7]:
# ----------------------------------------
# 3. 定义基础的卷积 SNN 模型
# ----------------------------------------
# 使用 nn.Sequential 快速搭建一个简单的 CNN 结构
# 关键在于在激活函数的位置换上 SNN 的脉冲神经元

class BasicCSNN(nn.Module):
    # 增加 surrogate_function 参数
    def __init__(self, T, surrogate_function=surrogate.Sigmoid()):
        super().__init__()
        self.T = T  # 保存仿真时长
        print(f"Initializing Network with Surrogate: {surrogate_function.__class__.__name__}")

        # 定义网络结构
        # 结构：[卷积 -> 脉冲 -> 池化] x 2 -> [展平 -> 全连接 -> 脉冲] -> [全连接]
        self.net = nn.Sequential(
            # 块 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            # --- 核心：使用 LIF 神经元 ---
            # 激活驱动:LIFNode 在前向传播时模拟 LIF 神经元动力学，在反向传播时，SpikingJelly 会自动使用“替代梯度”进行计算。
            neuron.LIFNode(surrogate_function=surrogate_function),
            nn.MaxPool2d(2),  # 32x32 -> 16x16

            # 块 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            neuron.LIFNode(surrogate_function=surrogate_function),
            nn.MaxPool2d(2),  # 16x16 -> 8x8

            # 展平
            nn.Flatten(),

            # 全连接层 1
            nn.Linear(64 * 8 * 8, 128), # 64 * 8 * 8 = 4096
            neuron.LIFNode(surrogate_function=surrogate_function),

            # 输出层 (全连接层 2)
            # 输出层通常不使用脉冲神经元，而是直接输出膜电位或累积电流
            # 这样可以方便地与交叉熵损失配合使用
            nn.Linear(128, 10) # 10个类别
        )

    def forward(self, x):
        # --- SNN 算法思路的核心 ---
        # SNN 神经元是有状态的（例如膜电位 V），在处理一个新样本前必须重置
        # 1. 重置网络中所有神经元的状态
        functional.reset_net(self)

        # 准备一个列表来收集 T 个时间步的输出
        # (T, N, C)，T=时间步, N=BatchSize, C=类别数
        outputs_over_time = []

        # 2. SNN 的时间步循环
        # 对于静态图像 (如CIFAR10)，我们在 T 个时间步内输入 *相同* 的图像 x
        # 神经元会在这 T 步内不断累积输入并发放脉冲
        for t in range(self.T):
            # 运行一步前向传播
            out_t = self.net(x)
            outputs_over_time.append(out_t)

        # 3. 聚合 T 个时间步的输出
        # (T, N, 10) -> (T, N, 10)
        outputs_stack = torch.stack(outputs_over_time)
        
        # 4. 解码：计算 T 步内的平均输出
        # (T, N, 10) -> (N, 10)
        # 我们取所有时间步输出的平均值，作为最终的分类 "logits"
        # 这是一种常见的 SNN 解码方式（Rate Coding / Mean Output）
        return outputs_stack.mean(dim=0)

In [8]:
# ----------------------------------------
# 4. 初始化模型、损失函数和优化器
# ----------------------------------------
# 1. 准备实验配置
# 我们用一个列表来存储所有要对比的实验对象
experiments = []

# 定义要对比的替代梯度
surrogates_config = [
    ("SuperSpike", SuperSpike(alpha=10.0)),
    ("Sigmoid",    SigmoidDerivative(alpha=4.0)),
    ("Esser",      Esser(alpha=1.0))
]

# 2. 初始化所有模型和优化器，填入 experiments 列表
for name, surr_func in surrogates_config:
    # 实例化模型
    net = BasicCSNN(T=T, surrogate_function=surr_func).to(DEVICE)
    
    # 实例化对应的优化器 (每个模型有自己独立的参数)
    opt = optim.Adam(net.parameters(), lr=LR)
    
    # 将它们打包存起来，顺便准备好存结果的 list
    experiments.append({
        "name": name,
        "model": net,
        "optimizer": opt,
        "train_acc_history": [],
        "test_acc_history": []
    })

# 使用标准的交叉熵损失函数
criterion = nn.CrossEntropyLoss()

Initializing Network with Surrogate: SuperSpike
Initializing Network with Surrogate: SigmoidDerivative
Initializing Network with Surrogate: Esser


In [9]:
# ----------------------------------------
# 5. 训练和评估循环
# ----------------------------------------

# --- 训练函数 (Train Loop) ---
def train_epoch(model, optimizer, epoch, model_name):
    model.train()  # 设置为训练模式
    total_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs) # 这里的 model 是参数传进来的
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()


        # 统计损失和准确率
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()


    end_time = time.time()
    avg_loss = total_loss / len(train_loader)
    acc = 100. * correct / total
    print(f"[{model_name}] Epoch {epoch+1} Train | Loss: {avg_loss:.4f} | Acc: {acc:.2f}% | Time: {end_time - start_time:.2f}s")
    return avg_loss, acc

# --- 评估函数 (Eval Loop) ---
def test_epoch(model, epoch, model_name):
    model.eval()  # 设置为评估模式
    total_loss = 0.0
    correct = 0
    total = 0

    # 评估时不需要计算梯度
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            # 前向传播
            outputs = model(inputs)

            # 计算损失
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            # 统计准确率
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    acc = 100. * correct / total
    avg_loss = total_loss / len(test_loader)
    print(f"[{model_name}] Epoch {epoch+1} Test  | Loss: {avg_loss:.4f} | Acc: {acc:.2f}%")
    return avg_loss, acc

In [10]:
# ----------------------------------------
# 6. 开始批量训练 (Batch Training Loop)
# ----------------------------------------
print(f"=== 开始对比实验 (总 Epochs: {EPOCHS}) ===\n")

# 遍历我们在 experiment 列表中定义的每一个实验对象
for exp in experiments:
    # 1. 解包当前实验的变量
    model_name = exp["name"]
    model = exp["model"]
    optimizer = exp["optimizer"]
    
    # 每个模型有自己独立的最佳准确率记录
    best_acc = 0.0
    
    print(f" >> 正在训练模型: [{model_name}] ...")
    
    # 2. 针对当前模型进行 Epoch 循环
    for epoch in range(EPOCHS):
        # 调用修改后的训练函数 (传入当前模型和优化器)
        # 注意：这里我们接收返回值，以便记录历史
        _, train_acc = train_epoch(model, optimizer, epoch, model_name)
        
        # 调用修改后的测试函数
        _, test_acc = test_epoch(model, epoch, model_name)
        
        # 3. 记录数据 (用于后续画图)
        exp["train_acc_history"].append(train_acc)
        exp["test_acc_history"].append(test_acc)
        
        # 4. 更新当前模型的最佳记录
        if test_acc > best_acc:
            best_acc = test_acc
            # 如果你想保存表现最好的模型权重
            # torch.save(model.state_dict(), f'{model_name}_best.pth')

    # 将最佳结果存入字典，方便最后总结
    exp["best_acc"] = best_acc
    
    print(f" << 模型 [{model_name}] 训练结束。最佳测试准确率: {best_acc:.2f}%\n")
    print("-" * 60)

# ----------------------------------------
# 7. 最终总结
# ----------------------------------------
print(f"\n=== 所有实验完成，最终结果汇总 ===")
for exp in experiments:
    print(f"模型: {exp['name']:<15} | 最佳 Test Acc: {exp['best_acc']:.2f}%")

=== 开始对比实验 (总 Epochs: 10) ===

 >> 正在训练模型: [SuperSpike] ...
[SuperSpike] Epoch 1 Train | Loss: 1.5635 | Acc: 43.80% | Time: 34.61s
[SuperSpike] Epoch 1 Test  | Loss: 1.2612 | Acc: 54.59%
[SuperSpike] Epoch 2 Train | Loss: 1.1668 | Acc: 58.22% | Time: 32.59s
[SuperSpike] Epoch 2 Test  | Loss: 1.0899 | Acc: 61.27%
[SuperSpike] Epoch 3 Train | Loss: 1.0219 | Acc: 63.90% | Time: 32.15s
[SuperSpike] Epoch 3 Test  | Loss: 0.9941 | Acc: 64.91%
[SuperSpike] Epoch 4 Train | Loss: 0.9180 | Acc: 67.59% | Time: 32.18s
[SuperSpike] Epoch 4 Test  | Loss: 0.9338 | Acc: 67.45%
[SuperSpike] Epoch 5 Train | Loss: 0.8510 | Acc: 70.14% | Time: 32.18s
[SuperSpike] Epoch 5 Test  | Loss: 0.9194 | Acc: 67.77%
[SuperSpike] Epoch 6 Train | Loss: 0.7938 | Acc: 72.20% | Time: 31.35s
[SuperSpike] Epoch 6 Test  | Loss: 0.8818 | Acc: 69.11%
[SuperSpike] Epoch 7 Train | Loss: 0.7415 | Acc: 74.05% | Time: 31.30s
[SuperSpike] Epoch 7 Test  | Loss: 0.8632 | Acc: 70.06%
[SuperSpike] Epoch 8 Train | Loss: 0.7023 | Acc: 75