In [8]:
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from data import get_mnist_dataloaders, get_cifar10_dataloaders
from snn import SNNMLP, SNNResNet18, poisson_encode, static_encode


def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
	pred = logits.argmax(dim=1)
	return (pred == targets).float().mean().item()


def run_epoch(
	model: nn.Module,
	loader,
	device: torch.device,
	optimizer: torch.optim.Optimizer | None,
	steps: int,
	train: bool,
	progress_desc: str,
) -> tuple[float, float]:
	if train:
		model.train(mode=train)
	else:
		model.eval()
	criterion = nn.CrossEntropyLoss()
	total_loss = 0.0
	total_acc = 0.0
	count = 0
	for images, labels in tqdm(loader, desc=progress_desc, leave=False):
		images = images.to(device)
		labels = labels.to(device)
		if train:
			optimizer.zero_grad(set_to_none=True)
		# 泊松编码时间序列 → 模型 → 平均 logits
		spike_seq = static_encode(images, steps)
		T = len(spike_seq)

		# 生成时间维度的随机排列索引
		perm = torch.randperm(T)

		# 根据该随机排列重新索引时间维度
		spike_seq = spike_seq[perm, :, :]
		logits = model.forward_sequence(spike_seq)
		loss = criterion(logits, labels)
		if train:
			loss.backward()
			optimizer.step()
		acc = accuracy(logits.detach(), labels)
		batch = images.size(0)
		total_loss += loss.detach().item() * batch
		total_acc += acc * batch
		count += batch
	return total_loss / count, total_acc / count

In [None]:
from __future__ import annotations

from typing import Iterable

import torch
from torch import nn

from snn.neurons import LIFNeuron, SpikeConv2d, SpikeOutputLayer


class SNNVGG9(nn.Module):
	"""VGG9的SNN版本
	
	VGG9是一个简化的VGG网络，包含9个卷积层。
	适用于CIFAR-10等小尺寸图像数据集。
	对于MNIST等1通道图像，需要调整输入通道数。
	
	结构：
	- Conv Block 1: 64 channels, 2 layers
	- Conv Block 2: 128 channels, 2 layers  
	- Conv Block 3: 256 channels, 2 layers
	- Global Average Pooling
	- Output Layer
	"""
	
	def __init__(
		self,
		num_classes: int = 10,
		in_channels: int = 3,
		norm_layer: type[nn.Module] | None = None,
	):
		super().__init__()
		if norm_layer is None:
			norm_layer = nn.BatchNorm2d
		
		# Conv Block 1: 64 channels
		self.conv_block1 = nn.Sequential(
			SpikeConv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(64),
			LIFNeuron(),
			SpikeConv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(64),
			LIFNeuron(),
		)
		self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
		
		# Conv Block 2: 128 channels
		self.conv_block2 = nn.Sequential(
			SpikeConv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(128),
			LIFNeuron(),
			SpikeConv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(128),
			LIFNeuron(),
		)
		self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
		
		# Conv Block 3: 256 channels
		self.conv_block3 = nn.Sequential(
			SpikeConv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(256),
			LIFNeuron(),
			SpikeConv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
			norm_layer(256),
			LIFNeuron(),
		)
		self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
		
		# 全局平均池化和输出层
		self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
		self.fc = SpikeOutputLayer(256, num_classes)
	
	def reset_state(self) -> None:
		"""重置所有LIF神经元和输出层的状态"""
		for module in self.modules():
			if isinstance(module, (LIFNeuron, SpikeOutputLayer)):
				module.reset_state()
	
 
 	def forward_step1(self, x_t: torch.Tensor) -> torch.Tensor:
		"""步骤1：Conv Block 1 + Pool1"""
		x = self.conv_block1(x_t)
		x = self.pool1(x)
		return x

	def forward_step2(self, x_t: torch.Tensor) -> torch.Tensor:
		"""步骤2：Conv Block 2 + Pool2"""
		x = self.conv_block2(x_t)
		x = self.pool2(x)
		return x

	def forward_step3(self, x_t: torch.Tensor) -> torch.Tensor:
		"""步骤3：Conv Block 3 + Pool3"""
		x = self.conv_block3(x_t)
		x = self.pool3(x)
		return x

	def forward_step4(self, x_t: torch.Tensor) -> torch.Tensor:
		"""步骤4：全局平均池化 + Flatten"""
		x = self.avgpool(x_t)
		x = x.flatten(1)  # [B, 256]
		return x

	def forward_step5(self, x_t: torch.Tensor) -> torch.Tensor:
		"""步骤5：输出层（累积输出）"""
		logits = self.fc(x_t)
		return logits

	def forward_step(self, x_t: torch.Tensor) -> torch.Tensor:
		"""完整的单时间步前向传播（保持向后兼容）"""
		x = self.forward_step1(x_t)
		x = self.forward_step2(x)
		x = self.forward_step3(x)
		x = self.forward_step4(x)
		logits = self.forward_step5(x)
		return logits

	def forward_sequence(
		self, 
		spike_sequence: Iterable[torch.Tensor] | torch.Tensor,
		shuffle_after_steps: list[int] | None = None
	) -> torch.Tensor:
		"""处理时间序列的脉冲输入，支持在指定步骤后打乱时间步
		
		Args:
			spike_sequence: 时间序列的脉冲输入（tensor [T, B, ...] 或生成器）
			shuffle_after_steps: 在哪些步骤后打乱时间步，例如 [1, 2] 表示在步骤1和2后打乱
								步骤编号：
								1: forward_step1 后（conv_block1 + pool1）
								2: forward_step2 后（conv_block2 + pool2）
								3: forward_step3 后（conv_block3 + pool3）
								4: forward_step4 后（avgpool + flatten）
		
		Returns:
			logits: [B, num_classes]
		"""
		self.reset_state()
		
		# 先收集所有时间步
		if isinstance(spike_sequence, torch.Tensor):
			# 如果是tensor [T, B, C, H, W]
			spike_list = [spike_sequence[i] for i in range(spike_sequence.shape[0])]
		else:
			# 如果是生成器，转换为列表
			spike_list = list(spike_sequence)
		
		T = len(spike_list)
		if shuffle_after_steps is None:
			shuffle_after_steps = []
		
		# 步骤1：Conv Block 1 + Pool1
		# 对所有时间步执行步骤1
		x_sequence = [self.forward_step1(x_t) for x_t in spike_list]
		
		# 在步骤1后打乱（如果需要）
		if 1 in shuffle_after_steps:
			x_tensor = torch.stack(x_sequence, dim=0)  # [T, B, C, H, W]
			perm = torch.randperm(T, device=x_tensor.device)
			x_tensor = x_tensor[perm]
			x_sequence = [x_tensor[i] for i in range(T)]
			# 重置步骤1之后所有层的状态
			for module in self.conv_block2.modules():
				if isinstance(module, LIFNeuron):
					module.reset_state()
			for module in self.conv_block3.modules():
				if isinstance(module, LIFNeuron):
					module.reset_state()
			self.fc.reset_state()
		
		# 步骤2：Conv Block 2 + Pool2
		x_sequence = [self.forward_step2(x_t) for x_t in x_sequence]
		
		# 在步骤2后打乱（如果需要）
		if 2 in shuffle_after_steps:
			x_tensor = torch.stack(x_sequence, dim=0)
			perm = torch.randperm(T, device=x_tensor.device)
			x_tensor = x_tensor[perm]
			x_sequence = [x_tensor[i] for i in range(T)]
			# 重置步骤2之后所有层的状态
			for module in self.conv_block3.modules():
				if isinstance(module, LIFNeuron):
					module.reset_state()
			self.fc.reset_state()
		
		# 步骤3：Conv Block 3 + Pool3
		x_sequence = [self.forward_step3(x_t) for x_t in x_sequence]
		
		# 在步骤3后打乱（如果需要）
		if 3 in shuffle_after_steps:
			x_tensor = torch.stack(x_sequence, dim=0)
			perm = torch.randperm(T, device=x_tensor.device)
			x_tensor = x_tensor[perm]
			x_sequence = [x_tensor[i] for i in range(T)]
			# 重置步骤3之后所有层的状态
			self.fc.reset_state()
		
		# 步骤4：全局平均池化 + Flatten
		x_sequence = [self.forward_step4(x_t) for x_t in x_sequence]
		
		# 在步骤4后打乱（如果需要）
		if 4 in shuffle_after_steps:
			x_tensor = torch.stack(x_sequence, dim=0)  # [T, B, 256]
			perm = torch.randperm(T, device=x_tensor.device)
			x_tensor = x_tensor[perm]
			x_sequence = [x_tensor[i] for i in range(T)]
			self.fc.reset_state()
		
		# 步骤5：输出层（累积所有时间步）
		for x_t in x_sequence:
			logits_t = self.forward_step5(x_t)
		
		# SpikeOutputLayer 已经累积了所有时间步的输出
		return logits_t



IndentationError: expected an indented block after function definition on line 79 (1697412634.py, line 80)

In [6]:
import torch

torch.random.manual_seed(42)
# 配置参数
ckpt_path = "checkpoints/best_snn_vgg9_cifar10.pt"  # 修改为你的模型和数据集类型保存的模型路径
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
steps = 20

# 数据加载
_, test_loader = get_cifar10_dataloaders(batch_size=batch_size)

# 加载模型
model = SNNVGG9(num_classes=10, in_channels=3)
model.to(device)

# 加载权重
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()

# 测试
test_loss, test_acc = run_epoch(model, test_loader, device, optimizer=None, steps=steps, train=False, progress_desc="Test")
print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")



                                                       

Test loss: 0.9841, Test accuracy: 0.7608


