In [None]:
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from data import get_mnist_dataloaders, get_cifar10_dataloaders
from snn import SNNMLP, SNNResNet18, SNNVGG9, poisson_encode, static_encode


def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
	pred = logits.argmax(dim=1)
	return (pred == targets).float().mean().item()


def run_epoch(
	model: nn.Module,
	loader,
	device: torch.device,
	optimizer: torch.optim.Optimizer | None,
	steps: int,
	train: bool,
	progress_desc: str,
) -> tuple[float, float]:
	if train:
		model.train(mode=train)
	else:
		model.eval()
	criterion = nn.CrossEntropyLoss()
	total_loss = 0.0
	total_acc = 0.0
	count = 0
	for images, labels in tqdm(loader, desc=progress_desc, leave=False):
		images = images.to(device)
		labels = labels.to(device)
		if train:
			optimizer.zero_grad(set_to_none=True)
		# 泊松编码时间序列 → 模型 → 平均 logits
		spike_seq = static_encode(images, steps)
  		T = len(spike_seq)

		# 生成时间维度的随机排列索引
		perm = torch.randperm(T)

		# 根据该随机排列重新索引时间维度
		spike_seq = spike_seq[perm, :, :]
		logits = model.forward_sequence(spike_seq)
		loss = criterion(logits, labels)
		if train:
			loss.backward()
			optimizer.step()
		acc = accuracy(logits.detach(), labels)
		batch = images.size(0)
		total_loss += loss.detach().item() * batch
		total_acc += acc * batch
		count += batch
	return total_loss / count, total_acc / count

In [2]:
import torch

torch.random.manual_seed(42)
# 配置参数
ckpt_path = "checkpoints/best_snn_vgg9_cifar10.pt"  # 修改为你的模型和数据集类型保存的模型路径
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
steps = 20

# 数据加载
_, test_loader = get_cifar10_dataloaders(batch_size=batch_size)

# 加载模型
model = SNNVGG9(num_classes=10, in_channels=3)
model.to(device)

# 加载权重
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state'])
model.eval()

# 测试
test_loss, test_acc = run_epoch(model, test_loader, device, optimizer=None, steps=steps, train=False, progress_desc="Test")
print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")



                                                       

Test loss: 0.9841, Test accuracy: 0.7608


