<a href="https://colab.research.google.com/github/CSteennis/BscThesis/blob/main/MNIST/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!pip install snntorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [13]:


import matplotlib.pyplot as plt
import snntorch.functional as sf
import snntorch as snn
from snntorch import spikegen
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy.signal import savgol_filter
import numpy as np

In [14]:
class Net(nn.Module):
    '''Spiking Neural network'''
    def __init__(self, num_inputs, num_hiddens, num_outputs, beta):
        super().__init__()

        self.num_inputs = num_inputs # number of inputs
        self.num_hidden = num_hiddens # number of hidden neurons
        self.num_outputs = num_outputs # number of output neurons

        # initialize layers
        self.fc1 = nn.Linear(self.num_inputs, self.num_hidden) # connection input and hidden layer
        self.lif1 = snn.Leaky(beta=beta) # hidden layer
        self.fc2 = nn.Linear(self.num_hidden, self.num_outputs) # connection hidden layer and output
        self.lif2 = snn.Leaky(beta=beta) # output layer

    def forward(self, data, num_steps):
        '''Run the network for ``num_steps`` with ``data`` as input. Output spiketrains of outputs'''
        # initialize membrane potentials for hidden and output layer
        mem_hid = self.lif1.init_leaky()
        mem_out = self.lif1.init_leaky()

        spike_out_rec = []

        for i in range(num_steps):
            input = self.fc1(data[i])
            spike_hid, mem_hid = self.lif1(input, mem_hid)
            hidden_out = self.fc2(spike_hid)
            spike_out, mem_out = self.lif2(hidden_out, mem_out)

            spike_out_rec.append(spike_out)

        return torch.stack(spike_out_rec)

In [15]:

def import_data():
    # Define a transform
    transform = transforms.Compose([
                transforms.Resize((28,28)),
                transforms.Grayscale(),
                transforms.ToTensor(),
                transforms.Normalize((0,), (1,))])

    mnist_train = datasets.MNIST("/dataset/", train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST("/dataset/", train=False, download=True, transform=transform)

    train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=64)


    return train_loader, test_loader


In [16]:

# rate encode pixels to binary representation for input trains
def gen_spike_trains(data, n_steps):
    ''' Generate spike train
        In: [num_steps, batch, input_size]
        Out: [num_steps, batch, input_size]
    '''
    spike_data = spikegen.rate(data.flatten(1), num_steps=n_steps)
    return spike_data


In [17]:

def plot_accuracy(acc_hist, title):
    %matplotlib inline
    fig = plt.figure(facecolor="w")
    plt.plot(acc_hist)
    plt.title(title)
    plt.xlabel("Batch")
    plt.ylabel("Accuracy")
    # plt.savefig(title+".png")
    plt.show()

def plot_loss(loss_hist, title):
    %matplotlib inline
    fig = plt.figure(facecolor="w")
    plt.plot(loss_hist)
    plt.title(title)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    # plt.savefig(title+".png")
    plt.show()


In [8]:

def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data.to(device)
      targets.to(device)
      input = gen_spike_trains(data, num_steps)
      spk_rec = net(input, num_steps)

      acc += sf.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total


In [9]:

def train_snn(net:Net, optimizer:torch.optim.Adam, loss_fn:sf.mse_count_loss, train_loader, test_loader, n_steps, epochs):
    '''Training loop for snn'''

    acc_hist = []
    loss_hist = []
    test_acc_hist = []
    acc_per_epoch = []

    for epoch in range(epochs):
        # for i, (data, label) in enumerate(tqdm(iter(train_loader))):
        with tqdm(train_loader, unit="batch") as tqepch:
            tqepch.set_description(desc=f"Epoch {epoch}")
            for data, label in tqepch:
                data.to(device)
                label.to(device)

                # convert input to spike trains
                input = gen_spike_trains(data.squeeze(), n_steps)

                # set net to training mode
                net.train()

                # do forward pass
                output = net(input, n_steps)

                # calculate loss value
                loss_val = loss_fn(output, label)
                loss_hist.append(loss_val.item())

                # clear previously stored gradients
                optimizer.zero_grad()

                # calculate the gradients
                loss_val.backward()

                # weight update
                optimizer.step()

                # determine batch accuracy
                acc = sf.accuracy_rate(output, label)
                acc_hist.append(acc)

                tqepch.set_postfix(loss=loss_val.item(), accuracy=f'{acc * 100:.2f}')

        # accuracy per epoch
        acc_per_epoch.append(acc_hist)

        # accuracy on test set for epoch
        test_acc = test_accuracy(test_loader, net, n_steps)
        test_acc_hist.append(test_acc)
        
        print(f'Test accuracy: {test_acc * 100:.2f}%')

    # take the mean of all the epochs
    acc_per_epoch = np.mean(acc_per_epoch, axis=0)
    
    # smoothing
    acc_per_epoch = savgol_filter(acc_per_epoch,10,1)
    loss_hist = savgol_filter(loss_hist,10,1)

    # plot
    plot_accuracy(acc_per_epoch, "Train accuracy")
    plot_loss(loss_hist, "Train loss")
    plot_accuracy(test_acc_hist, "Test accuracy")

    # return trained network
    return net


In [10]:

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# number of epochs
epochs = 15

# number of time steps
n_steps = 25 #ms

# neuron counts
inputs = 28 * 28
hiddens = 200
outputs = 10

# import training and test data
train_loader, test_loader = import_data()

# membrane potential decay
decay = 0.9

# initialize net
net = Net(inputs, hiddens, outputs, decay).to(device)

# optimalizatie algoritme
optimizer = torch.optim.Adam(net.parameters()) # (NOTE: Adam stond in de tutorial wellicht beter algo)

# loss function
loss_fn = sf.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2) # type: ignore


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/train-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw



In [11]:

trained_net = train_snn(net, optimizer, loss_fn, train_loader, test_loader, n_steps, epochs)


Epoch 0: 100%|██████████| 469/469 [01:30<00:00,  5.21batch/s, accuracy=97.92, loss=0.0771]


Test accuracy: 95.58%


Epoch 1:  81%|████████▏ | 382/469 [01:09<00:15,  5.48batch/s, accuracy=96.09, loss=0.104]


KeyboardInterrupt: ignored