# All SNN Models Comparison on MNIST

Models:
1. **Baseline LIF** - Standard Leaky Integrate-and-Fire
2. **DASNN** - Dendritic Attention SNN
3. **Spiking-KAN** - Kolmogorov-Arnold Network with spikes
4. **NEXUS-SNN** - Multi-innovation SNN
5. **APEX-SNN** - Ultimate combined model

In [None]:
!nvidia-smi

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import time

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Core Components

In [None]:
class ATanSurrogate(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha=2.0):
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        return (x >= 0).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        grad = ctx.alpha / (2 * (1 + (np.pi/2 * ctx.alpha * x)**2))
        return grad * grad_output, None

def spike_fn(x):
    return ATanSurrogate.apply(x, 2.0)

class LIFNeuron(nn.Module):
    def __init__(self, tau=2.0, threshold=1.0):
        super().__init__()
        self.tau = tau
        self.threshold = threshold
        self.beta = 1.0 - 1.0 / tau
        self.v = None
    
    def reset(self):
        self.v = None
    
    def forward(self, x):
        if self.v is None:
            self.v = torch.zeros_like(x)
        self.v = self.beta * self.v + x
        spike = spike_fn(self.v - self.threshold)
        self.v = self.v - spike * self.threshold
        return spike

class AdaptiveLIF(nn.Module):
    def __init__(self, size, tau=2.0):
        super().__init__()
        self.tau = nn.Parameter(torch.ones(size) * tau)
        self.threshold = nn.Parameter(torch.ones(size))
        self.v = None
    
    def reset(self):
        self.v = None
    
    def forward(self, x):
        if self.v is None:
            self.v = torch.zeros_like(x)
        beta = 1.0 - 1.0 / self.tau.abs().clamp(min=1.1)
        self.v = beta * self.v + x
        spike = spike_fn(self.v - self.threshold.abs())
        self.v = self.v - spike * self.threshold.abs()
        return spike

## Model 1: Baseline LIF

In [None]:
class BaselineLIF(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.lif1 = LIFNeuron()
        self.fc2 = nn.Linear(512, 256)
        self.lif2 = LIFNeuron()
        self.fc3 = nn.Linear(256, 10)
    
    def reset(self):
        self.lif1.reset()
        self.lif2.reset()
    
    def forward(self, x, timesteps=4):
        self.reset()
        x = x.view(x.size(0), -1)
        outputs = []
        spikes = []
        
        for t in range(timesteps):
            h = self.lif1(self.fc1(x))
            spikes.append(h.detach())
            h = self.lif2(self.fc2(h))
            spikes.append(h.detach())
            out = self.fc3(h)
            outputs.append(out)
        
        spike_rate = torch.cat([s.flatten() for s in spikes]).mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## Model 2: DASNN (Dendritic Attention SNN)

In [None]:
class DendriticNeuron(nn.Module):
    def __init__(self, in_features, out_features, n_branches=4):
        super().__init__()
        self.n_branches = n_branches
        self.branches = nn.ModuleList([
            nn.Linear(in_features, out_features) for _ in range(n_branches)
        ])
        self.gate = nn.Linear(in_features, n_branches)
        self.lif = LIFNeuron()
    
    def reset(self):
        self.lif.reset()
    
    def forward(self, x):
        gates = torch.softmax(self.gate(x), dim=-1)
        branch_outs = torch.stack([b(x) for b in self.branches], dim=-1)
        weighted = (branch_outs * gates.unsqueeze(1)).sum(-1)
        return self.lif(weighted)

class DASNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.dn1 = DendriticNeuron(784, 512)
        self.dn2 = DendriticNeuron(512, 256)
        self.fc = nn.Linear(256, 10)
    
    def reset(self):
        self.dn1.reset()
        self.dn2.reset()
    
    def forward(self, x, timesteps=4):
        self.reset()
        x = x.view(x.size(0), -1)
        outputs = []
        spikes = []
        
        for t in range(timesteps):
            h = self.dn1(x)
            spikes.append(h.detach())
            h = self.dn2(h)
            spikes.append(h.detach())
            outputs.append(self.fc(h))
        
        spike_rate = torch.cat([s.flatten() for s in spikes]).mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## Model 3: Spiking-KAN

In [None]:
class SpikingKANLayer(nn.Module):
    def __init__(self, in_features, out_features, degree=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.degree = degree
        
        self.coeffs = nn.Parameter(torch.randn(in_features, out_features, degree + 1) * 0.1)
        self.lif = LIFNeuron()
    
    def reset(self):
        self.lif.reset()
    
    def forward(self, x):
        batch_size = x.size(0)
        x_expanded = x.unsqueeze(-1).unsqueeze(-1)
        
        powers = torch.stack([x_expanded.pow(i) for i in range(self.degree + 1)], dim=-1)
        powers = powers.squeeze(-2)
        
        out = torch.einsum('bi,iod,bid->bo', x, self.coeffs, powers)
        return self.lif(out)

class SpikingKAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.kan1 = SpikingKANLayer(784, 256, degree=3)
        self.kan2 = SpikingKANLayer(256, 128, degree=3)
        self.fc = nn.Linear(128, 10)
    
    def reset(self):
        self.kan1.reset()
        self.kan2.reset()
    
    def forward(self, x, timesteps=4):
        self.reset()
        x = x.view(x.size(0), -1)
        outputs = []
        spikes = []
        
        for t in range(timesteps):
            h = self.kan1(x)
            spikes.append(h.detach())
            h = self.kan2(h)
            spikes.append(h.detach())
            outputs.append(self.fc(h))
        
        spike_rate = torch.cat([s.flatten() for s in spikes]).mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## Model 4: NEXUS-SNN

In [None]:
class TemporalAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.scale = dim ** -0.5
    
    def forward(self, x_list):
        if len(x_list) == 1:
            return x_list[0]
        x = torch.stack(x_list, dim=1)
        q, k, v = self.query(x), self.key(x), self.value(x)
        attn = torch.softmax(torch.bmm(q, k.transpose(-1, -2)) * self.scale, dim=-1)
        return torch.bmm(attn, v)[:, -1]

class NEXUSSNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.lif1 = AdaptiveLIF(512)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.lif2 = AdaptiveLIF(256)
        
        self.attn = TemporalAttention(256)
        self.fc3 = nn.Linear(256, 10)
    
    def reset(self):
        self.lif1.reset()
        self.lif2.reset()
    
    def forward(self, x, timesteps=6):
        self.reset()
        x = x.view(x.size(0), -1)
        outputs = []
        hidden_states = []
        spikes = []
        
        for t in range(timesteps):
            h = self.lif1(self.bn1(self.fc1(x)))
            spikes.append(h.detach())
            h = self.lif2(self.bn2(self.fc2(h)))
            spikes.append(h.detach())
            hidden_states.append(h)
            
            attended = self.attn(hidden_states)
            outputs.append(self.fc3(attended))
        
        spike_rate = torch.cat([s.flatten() for s in spikes]).mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## Model 5: APEX-SNN

In [None]:
class APEXSNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.lif1 = LIFNeuron()
        
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.lif2 = LIFNeuron()
        
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.lif3 = LIFNeuron()
        
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.lif4 = AdaptiveLIF(256)
        self.fc2 = nn.Linear(256, 10)
        
        self.pool = nn.AvgPool2d(2)
    
    def reset(self):
        for m in self.modules():
            if hasattr(m, 'reset') and m is not self:
                m.reset()
    
    def forward(self, x, timesteps=4):
        self.reset()
        if x.dim() == 3:
            x = x.unsqueeze(1)
        
        outputs = []
        spikes = []
        
        for t in range(timesteps):
            h = self.lif1(self.bn1(self.conv1(x)))
            spikes.append(h.detach())
            h = self.pool(h)
            
            h = self.lif2(self.bn2(self.conv2(h)))
            spikes.append(h.detach())
            h = self.pool(h)
            
            h = self.lif3(self.bn3(self.conv3(h)))
            spikes.append(h.detach())
            h = self.pool(h)
            
            h = h.view(h.size(0), -1)
            h = self.lif4(self.fc1(h))
            spikes.append(h.detach())
            outputs.append(self.fc2(h))
        
        spike_rate = torch.cat([s.flatten() for s in spikes]).mean().item()
        return torch.stack(outputs).mean(0), spike_rate

## Data Loading

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)

print(f'Train: {len(train_data)}, Test: {len(test_data)}')

## Training Function

In [None]:
def train_model(model, train_loader, test_loader, epochs=20, lr=1e-3, timesteps=4):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    
    history = {'acc': [], 'spike_rate': []}
    best_acc = 0
    
    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output, _ = model(data, timesteps)
            loss = F.cross_entropy(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()
        
        model.eval()
        correct, total = 0, 0
        spike_rates = []
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output, spike_rate = model(data, timesteps)
                spike_rates.append(spike_rate)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        acc = 100. * correct / total
        avg_spike = np.mean(spike_rates)
        history['acc'].append(acc)
        history['spike_rate'].append(avg_spike)
        
        if acc > best_acc:
            best_acc = acc
        
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch+1:2d} | Acc: {acc:.2f}% | Spike: {avg_spike:.4f}')
    
    return best_acc, history

## Run All Models

In [None]:
EPOCHS = 20
results = {}

models = {
    'Baseline LIF': (BaselineLIF(), 4),
    'DASNN': (DASNN(), 4),
    'Spiking-KAN': (SpikingKAN(), 4),
    'NEXUS-SNN': (NEXUSSNN(), 6),
    'APEX-SNN': (APEXSNN(), 4),
}

for name, (model, timesteps) in models.items():
    print(f'\n{"="*50}')
    print(f'Training {name}')
    print(f'{"="*50}')
    
    model = model.to(device)
    params = sum(p.numel() for p in model.parameters())
    print(f'Parameters: {params:,}')
    
    t0 = time.time()
    best_acc, history = train_model(model, train_loader, test_loader, EPOCHS, timesteps=timesteps)
    train_time = time.time() - t0
    
    results[name] = {
        'accuracy': best_acc,
        'spike_rate': history['spike_rate'][-1],
        'parameters': params,
        'timesteps': timesteps,
        'time': train_time,
        'history': history
    }
    
    print(f'\n{name}: {best_acc:.2f}% ({train_time:.1f}s)')

## Results Comparison

In [None]:
print('\n' + '=' * 70)
print('MNIST BENCHMARK RESULTS')
print('=' * 70)
print(f'{"Model":<15} {"Accuracy":<12} {"Spike Rate":<12} {"Params":<12} {"Time":<10}')
print('-' * 70)

for name, data in sorted(results.items(), key=lambda x: -x[1]['accuracy']):
    print(f'{name:<15} {data["accuracy"]:.2f}% {data["spike_rate"]:.4f} {data["parameters"]:>10,} {data["time"]:.1f}s')

print('=' * 70)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for name, data in results.items():
    axes[0].plot(data['history']['acc'], label=name)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy (%)')
axes[0].set_title('Model Comparison: Accuracy')
axes[0].legend()
axes[0].grid(True)

names = list(results.keys())
accs = [results[n]['accuracy'] for n in names]
spikes = [results[n]['spike_rate'] for n in names]

x = np.arange(len(names))
width = 0.35

ax1 = axes[1]
ax2 = ax1.twinx()

bars1 = ax1.bar(x - width/2, accs, width, label='Accuracy', color='steelblue')
bars2 = ax2.bar(x + width/2, spikes, width, label='Spike Rate', color='coral')

ax1.set_xlabel('Model')
ax1.set_ylabel('Accuracy (%)', color='steelblue')
ax2.set_ylabel('Spike Rate', color='coral')
ax1.set_xticks(x)
ax1.set_xticklabels(names, rotation=45, ha='right')
ax1.set_title('Accuracy vs Spike Rate')

plt.tight_layout()
plt.savefig('all_models_comparison.png', dpi=150)
plt.show()

In [None]:
import json

results_save = {
    name: {k: v for k, v in data.items() if k != 'history'}
    for name, data in results.items()
}

with open('mnist_benchmark_results.json', 'w') as f:
    json.dump(results_save, f, indent=2)

print('Results saved to mnist_benchmark_results.json')