<a href="https://colab.research.google.com/github/NiclasRoer/Event-based-Action-Recognition-using-SNNs/blob/main/SNN_on_MNIST_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install snntorch --quiet

[?25l[K     |███▍                            | 10 kB 17.3 MB/s eta 0:00:01[K     |██████▉                         | 20 kB 7.7 MB/s eta 0:00:01[K     |██████████▎                     | 30 kB 10.4 MB/s eta 0:00:01[K     |█████████████▊                  | 40 kB 4.8 MB/s eta 0:00:01[K     |█████████████████▏              | 51 kB 4.6 MB/s eta 0:00:01[K     |████████████████████▋           | 61 kB 5.4 MB/s eta 0:00:01[K     |████████████████████████        | 71 kB 5.8 MB/s eta 0:00:01[K     |███████████████████████████▌    | 81 kB 5.9 MB/s eta 0:00:01[K     |███████████████████████████████ | 92 kB 6.5 MB/s eta 0:00:01[K     |████████████████████████████████| 95 kB 2.4 MB/s 
[?25h

In [None]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import utils 
from snntorch import backprop
import snntorch.functional as SF
import plotly.express as px

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


batch_size = 128
data_path='/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)




beta = 0.9  # neuron decay rate 
spike_grad = surrogate.fast_sigmoid()

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 8, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(8, 16, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)


def forward_pass(net, data, num_steps):  
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps): 
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
  
  return torch.stack(spk_rec)

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1
num_steps = 25  # run for 25 time steps 

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data, num_steps)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # print every 25 iterations
        if i % 25 == 0:
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)  
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")
        
        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 2.22
Accuracy: 13.28%

Epoch 0, Iteration 25 
Train Loss: 0.66
Accuracy: 50.00%

Epoch 0, Iteration 50 
Train Loss: 0.43
Accuracy: 82.03%

Epoch 0, Iteration 75 
Train Loss: 0.33
Accuracy: 88.28%

Epoch 0, Iteration 100 
Train Loss: 0.33
Accuracy: 86.72%

Epoch 0, Iteration 125 
Train Loss: 0.28
Accuracy: 91.41%

Epoch 0, Iteration 150 
Train Loss: 0.25
Accuracy: 92.19%

Epoch 0, Iteration 175 
Train Loss: 0.24
Accuracy: 95.31%

Epoch 0, Iteration 200 
Train Loss: 0.22
Accuracy: 96.09%

Epoch 0, Iteration 225 
Train Loss: 0.24
Accuracy: 90.62%

Epoch 0, Iteration 250 
Train Loss: 0.19
Accuracy: 96.09%

Epoch 0, Iteration 275 
Train Loss: 0.19
Accuracy: 96.09%

Epoch 0, Iteration 300 
Train Loss: 0.17
Accuracy: 95.31%

Epoch 0, Iteration 325 
Train Loss: 0.19
Accuracy: 94.53%

Epoch 0, Iteration 350 
Train Loss: 0.17
Accuracy: 97.66%

Epoch 0, Iteration 375 
Train Loss: 0.21
Accuracy: 92.97%

Epoch 0, Iteration 400 
Train Loss: 0.18
Accuracy: 96.09%

Ep

In [None]:
num_epochs = 3

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch {epoch}, Train Loss: {avg_loss.item():.2f}")

def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec = forward_pass(net, data, num_steps)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

Epoch 0, Train Loss: 0.14
Epoch 1, Train Loss: 0.12
Epoch 2, Train Loss: 0.11
Test set accuracy: 98.170%


In [None]:
import torch.nn.functional as F

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        num_inputs = 784
        num_hidden = 300
        num_outputs = 10
        spike_grad = surrogate.fast_sigmoid()

        # global decay rate for all leaky neurons in layer 1
        beta1 = 0.9
        # independent decay rate for each leaky neuron in layer 2: [0, 1)
        beta2 = torch.rand((num_outputs), dtype = torch.float) #.to(device)

        # Init layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta1, spike_grad=spike_grad, learn_beta=True)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta2, spike_grad=spike_grad,learn_beta=True)

    def forward(self, x):

        # reset hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
net = Net().to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1
num_steps = 100  # run for 25 time steps 

loss_hist = []
acc_hist = []

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = net(data)
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)  
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")
        
        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 10.00
Accuracy: 9.38%

Epoch 0, Iteration 25 
Train Loss: 2.99
Accuracy: 79.69%

Epoch 0, Iteration 50 
Train Loss: 1.60
Accuracy: 87.50%

Epoch 0, Iteration 75 
Train Loss: 1.20
Accuracy: 89.84%

Epoch 0, Iteration 100 
Train Loss: 1.13
Accuracy: 90.62%

Epoch 0, Iteration 125 
Train Loss: 0.83
Accuracy: 94.53%

Epoch 0, Iteration 150 
Train Loss: 0.85
Accuracy: 93.75%

Epoch 0, Iteration 175 
Train Loss: 0.94
Accuracy: 92.19%

Epoch 0, Iteration 200 
Train Loss: 0.72
Accuracy: 96.88%

Epoch 0, Iteration 225 
Train Loss: 0.85
Accuracy: 90.62%

Epoch 0, Iteration 250 
Train Loss: 0.81
Accuracy: 94.53%

Epoch 0, Iteration 275 
Train Loss: 0.92
Accuracy: 94.53%

Epoch 0, Iteration 300 
Train Loss: 1.34
Accuracy: 90.62%

Epoch 0, Iteration 325 
Train Loss: 1.10
Accuracy: 93.75%

Epoch 0, Iteration 350 
Train Loss: 1.14
Accuracy: 92.97%

Epoch 0, Iteration 375 
Train Loss: 0.99
Accuracy: 95.31%

Epoch 0, Iteration 400 
Train Loss: 0.94
Accuracy: 92.97%

Ep

In [None]:
def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = net(data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

print(f"Trained decay rate of the first layer: {net.lif1.beta:.3f}\n")
print(f"Trained decay rates of the second layer: {net.lif2.beta}")
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

Trained decay rate of the first layer: 0.977

Trained decay rates of the second layer: Parameter containing:
tensor([0.1856, 0.2179, 0.1685, 0.1878, 0.1592, 0.7038, 0.5146, 0.5724, 0.6376,
        0.7811], requires_grad=True)
Test set accuracy: 94.660%
