# CIFAR-10 SNN Benchmark
**Spiking Neural Networks vs Artificial Neural Networks**

This notebook trains and compares:
- SNN-VGG (Spiking Neural Network)
- ANN Baseline (Standard Neural Network)

**SOTA References:**
- Spikformer V2: ~95-96%
- SEW-ResNet: ~94%

In [None]:
# Check GPU
!nvidia-smi

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import time

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Surrogate Gradient & LIF Neuron

In [None]:
class ATanSurrogate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha=2.0):
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        return (x >= 0).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        alpha = ctx.alpha
        grad = alpha / (2 * (1 + (np.pi/2 * alpha * x)**2))
        return grad * grad_output, None

def spike_fn(x, alpha=2.0):
    return ATanSurrogate.apply(x, alpha)

class LIFNeuron(nn.Module):
    def __init__(self, tau=2.0, v_threshold=1.0):
        super().__init__()
        self.tau = tau
        self.v_threshold = v_threshold
        self.beta = 1.0 - 1.0 / tau
        self.v_mem = None
    
    def reset(self):
        self.v_mem = None
    
    def forward(self, x):
        if self.v_mem is None:
            self.v_mem = torch.zeros_like(x)
        self.v_mem = self.beta * self.v_mem + x
        spike = spike_fn(self.v_mem - self.v_threshold)
        self.v_mem = self.v_mem - spike * self.v_threshold
        return spike

## SNN-VGG Model

In [None]:
class SNN_VGG(nn.Module):
    def __init__(self, tau=2.0):
        super().__init__()
        
        self.features = nn.ModuleList([
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), LIFNeuron(tau),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), LIFNeuron(tau),
            nn.AvgPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), LIFNeuron(tau),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), LIFNeuron(tau),
            nn.AvgPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), LIFNeuron(tau),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), LIFNeuron(tau),
            nn.AvgPool2d(2),
        ])
        
        self.classifier = nn.Linear(256 * 4 * 4, 512)
        self.lif_fc = LIFNeuron(tau)
        self.fc_out = nn.Linear(512, 10)
    
    def reset(self):
        for m in self.modules():
            if isinstance(m, LIFNeuron):
                m.reset()
    
    def forward(self, x, timesteps=4):
        self.reset()
        outputs = []
        all_spikes = []
        
        for t in range(timesteps):
            h = x
            for layer in self.features:
                h = layer(h)
                if isinstance(layer, LIFNeuron):
                    all_spikes.append(h.detach())
            
            h = h.view(h.size(0), -1)
            h = self.classifier(h)
            h = self.lif_fc(h)
            all_spikes.append(h.detach())
            out = self.fc_out(h)
            outputs.append(out)
        
        spike_rate = torch.cat([s.flatten() for s in all_spikes]).float().mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## ANN Baseline

In [None]:
class ANN_VGG(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.AvgPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512), nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        h = self.features(x)
        h = h.view(h.size(0), -1)
        return self.classifier(h)

## Data Loading

In [None]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_data = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10('./data', train=False, download=True, transform=test_transform)

BATCH_SIZE = 128
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'Train: {len(train_data)}, Test: {len(test_data)}')

## Training Functions

In [None]:
def train_snn(model, loader, optimizer, device, timesteps=4):
    model.train()
    total_loss, correct, total = 0, 0, 0
    spike_rates = []
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output, spike_rate = model(data, timesteps)
        spike_rates.append(spike_rate)
        
        loss = F.cross_entropy(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total, np.mean(spike_rates)

@torch.no_grad()
def eval_snn(model, loader, device, timesteps=4):
    model.eval()
    correct, total = 0, 0
    spike_rates = []
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output, spike_rate = model(data, timesteps)
        spike_rates.append(spike_rate)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return 100. * correct / total, np.mean(spike_rates)

def train_ann(model, loader, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total

@torch.no_grad()
def eval_ann(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    return 100. * correct / total

## Train SNN-VGG

In [None]:
EPOCHS = 100
TIMESTEPS = 4

snn_model = SNN_VGG(tau=2.0).to(device)
snn_params = sum(p.numel() for p in snn_model.parameters())
print(f'SNN Parameters: {snn_params:,}')

optimizer = optim.SGD(snn_model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

snn_history = {'acc': [], 'spike_rate': []}
best_snn_acc = 0
best_snn_spike = 0

t0 = time.time()
for epoch in range(EPOCHS):
    train_loss, train_acc, train_spike = train_snn(snn_model, train_loader, optimizer, device, TIMESTEPS)
    test_acc, test_spike = eval_snn(snn_model, test_loader, device, TIMESTEPS)
    scheduler.step()
    
    snn_history['acc'].append(test_acc)
    snn_history['spike_rate'].append(test_spike)
    
    if test_acc > best_snn_acc:
        best_snn_acc = test_acc
        best_snn_spike = test_spike
        torch.save(snn_model.state_dict(), 'best_snn_cifar10.pth')
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1:3d} | Test: {test_acc:.2f}% | Spike: {test_spike:.4f} | Best: {best_snn_acc:.2f}%')

snn_time = time.time() - t0
print(f'\nSNN Training Time: {snn_time/60:.1f} min')
print(f'Best SNN Accuracy: {best_snn_acc:.2f}%')

## Train ANN Baseline

In [None]:
ann_model = ANN_VGG().to(device)
ann_params = sum(p.numel() for p in ann_model.parameters())
print(f'ANN Parameters: {ann_params:,}')

optimizer = optim.SGD(ann_model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

ann_history = {'acc': []}
best_ann_acc = 0

t0 = time.time()
for epoch in range(EPOCHS):
    train_loss, train_acc = train_ann(ann_model, train_loader, optimizer, device)
    test_acc = eval_ann(ann_model, test_loader, device)
    scheduler.step()
    
    ann_history['acc'].append(test_acc)
    
    if test_acc > best_ann_acc:
        best_ann_acc = test_acc
        torch.save(ann_model.state_dict(), 'best_ann_cifar10.pth')
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1:3d} | Test: {test_acc:.2f}% | Best: {best_ann_acc:.2f}%')

ann_time = time.time() - t0
print(f'\nANN Training Time: {ann_time/60:.1f} min')
print(f'Best ANN Accuracy: {best_ann_acc:.2f}%')

## Results Comparison

In [None]:
print('=' * 60)
print('CIFAR-10 BENCHMARK RESULTS')
print('=' * 60)
print(f'{"Model":<15} {"Accuracy":<12} {"Spike Rate":<12} {"Params":<12}')
print('-' * 60)
print(f'{"SNN-VGG":<15} {best_snn_acc:.2f}% {best_snn_spike:.4f} {snn_params:,}')
print(f'{"ANN-VGG":<15} {best_ann_acc:.2f}% {"N/A":<12} {ann_params:,}')
print('-' * 60)
gap = best_ann_acc - best_snn_acc
print(f'\nAccuracy Gap: {gap:.2f}% (Target: <2%)')
print(f'\nTheoretical Energy Ratio: {best_snn_spike * 0.9 / (0.5 * 4.6):.3f}x')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(snn_history['acc'], label='SNN-VGG')
axes[0].plot(ann_history['acc'], label='ANN-VGG')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('CIFAR-10: SNN vs ANN')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(snn_history['spike_rate'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Spike Rate')
axes[1].set_title('SNN Spike Rate During Training')
axes[1].grid(True)

plt.tight_layout()
plt.savefig('cifar10_benchmark.png', dpi=150)
plt.show()