<a href="https://colab.research.google.com/github/AmirHAbbasi/AmirHAbbasi.github.io/blob/main/examples/quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install snntorch --quiet

In [6]:
import torch, torch.nn as nn
import snntorch as snn

## DataLoading
Define variables for dataloading.

In [7]:
batch_size = 128
data_path='/tmp/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Load MNIST dataset.

In [8]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [9]:
from snntorch import surrogate

beta = 0.9  # neuron decay rate
spike_grad = surrogate.fast_sigmoid() # fast sigmoid surrogate gradient


'''
#  Initialize Convolutional SNN
net = nn.Sequential(nn.Conv2d(1, 8, 5), # Input: 1 channel (grayscale), Output: 8 feature maps, Kernel: 5×5
                    nn.MaxPool2d(2), # 2×2 max pooling

                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    #Input: 8 channels, Output: 16 feature maps, Kernel: 5×5
                    nn.Conv2d(8, 16, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    #fully connected
                    nn.Linear(16*4*4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)
'''


net = nn.Sequential(
    # Conv2D layer: 2 filters, 2x2 kernel, input: 1 channel, 28x28
    nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0),

    # Leaky Integrate-and-Fire neuron (replaces ReLU activation)
    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),

    # Flatten layer
    nn.Flatten(),

    # Dense layer: to 10 output units
    nn.Linear(2 * 27 * 27, 10),  # Note: input size calculation changed!

    # Output LIF neuron (replaces softmax)
    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
).to(device)



## Define the Forward Pass


In [10]:
from snntorch import utils

def forward_pass(net, data, num_steps):
  spk_rec = [] # record spikes over time
  utils.reset(net)  # reset/initialize hidden states for all LIF neurons in net

  for step in range(num_steps): # loop over time
      spk_out, mem_out = net(data) # one time step of the forward-pass
      spk_rec.append(spk_out) # record spikes

  return torch.stack(spk_rec)

Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. The correct class has a target firing rate of 80% of all time steps, and incorrect classes are set to 20%.

In [11]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

## Training Loop

Now for the training loop. The predicted class will be set to the neuron with the highest firing rate, i.e., a rate-coded output. We will just measure accuracy on the training set. This training loop follows the same syntax as with PyTorch.

In [12]:
num_epochs = 1 # run for 1 epoch - each data sample is seen only once
num_steps = 25  # run for 25 time steps

loss_hist = [] # record loss over iterations
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec = forward_pass(net, data, num_steps) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 2.45
Accuracy: 8.59%

Epoch 0, Iteration 25 
Train Loss: 0.69
Accuracy: 51.56%

Epoch 0, Iteration 50 
Train Loss: 0.47
Accuracy: 76.56%

Epoch 0, Iteration 75 
Train Loss: 0.40
Accuracy: 83.59%

Epoch 0, Iteration 100 
Train Loss: 0.37
Accuracy: 85.16%

Epoch 0, Iteration 125 
Train Loss: 0.40
Accuracy: 82.81%

Epoch 0, Iteration 150 
Train Loss: 0.36
Accuracy: 84.38%

Epoch 0, Iteration 175 
Train Loss: 0.33
Accuracy: 85.94%

Epoch 0, Iteration 200 
Train Loss: 0.32
Accuracy: 87.50%

Epoch 0, Iteration 225 
Train Loss: 0.33
Accuracy: 87.50%

Epoch 0, Iteration 250 
Train Loss: 0.32
Accuracy: 86.72%

Epoch 0, Iteration 275 
Train Loss: 0.30
Accuracy: 92.19%

Epoch 0, Iteration 300 
Train Loss: 0.40
Accuracy: 85.94%

Epoch 0, Iteration 325 
Train Loss: 0.33
Accuracy: 91.41%

Epoch 0, Iteration 350 
Train Loss: 0.34
Accuracy: 87.50%

Epoch 0, Iteration 375 
Train Loss: 0.31
Accuracy: 89.06%

Epoch 0, Iteration 400 
Train Loss: 0.33
Accuracy: 89.06%

Epo

## More control over your model
If you are simulating more complex architectures, such as residual nets, then your best bet is to wrap the network up in a class as shown below. This time, we will explicitly use the membrane potential, `mem`, and let `init_hidden` default to false.

For the sake of speed, we'll just simulate a fully-connected SNN, but this can be generalized to other network types (e.g., Convs).

In addition, let's set the neuron decay rate, `beta`, to be a learnable parameter. The first layer will have a shared decay rate across neurons. Each neuron in the second layer will have an independent decay rate. The decay is clipped between [0,1].

In [13]:
import torch.nn.functional as F

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        num_inputs = 784 # number of inputs
        num_hidden = 300 # number of hidden neurons
        num_outputs = 10 # number of classes (i.e., output neurons)

        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta1) # not a learnable decay rate
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True) # learnable decay rate

    def forward(self, x):
        mem1 = self.lif1.init_leaky() # reset/init hidden states at t=0
        mem2 = self.lif2.init_leaky() # reset/init hidden states at t=0
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(num_steps): # loop over time
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
net = Net().to(device)

In [14]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1 # run for 1 epoch - each data sample is seen only once
num_steps = 25  # run for 25 time steps

loss_hist = [] # record loss over iterations
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = net(data) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

        # uncomment for faster termination
        # if i == 150:
        #     break


Epoch 0, Iteration 0 
Train Loss: 2.50
Accuracy: 13.28%

Epoch 0, Iteration 25 
Train Loss: 0.56
Accuracy: 81.25%

Epoch 0, Iteration 50 
Train Loss: 0.40
Accuracy: 82.03%

Epoch 0, Iteration 75 
Train Loss: 0.30
Accuracy: 89.84%

Epoch 0, Iteration 100 
Train Loss: 0.23
Accuracy: 92.19%

Epoch 0, Iteration 125 
Train Loss: 0.21
Accuracy: 93.75%

Epoch 0, Iteration 150 
Train Loss: 0.21
Accuracy: 89.06%

Epoch 0, Iteration 175 
Train Loss: 0.18
Accuracy: 90.62%

Epoch 0, Iteration 200 
Train Loss: 0.18
Accuracy: 91.41%

Epoch 0, Iteration 225 
Train Loss: 0.17
Accuracy: 93.75%

Epoch 0, Iteration 250 
Train Loss: 0.16
Accuracy: 95.31%

Epoch 0, Iteration 275 
Train Loss: 0.16
Accuracy: 92.97%

Epoch 0, Iteration 300 
Train Loss: 0.14
Accuracy: 93.75%

Epoch 0, Iteration 325 
Train Loss: 0.15
Accuracy: 95.31%

Epoch 0, Iteration 350 
Train Loss: 0.13
Accuracy: 95.31%

Epoch 0, Iteration 375 
Train Loss: 0.11
Accuracy: 96.09%

Epoch 0, Iteration 400 
Train Loss: 0.13
Accuracy: 95.31%

Ep

In [16]:
# function to measure accuracy on full test set
def test_accuracy(data_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = net(data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [17]:
print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%")

Test set accuracy: 94.890%
