# Quantized SNN for MNIST

## Imports

In [None]:
# SNN
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

# Torch
import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Other
import matplotlib.pyplot as plt
import numpy as np

## 1. Setting up MNIST

In [None]:
batch_size = 128
data_path  = './data/mnist'
num_class  = 10

### 1.1 Download Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Make a sanity check plot
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
for i in range(3):
    image, label = mnist_train[i]
    axs[i].imshow(image.squeeze(0).numpy(), cmap="gray")
    axs[i].set_title(f"Label: {label}")

### 1.2 Dataloader and Spike Encoding

In [None]:
# Create dataloader
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

# Fetch a batch of data
data = iter(train_loader)
data_it, targets_it = next(data)

# Encode it into spikes
#   - linear, normalize, clip are IMPORTANT for the spikegen.latency function
#     details here: https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_1.html
spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)

# Plot to keep me sane
fig = plt.figure(facecolor="w", figsize=(12, 3))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(100, -1), ax, s=25, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

### 1.3 Network Definition

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

net = nn.Sequential(
    nn.Linear(32*32, 10),
    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True),
).to(device)

### 1.4 Training

In [None]:
def forward_pass(net, num_steps, data):
    mem_rec = []
    spk_rec = []
    utils.reset(net)

    for step in range(num_steps):
        x = data[step]
        x = x.view(x.size(0), -1)
        spk_out, mem_out = net(x)
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)

    return torch.stack(spk_rec), torch.stack(mem_rec)

spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
spike_data = spike_data.to(device)
spk_rec, mem_rec = forward_pass(net, 100, spike_data)

In [None]:
loss_hist = []
acc_hist = []

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-2, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1
num_iters = 50
num_steps = 100

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = spikegen.latency(data, num_steps=100, tau=5, threshold=0.01, linear=True, normalize=True, clip=True)
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)
        print(f"Accuracy: {acc * 100:.2f}%\n")

        # training loop breaks after 50 iterations
        if i == num_iters:
            break

In [None]:
fig = plt.figure(facecolor="w")
plt.plot(acc_hist)
plt.title("Train Set Accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.show()