## Spiking Neural Network for Image Classification

**MatTheTab**: [Github](https://github.com/MatTheTab)

This notebook is a live demo of a Spiking Neural Network for image classification. If you wish to turn on the demo all you need to do is to run this notebook and click the link at the bottom of the notebook - Gradio app link.

If you have decided to retrain the model on your own then you will have to upload your own version of the model to the colab notebook and possibly change variables in the **Variables Declarations** to appropriate values.

## Imports

In [1]:
!pip install gradio --quiet
! pip install snntorch --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m337.1 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.8/16.8 MB[0m [31m61.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.7/318.7 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m62.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.8/62.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
!wget https://github.com/MatTheTab/Spiking-Classifier/raw/main/SNN_Model.pth

--2024-08-30 16:33:53--  https://github.com/MatTheTab/Spiking-Classifier/raw/main/SNN_Model.pth
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/MatTheTab/Spiking-Classifier/main/SNN_Model.pth [following]
--2024-08-30 16:33:53--  https://raw.githubusercontent.com/MatTheTab/Spiking-Classifier/main/SNN_Model.pth
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94220 (92K) [application/octet-stream]
Saving to: ‘SNN_Model.pth’


2024-08-30 16:33:53 (3.79 MB/s) - ‘SNN_Model.pth’ saved [94220/94220]



In [3]:
# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.utils.data import DataLoader
from snntorch import surrogate
from torchvision import datasets, transforms
from torch.optim import Adam
from torch.utils.data import random_split
from snntorch import functional as SF
from snntorch import utils
import torchvision.transforms as transforms

# Additional Imports
import snntorch as snn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns
import numpy as np
import time
import os
from tqdm import tqdm
from IPython.display import Video
import gradio as gr
from PIL import Image

## Variables Declarations

If you decided to retrain the model from scratch with different values, then these values also need to be changed - you can still change them even if you are using a pretrained model but it might cause unexpected behaviour.

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
BETA = 0.5
SPIKE_GRAD = surrogate.fast_sigmoid(slope=25)
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

## Model

In [5]:
class Classifier(nn.Module):
    def __init__(self, beta, spike_grad, num_steps):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(32 * 5 * 5, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.num_steps = num_steps

    def forward(self, x):
        mem_rec = []
        spk_rec = []
        batch_dim = int(x.shape[0])
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        for step in range(self.num_steps):
            cur1 = F.max_pool2d(self.conv1(x), 2)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = F.max_pool2d(self.conv2(spk1), 2)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc1(spk2.view(batch_dim, 32*5*5))
            spk_out, mem_out = self.lif3(cur3, mem3)
            spk_rec.append(spk_out)
            mem_rec.append(mem_out)

        return torch.stack(spk_rec), torch.stack(mem_rec)

In [6]:
classifier = Classifier(beta = BETA, spike_grad = SPIKE_GRAD, num_steps = 50)
classifier.load_state_dict(torch.load("./SNN_Model.pth", map_location=torch.device(device), weights_only=True))
classifier.eval()

Classifier(
  (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
  (lif1): Leaky()
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (lif2): Leaky()
  (fc1): Linear(in_features=800, out_features=10, bias=True)
  (lif3): Leaky()
)

## Gradio App

The below app will be launched using Gradio, all you need to do after running the notebook is to click the link produced below the code cell. You can test the model's performance by uploading your own photos, but they should be related to the CIFAR 10 dataset (or if you decided to retrain the model, the new relevant dataset). Currently supported classes include:

["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [7]:
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    return transform(image).unsqueeze(0)

def process_model_output(model_output):
    model_output_first_10 = model_output[:, :10, :]
    activation_sums = torch.sum(model_output_first_10, dim=0)
    most_activated_classes = torch.argmax(activation_sums, dim=1)
    return most_activated_classes

def predict(image):
    image_tensor = preprocess_image(image)
    with torch.no_grad():
        output, _ = classifier(image_tensor)
        model_labels = process_model_output(output)
    pred_label = model_labels[0] if model_labels[0] is not None else "No Label"
    if pred_label == "No Label":
        return "No Label"
    return classes[pred_label]

In [8]:
interface = gr.Interface(fn=predict, inputs="image", outputs="label")
interface.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://821161014aff9541b2.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


