<center>

**EE6347 : Devices and Technologies for AI and Neuromorphic Computing**

**Assignment 7 : Image Classification using Spiking Neural Network**

Name : ANIRUDH B S ; Roll No. : EE21B019
</center>

The first step is to install the required library : [snntorch](https://snntorch.readthedocs.io/en/latest/)

In [1]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.4 nirtorch-1.0 snntorch-0.9.1


Now, we will import all the required modules from snntorch.

In [2]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

  from snntorch import backprop


Load the EMNIST dataset

In [3]:
# Define transform to preprocess the data
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

# Loading the 'balanced' split of EMNIST dataset
emnist_dataset_train = datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform)
emnist_dataset_test = datasets.EMNIST(root='./data', split='balanced', train=False, download=True, transform=transform)

# Splitting into train and validation sets
train_val_split = 0.9
train_size = int(train_val_split * len(emnist_dataset_train))
val_size = len(emnist_dataset_train) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(emnist_dataset_train, [train_size, val_size])

# Creating data loaders
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last = True)
test_loader = torch.utils.data.DataLoader(emnist_dataset_test, batch_size=batch_size, shuffle=True, drop_last = True)

Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to ./data/EMNIST/raw/gzip.zip


100%|██████████| 562M/562M [00:09<00:00, 60.0MB/s]


Extracting ./data/EMNIST/raw/gzip.zip to ./data/EMNIST/raw


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Neuron and simulation parameters
#spike_grad = surrogate.fast_sigmoid(slope=25)
spike_grad = surrogate.atan(alpha=2.0)
beta = 0.5
num_steps = 5

# Defining the network architecture
# Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*4*4, 47),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

data, targets = next(iter(train_loader))
#spike_data = spikegen.delta(data, threshold=0.1, padding=False, off_spike=False)
#spike_data = spikegen.latency(data, num_steps=num_steps, normalize=True, linear=True)
spike_data = spikegen.rate(data, num_steps=num_steps)
spike_data = spike_data.to(device)
print(data.size())
print(spike_data.size())
targets = targets.to(device)

for step in range(num_steps):
    spk_out, mem_out = net(data)

# Forward pass
def forward_pass(net, num_steps, spike_data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(spike_data[step])
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

#spk_rec, mem_rec = forward_pass(net, num_steps, spike_data)

torch.Size([128, 1, 28, 28])
torch.Size([5, 128, 1, 28, 28])


In [5]:
loss_fn = SF.mse_count_loss()

In [6]:
# Defining the function for accuracy calculation
def batch_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      spike_data = spikegen.rate(data, num_steps=num_steps)
      spike_data = spike_data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, spike_data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

Training Loop

In [7]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 40
loss_hist = []
val_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    #Inner training loop
    for data, targets in iter(train_loader):
        spike_data = spikegen.rate(data, num_steps=num_steps)
        spike_data = spike_data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, spike_data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Validation set
        if counter % 100 == 0:
          with torch.no_grad():
              net.eval()

              # Validation set forward pass
              val_acc = batch_accuracy(val_loader, net, num_steps)
              print(f"Iteration {counter}, Val Acc: {val_acc * 100:.2f}%\n")
              val_acc_hist.append(val_acc.item())

        counter += 1

Iteration 0, Val Acc: 2.11%

Iteration 100, Val Acc: 62.68%

Iteration 200, Val Acc: 70.87%

Iteration 300, Val Acc: 74.87%

Iteration 400, Val Acc: 75.47%

Iteration 500, Val Acc: 76.30%

Iteration 600, Val Acc: 78.38%

Iteration 700, Val Acc: 78.97%

Iteration 800, Val Acc: 79.17%

Iteration 900, Val Acc: 79.15%

Iteration 1000, Val Acc: 79.83%

Iteration 1100, Val Acc: 79.87%

Iteration 1200, Val Acc: 80.56%

Iteration 1300, Val Acc: 81.11%

Iteration 1400, Val Acc: 81.12%

Iteration 1500, Val Acc: 80.98%

Iteration 1600, Val Acc: 80.93%

Iteration 1700, Val Acc: 81.84%

Iteration 1800, Val Acc: 81.41%

Iteration 1900, Val Acc: 81.82%

Iteration 2000, Val Acc: 81.99%

Iteration 2100, Val Acc: 81.26%

Iteration 2200, Val Acc: 82.55%

Iteration 2300, Val Acc: 81.49%

Iteration 2400, Val Acc: 81.78%

Iteration 2500, Val Acc: 81.76%

Iteration 2600, Val Acc: 82.20%

Iteration 2700, Val Acc: 81.90%

Iteration 2800, Val Acc: 82.05%

Iteration 2900, Val Acc: 81.96%

Iteration 3000, Val Acc

KeyboardInterrupt: 

In [8]:
# Test accuracy
with torch.no_grad():
    net.eval()
    # Test set forward pass
    test_acc = batch_accuracy(test_loader, net, num_steps)
    print(f"Test Accuracy: {test_acc * 100:.2f}%\n")

Test Accuracy: 83.00%

