In [6]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import snntorch.functional as SF


batch_size = 128
data_path='/tmp/data/mnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)


In [7]:
import torch.nn.functional as F

# Define Network
class Net(nn.Module):
    def __init__(self, num_steps):
        super().__init__()

        num_inputs = 784 # number of inputs
        num_hidden = 300 # number of hidden neurons
        num_outputs = 10 # number of classes (i.e., output neurons)

        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((num_outputs), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta1) # not a learnable decay rate
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True) # learnable decay rate
        self.num_steps = num_steps

    def forward(self, x):
        mem1 = self.lif1.init_leaky() # reset/init hidden states at t=0
        mem2 = self.lif2.init_leaky() # reset/init hidden states at t=0
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(self.num_steps): # loop over time
            cur1 = self.fc1(x.flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

# Load the network onto CUDA if available
num_steps = 25
net = Net(num_steps).to(device)

In [11]:
optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

num_epochs = 1 # run for 1 epoch - each data sample is seen only once

loss_hist = [] # record loss over iterations
acc_hist = [] # record accuracy over iterations

# training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        spk_rec, _ = net(data) # forward-pass
        loss_val = loss_fn(spk_rec, targets) # loss calculation
        optimizer.zero_grad() # null gradients
        loss_val.backward() # calculate gradients
        optimizer.step() # update weights
        loss_hist.append(loss_val.item()) # store loss

        # print every 25 iterations
        if i % 25 == 0:
          net.eval()
          print(f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")

          # check accuracy on a single batch
          acc = SF.accuracy_rate(spk_rec, targets)
          acc_hist.append(acc)
          print(f"Accuracy: {acc * 100:.2f}%\n")

        # uncomment for faster termination
        if i == 150:
            break

Epoch 0, Iteration 0 
Train Loss: 0.10
Accuracy: 97.66%

Epoch 0, Iteration 25 
Train Loss: 0.11
Accuracy: 96.09%

Epoch 0, Iteration 50 
Train Loss: 0.09
Accuracy: 96.88%

Epoch 0, Iteration 75 
Train Loss: 0.07
Accuracy: 96.88%

Epoch 0, Iteration 100 
Train Loss: 0.06
Accuracy: 97.66%

Epoch 0, Iteration 125 
Train Loss: 0.11
Accuracy: 95.31%

Epoch 0, Iteration 150 
Train Loss: 0.07
Accuracy: 98.44%

Epoch 0, Iteration 175 
Train Loss: 0.09
Accuracy: 96.88%

Epoch 0, Iteration 200 
Train Loss: 0.11
Accuracy: 96.09%

Epoch 0, Iteration 225 
Train Loss: 0.10
Accuracy: 95.31%

Epoch 0, Iteration 250 
Train Loss: 0.08
Accuracy: 97.66%

Epoch 0, Iteration 275 
Train Loss: 0.07
Accuracy: 99.22%

Epoch 0, Iteration 300 
Train Loss: 0.12
Accuracy: 93.75%

Epoch 0, Iteration 325 
Train Loss: 0.07
Accuracy: 98.44%

Epoch 0, Iteration 350 
Train Loss: 0.09
Accuracy: 95.31%

Epoch 0, Iteration 375 
Train Loss: 0.07
Accuracy: 97.66%

Epoch 0, Iteration 400 
Train Loss: 0.09
Accuracy: 96.88%

Ep

In [12]:
print(f"Trained decay rate of the first layer: {net.lif1.beta:.3f}\n")

print(f"Trained decay rates of the second layer: {net.lif2.beta}")

Trained decay rate of the first layer: 0.900

Trained decay rates of the second layer: Parameter containing:
tensor([0.9363, 0.7177, 0.4824, 0.8845, 0.9742, 0.4527, 0.5148, 0.7540, 0.9218,
        0.9530], device='mps:0', requires_grad=True)
