# CNN vs spiking-CNN comparison

## Setup

In [4]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [5]:
# dataloader arguments
batch_size = 128
data_path='./data/'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Data and Dataloaders

In [6]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [7]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

## Network - CNN

In [6]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10


In [7]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        #  Initialize Network
        self.net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    nn.Flatten(),
                    nn.Linear(64*4*4, 10),
                    )

    def forward(self, x):
        return self.net(x)
    
# Load the network onto CUDA if available
cnn_net = CNN().to(device)

### Training CNN

In [24]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cnn_net.parameters(), lr=0.001)
num_epochs = 1


for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optim.zero_grad()

        outputs = cnn_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optim.step()

        running_loss += loss.item()
        if i%50 == 49:
            print(f'[epoch: {epoch + 1}, iteration: {i + 1:5d}] loss: {running_loss / 50:.3f}]')
            running_loss = 0.0

print('Finished Training!!!')

[epoch: 1, iteration:    50] loss: 0.035]
[epoch: 1, iteration:   100] loss: 0.029]
[epoch: 1, iteration:   150] loss: 0.036]
[epoch: 1, iteration:   200] loss: 0.047]
[epoch: 1, iteration:   250] loss: 0.032]
[epoch: 1, iteration:   300] loss: 0.037]
[epoch: 1, iteration:   350] loss: 0.042]
[epoch: 1, iteration:   400] loss: 0.037]
[epoch: 1, iteration:   450] loss: 0.031]
Finished Training!!!


## Network - SCNN

In [8]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

### Training SCNN

In [9]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

In [10]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [12]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()
num_epochs = 1
loss_hist = []
# test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in iter(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        print("data size : ",data.size())
        print("target size : ",targets.size())

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)

        print("spk_rec size : ",spk_rec.size())

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # # Test set
        # if counter % 50 == 0:
        #     with torch.no_grad():
        #         net.eval()

        #         # Test set forward pass
        #         test_acc = batch_accuracy(test_loader, net, num_steps)
        #         print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
        #         test_acc_hist.append(test_acc.item())

        counter += 1

data size :  torch.Size([128, 1, 28, 28])
target size :  torch.Size([128])
spk_rec size :  torch.Size([50, 128, 10])
data size :  torch.Size([128, 1, 28, 28])
target size :  torch.Size([128])
spk_rec size :  torch.Size([50, 128, 10])
data size :  torch.Size([128, 1, 28, 28])
target size :  torch.Size([128])
spk_rec size :  torch.Size([50, 128, 10])
data size :  torch.Size([128, 1, 28, 28])
target size :  torch.Size([128])
spk_rec size :  torch.Size([50, 128, 10])
data size :  torch.Size([128, 1, 28, 28])
target size :  torch.Size([128])
spk_rec size :  torch.Size([50, 128, 10])


KeyboardInterrupt: 

## Comparison

### CNN

In [35]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = cnn_net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the {total} test images: {100 * correct // total} %')

Accuracy of the network on the 9984 test images: 98 %


### SCNN

In [36]:
with torch.no_grad():
    net.eval()

    # Test set forward pass
    test_acc = batch_accuracy(test_loader, net, num_steps)
    print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
    # test_acc_hist.append(test_acc.item())

Iteration 468, Test Acc: 98.12%

