In [1]:
# debug
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
from spikingjelly.activation_based import neuron, functional, layer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current device: {torch.cuda.get_device_name(0)}")
    print(f"Device count: {torch.cuda.device_count()}") # 应该输出 1

PyTorch version: 2.9.1+cu128
CUDA available: True
Current device: NVIDIA GeForce RTX 3090
Device count: 1


In [3]:
# ----------------------------------------
# 1. 定义超参数和设置
# ----------------------------------------

T = 8             # 仿真总时长 (SNN 的关键参数)
BATCH_SIZE = 64   # 批处理大小
EPOCHS = 10       # 训练轮数 (为快速演示，设置较小)
LR = 1e-3         # 学习率
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"--- 实验设置 ---")
print(f"设备 (DEVICE): {DEVICE}")
print(f"仿真时长 (T): {T}")
print(f"批大小 (BATCH_SIZE): {BATCH_SIZE}")
print(f"训练轮数 (EPOCHS): {EPOCHS}")
print(f"------------------\n")

--- 实验设置 ---
设备 (DEVICE): cuda:0
仿真时长 (T): 8
批大小 (BATCH_SIZE): 64
训练轮数 (EPOCHS): 10
------------------



In [4]:
# ----------------------------------------
# 2. 加载和预处理 CIFAR10 数据集
# ----------------------------------------
print("正在加载 CIFAR10 数据集...")
# CIFAR10 图像的均值和标准差 (用于归一化)
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(), # 简单数据增强：随机翻转
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std)
])

# 加载数据
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print("数据集加载完毕。\n")

正在加载 CIFAR10 数据集...
数据集加载完毕。



In [5]:
# ----------------------------------------
# 3. 定义基础的卷积 SNN 模型
# ----------------------------------------
# 使用 nn.Sequential 快速搭建一个简单的 CNN 结构
# 关键在于在激活函数的位置换上 SNN 的脉冲神经元

class BasicCSNN(nn.Module):
    def __init__(self, T):
        super().__init__()
        self.T = T  # 保存仿真时长

        # 定义网络结构
        # 结构：[卷积 -> 脉冲 -> 池化] x 2 -> [展平 -> 全连接 -> 脉冲] -> [全连接]
        self.net = nn.Sequential(
            # 块 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            # --- 核心：使用 LIF 神经元 ---
            # 激活驱动:LIFNode 在前向传播时模拟 LIF 神经元动力学，在反向传播时，SpikingJelly 会自动使用“替代梯度”进行计算。
            neuron.LIFNode(),
            nn.MaxPool2d(2),  # 32x32 -> 16x16

            # 块 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            neuron.LIFNode(),
            nn.MaxPool2d(2),  # 16x16 -> 8x8

            # 展平
            nn.Flatten(),

            # 全连接层 1
            nn.Linear(64 * 8 * 8, 128), # 64 * 8 * 8 = 4096
            neuron.LIFNode(),

            # 输出层 (全连接层 2)
            # 输出层通常不使用脉冲神经元，而是直接输出膜电位或累积电流
            # 这样可以方便地与交叉熵损失配合使用
            nn.Linear(128, 10) # 10个类别
        )

    def forward(self, x):
        # --- SNN 算法思路的核心 ---
        # SNN 神经元是有状态的（例如膜电位 V），在处理一个新样本前必须重置
        # 1. 重置网络中所有神经元的状态
        functional.reset_net(self)

        # 准备一个列表来收集 T 个时间步的输出
        # (T, N, C)，T=时间步, N=BatchSize, C=类别数
        outputs_over_time = []

        # 2. SNN 的时间步循环
        # 对于静态图像 (如CIFAR10)，我们在 T 个时间步内输入 *相同* 的图像 x
        # 神经元会在这 T 步内不断累积输入并发放脉冲
        for t in range(self.T):
            # 运行一步前向传播
            out_t = self.net(x)
            outputs_over_time.append(out_t)

        # 3. 聚合 T 个时间步的输出
        # (T, N, 10) -> (T, N, 10)
        outputs_stack = torch.stack(outputs_over_time)
        
        # 4. 解码：计算 T 步内的平均输出
        # (T, N, 10) -> (N, 10)
        # 我们取所有时间步输出的平均值，作为最终的分类 "logits"
        # 这是一种常见的 SNN 解码方式（Rate Coding / Mean Output）
        return outputs_stack.mean(dim=0)

In [6]:
# ----------------------------------------
# 4. 初始化模型、损失函数和优化器
# ----------------------------------------
model = BasicCSNN(T=T).to(DEVICE)
# print("模型结构:\n", model) # (取消注释以查看模型)

# 使用标准的交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 使用 Adam 优化器
optimizer = optim.Adam(model.parameters(), lr=LR)

In [7]:
# ----------------------------------------
# 5. 编写训练和评估循环
# ----------------------------------------

# --- 训练函数 (Train Loop) ---
def train_epoch(epoch):
    model.train()  # 设置为训练模式
    total_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        # 1. 梯度清零
        optimizer.zero_grad()

        # 2. 前向传播
        #    模型内部会自动处理 T 个时间步的循环
        outputs = model(inputs)

        # 3. 计算损失
        loss = criterion(outputs, targets)

        # 4. 反向传播 (核心)
        #    PyTorch 在这里调用 .backward()
        #    SpikingJelly 的 LIFNode 会自动拦截梯度计算，
        #    并用“替代梯度”替换掉不可导的脉冲激活函数梯度。
        loss.backward()

        # 5. 更新参数
        optimizer.step()

        # 统计损失和准确率
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if (batch_idx + 1) % 100 == 0:
            print(f"  [Epoch {epoch+1}/{EPOCHS}, Batch {batch_idx+1}/{len(train_loader)}] "
                  f"Loss: {total_loss / (batch_idx + 1):.4f} | "
                  f"Acc: {100. * correct / total:.2f}%")
    
    end_time = time.time()
    print(f"Epoch {epoch+1} 训练完成。用时: {end_time - start_time:.2f}秒")
    print(f"  训练集平均 Loss: {total_loss / len(train_loader):.4f}, "
          f"训练集准确率: {100. * correct / total:.2f}%")

# --- 评估函数 (Eval Loop) ---
def test_epoch(epoch):
    model.eval()  # 设置为评估模式
    total_loss = 0.0
    correct = 0
    total = 0

    # 评估时不需要计算梯度
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            # 前向传播
            outputs = model(inputs)

            # 计算损失
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            # 统计准确率
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_acc = 100. * correct / total
    test_loss = total_loss / len(test_loader)
    print(f"--- Epoch {epoch+1} 测试结果 ---")
    print(f"  测试集 Loss: {test_loss:.4f}, 测试集准确率 (Acc): {test_acc:.2f}%")
    print(f"--------------------------\n")
    return test_acc

In [8]:
# ----------------------------------------
# 6. 开始训练
# ----------------------------------------
print("=== 开始训练 ===")

best_acc = 0.0
for epoch in range(EPOCHS):
    train_epoch(epoch)
    test_acc = test_epoch(epoch)
    
    if test_acc > best_acc:
        best_acc = test_acc

print(f"=== 训练完成 ===")
print(f"在 {EPOCHS} 轮训练后，最佳测试集准确率为: {best_acc:.2f}%")

=== 开始训练 ===
  [Epoch 1/10, Batch 100/782] Loss: 2.1331 | Acc: 21.31%
  [Epoch 1/10, Batch 200/782] Loss: 1.9418 | Acc: 29.46%
  [Epoch 1/10, Batch 300/782] Loss: 1.8233 | Acc: 33.80%
  [Epoch 1/10, Batch 400/782] Loss: 1.7321 | Acc: 37.36%
  [Epoch 1/10, Batch 500/782] Loss: 1.6718 | Acc: 39.66%
  [Epoch 1/10, Batch 600/782] Loss: 1.6236 | Acc: 41.30%
  [Epoch 1/10, Batch 700/782] Loss: 1.5816 | Acc: 42.83%
Epoch 1 训练完成。用时: 16.02秒
  训练集平均 Loss: 1.5531, 训练集准确率: 43.86%
--- Epoch 1 测试结果 ---
  测试集 Loss: 1.3140, 测试集准确率 (Acc): 52.39%
--------------------------

  [Epoch 2/10, Batch 100/782] Loss: 1.2630 | Acc: 54.64%
  [Epoch 2/10, Batch 200/782] Loss: 1.2340 | Acc: 56.11%
  [Epoch 2/10, Batch 300/782] Loss: 1.2259 | Acc: 56.34%
  [Epoch 2/10, Batch 400/782] Loss: 1.2158 | Acc: 56.75%
  [Epoch 2/10, Batch 500/782] Loss: 1.2041 | Acc: 57.18%
  [Epoch 2/10, Batch 600/782] Loss: 1.1957 | Acc: 57.41%
  [Epoch 2/10, Batch 700/782] Loss: 1.1849 | Acc: 57.78%
Epoch 2 训练完成。用时: 13.68秒
  训练集平均 Loss: 