In [4]:
!pip install gradio --quiet
! pip install snntorch --quiet

In [5]:
# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.utils.data import DataLoader
from snntorch import surrogate
from torchvision import datasets, transforms
from torch.optim import Adam
from torch.utils.data import random_split
from snntorch import functional as SF
from snntorch import utils
import torchvision.transforms as transforms

# Additional Imports
import snntorch as snn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns
import numpy as np
import time
import os
from tqdm import tqdm
from IPython.display import Video
import gradio as gr
from PIL import Image

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
BETA = 0.5
SPIKE_GRAD = surrogate.fast_sigmoid(slope=25)
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
           "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

In [12]:
class Classifier(nn.Module):
    def __init__(self, beta, spike_grad, num_steps):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 3, kernel_size = 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(16*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.num_steps = num_steps

    def forward(self, x):
        mem_rec = []
        spk_rec = []
        batch_dim = int(x.shape[0])
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        for step in range(self.num_steps):
            cur1 = F.max_pool2d(self.conv1(x), 2)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = F.max_pool2d(self.conv2(spk1), 2)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc1(spk2.view(batch_dim, 16*4*4))
            spk_out, mem_out = self.lif3(cur3, mem3)
            spk_rec.append(spk_out)
            mem_rec.append(mem_out)

        return torch.stack(spk_rec), torch.stack(mem_rec)

In [14]:
classifier = Classifier(beta = BETA, spike_grad = SPIKE_GRAD, num_steps = 50)
classifier.load_state_dict(torch.load("./SNN_Model.pth", map_location=torch.device(device)))
classifier.eval()

  classifier.load_state_dict(torch.load("./SNN_Model.pth", map_location=torch.device(device)))


Classifier(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (lif1): Leaky()
  (conv2): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
  (lif2): Leaky()
  (fc1): Linear(in_features=256, out_features=10, bias=True)
  (lif3): Leaky()
)

In [15]:
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,))
    ])
    return transform(image).unsqueeze(0)

def process_model_output(model_output):
    model_output_first_10 = model_output[:, :10, :]
    activation_sums = torch.sum(model_output_first_10, dim=0)
    most_activated_classes = torch.argmax(activation_sums, dim=1)
    return most_activated_classes

def predict(image):
    image_tensor = preprocess_image(image)
    with torch.no_grad():
        output = classifier(image_tensor)
        model_labels = process_model_output(output)
    pred_label = model_labels[0] if model_labels[0] is not None else "No Label"
    if pred_label == "No Label":
        return "No Label"
    return classes[pred_label]

In [None]:
interface = gr.Interface(fn=predict, inputs="image", outputs="label")
interface.launch(share = True, debug = True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://8f55d0f8dae6a6c9f1.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
