<a href="https://colab.research.google.com/github/andrewsiyoon/spiking-seRNN/blob/main/Standard_SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install snntorch

In [None]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [None]:
# Leaky neuron model, overriding the backward pass with a custom function (surrogate gradient descent)
class LeakySurrogate(nn.Module):
  def __init__(self, beta, threshold=1.0):
      super(LeakySurrogate, self).__init__()

      # initialize decay rate beta and threshold
      self.beta = beta
      self.threshold = threshold
      self.spike_op = self.SpikeOperator.apply
  
  # the forward function is called each time we call Leaky
  def forward(self, input_, mem):
    spk = self.spike_op((mem-self.threshold))  # call the Heaviside function
    reset = (spk * self.threshold).detach() # removes spike_op gradient from reset
    mem = self.beta * mem + input_ - reset # Eq (1)
    return spk, mem

  # Forward pass: Heaviside function
  # Backward pass: Override Dirac Delta with the Spike itself
  @staticmethod
  class SpikeOperator(torch.autograd.Function):
      @staticmethod
      def forward(ctx, mem):
          spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2)
          ctx.save_for_backward(spk)  # store the spike for use in the backward pass
          return spk

      @staticmethod
      def backward(ctx, grad_output):
          (spk,) = ctx.saved_tensors  # retrieve the spike 
          grad = grad_output * spk # scale the gradient by the spike: 1/0
          return grad

In [None]:
# Reduce the above neuron using PyTorch

lif1 = LeakySurrogate(beta = 0.9)
lif1 = snn.Leaky(beta = 0.9) #the snn function applies the Spike Operator surrogate gradient by default

SETTING UP MNIST

In [None]:
# dataloader arguments
batch_size = 128
data_path='/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [None]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

DEFINE THE NETWORK

In [None]:
# Network Architecture
num_inputs = 28*28 #We transformed the MNIST dataset to dimensions (28,28) in above code
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95

In [None]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden) #Applies a linear transformation of all pixels from MNIST (inputs)
        self.lif1 = snn.Leaky(beta=beta) #First spiking neuron layer: integrates weighted input over time and emits a spike if threshold is met
        self.fc2 = nn.Linear(num_hidden, num_outputs) #Applies a linear transformation to the output spikes of fc1
        self.lif2 = snn.Leaky(beta=beta) #Second spiking neuron layer: integrates weighted spikes over time

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
        
# Load the network onto CUDA if available
net = Net().to(device)

TRAINING THE SNN

1. Accuracy metric: from a batch of data passed through the network, sum all the spikes from each neuron over time and compare the index of the highest spike count to that of the target. If they match, the network correctly predicted the target. This is a way of assessing the accuracy of the network.

In [None]:
def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

2. Define the loss function

In [None]:
loss = nn.CrossEntropyLoss() #Single PyTorch function that takes the softmax of the output layer and generates a loss at the output 
#(see Pytorch documentation for possible parameters)

3. Define the optimizer (Adam for this example)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

#Can adjust the learning rate (lr), and the betas as well?

4. Single training iteration

In [None]:
#Take the first batch of data and load into CUDA

data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

#Flatten the input data into a vector of size 784 then pass it into the network

spk_rec, mem_rec = net(data.view(batch_size, -1)) #Batch_size was defined earlier as 784, then transformed into a 28*28
#The input is taken across 25 time steps, 128 data samples, and 10 output neurons which you can see through print(mem_rec.size())

#Calculate the loss by:

#Initializing the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

#Summing loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

#Printing the loss
print(f"Training loss: {loss_val.item():.3f}")


5. The above code was for one training iteration, but the real training loop is:

In [None]:
num_epochs = 1 #Modify if you want to train for >1 epoch
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch: #taking a batch of data and loading it into CUDA
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1)) #flatten the input data into a vector and pass it into the network

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device) #Initialize the total loss value
        for step in range(num_steps): #Sum loss at every step
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad() #Clear previously stored gradients
        loss_val.backward() #Calculate the gradients
        optimizer.step() #Weight update

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1