# SNN on MNIST

In this notebook I train a Spiking Neural network on the MNIST dataset


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/{user_name}/{repo_name}/blob/{branch_name}/mnist.ipynb)

In [None]:
%pip install snntorch
%pip install torchmetrics


### Imports

In [None]:

import matplotlib.pyplot as plt
import snntorch.functional as sf
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torchmetrics.classification import MulticlassAccuracy

from snn_net import Net
from snn_net import train

### Helper functions to import data and plot the accuracy

In [None]:

def import_data():
    # Define a transform
    transform = transforms.Compose([
                transforms.Resize((28,28)),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize((0,), (1,))])
    
    # gather data
    mnist_train = datasets.MNIST("/dataset/", train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST("/dataset/", train=False, download=True, transform=transform)
    
    # create batches
    train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=64) 
    
    return train_loader, test_loader


In [None]:


def plot_accuracy(acc_hist, title, test=False):
    plt.plot(acc_hist)
    plt.title(title)
    plt.xlabel("Epoch" if test else "Batch")
    plt.ylabel("Accuracy")
    # plt.savefig(title+".png")
    plt.show()

def plot_loss(loss_hist, title):
    plt.plot(loss_hist)
    plt.title(title)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    # plt.savefig(title+".png")
    plt.show()


### Hyper parameters

In [None]:

# number of epochs
epochs = 1

# number of time steps
n_steps = 25 #ms

# neuron counts
inputs = 28 * 28
hiddens = 200
outputs = 10

# membrane potential decay
decay = 0.9


In [None]:

# import training and test data
train_loader, test_loader = import_data()


### Train command for the snn

#### Encoding schemes

##### Rate encoded

In [None]:

print("SNN rate:")

# initialize net
rate_snn = Net(inputs, hiddens, outputs, decay, n_steps, enc_type='rate')

# optimization algoritm
optimizer = torch.optim.Adam(rate_snn.parameters()) # (NOTE: Adam stond in de tutorial misschien beter algoritme)

# loss function
loss_fn = sf.ce_count_loss() # type: ignore

# accuracy function
accuracy = sf.accuracy_rate

test_acc_snn_rate = train(rate_snn, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


##### Time encoded

In [None]:

print("SNN temporal:")

# initialize net
temp_snn = Net(inputs, hiddens, outputs, decay, n_steps, 'latency')

# optimization algoritm
optimizer = torch.optim.Adam(temp_snn.parameters()) # (NOTE: Adam stond in de tutorial misschien beter algoritme)

# loss function
loss_fn = sf.ce_temporal_loss() # type: ignore

# accuracy function
accuracy = sf.accuracy_temporal

test_acc_snn_temp = train(temp_snn, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


#### Test net on single img

In [None]:

data = next(iter(train_loader))[0].squeeze()

img = 0

plt.imshow(data[img],cmap="Greys")

output = rate_snn.forward(data.flatten(1))

fig, ax = plt.subplots(5,2)

for i in range(5):
    for j in range(2):
        ax[i][j].plot(output[:,img,2*i+j].detach().numpy())
        ax[i][j].set_title(f'{2*i+j}')
        ax[i][j].set_ybound(-0.2,1.2)

fig.tight_layout()


### Train command for Feed Forward net

In [None]:

feed_fwd_net = nn.Sequential(nn.Linear(inputs, hiddens),
                            nn.ReLU(),
                            nn.Linear(hiddens, outputs))

# optimization algoritm
optimizer = torch.optim.Adam(feed_fwd_net.parameters()) # (NOTE: Adam stond in de tutorial misschien beter algoritme)

# loss function
loss_fn = nn.CrossEntropyLoss()

# accuracy function
accuracy = MulticlassAccuracy(num_classes=outputs)

print("FFN:")
test_acc_feed = train(feed_fwd_net, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


### Comparison

In [None]:

fig = plt.figure(1)
plt.plot(test_acc_snn_rate, label="SNN rate")
plt.plot(test_acc_snn_temp, label="SNN temporal")
plt.plot(test_acc_feed, label="FFN")
plt.title("Test accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
