In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib as plt
import numpy as np
import itertools

In [2]:
batch_size = 128
data_path = './data/mnist'

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
transfrom = transforms.Compose([
    transforms.Resize(28),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

minist_train = datasets.MNIST(data_path, train=True, download=True, transform=transfrom)
minist_test = datasets.MNIST(data_path, train=False, download=True, transform=transfrom)

In [4]:
train_loader = DataLoader(minist_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
test_loader = DataLoader(minist_test, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)

In [5]:
# define the network
num_inputs = 28 * 28
num_hidden = 100
num_outputs = 10

num_steps = 25
beta = 0.95

In [6]:
class SNNNet(nn.Module):
    def __init__(self):
        super(SNNNet, self).__init__()

        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.ac1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.ac2 = snn.Leaky(beta=beta)

    def forward(self, inputs):
        mem1 = self.ac1.init_leaky()
        mem2 = self.ac2.init_leaky()

        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(inputs)
            spk1, mem1 = self.ac1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.ac2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)


net = SNNNet().to(device)

In [8]:
def print_batch_acc(data, target, train=False):
    output, _ = net(data.view(batch_size, -1).to(device))
    _, idx = output.sum(0).max(1)
    acc = (idx == target.to(device)).float().detach().cpu().numpy().mean()

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_acc(data, targets, train=True)
    print_batch_acc(test_data, test_targets, train=False)
    print("\n")

In [23]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)

data, target = next(iter(train_loader))
test_data, test_targets = next(iter(test_loader))
data = data.to(device)
target = target.to(device)

spk_rec, mem_rec = net(data.view(batch_size, -1).to(device))
mem_rec.size()

torch.Size([25, 128, 10])

In [24]:
loss_val = torch.sum(torch.stack([loss(mem, target) for mem in mem_rec]))
loss_val

tensor(52.5580, device='cuda:0', grad_fn=<SumBackward0>)

In [25]:
optimizer.zero_grad()
loss_val.backward()
optimizer.step()

In [26]:
print_batch_acc(data, target, train=True)

Train set accuracy for a single minibatch: 31.25%


In [None]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 51.05
Test Set Loss: 48.33
Train set accuracy for a single minibatch: 30.47%
Test set accuracy for a single minibatch: 22.66%


Epoch 0, Iteration 50
Train Set Loss: 18.73
Test Set Loss: 15.64
Train set accuracy for a single minibatch: 89.06%
Test set accuracy for a single minibatch: 90.62%


Epoch 0, Iteration 100
Train Set Loss: 13.67
Test Set Loss: 13.59
Train set accuracy for a single minibatch: 89.84%
Test set accuracy for a single minibatch: 92.97%


Epoch 0, Iteration 150
Train Set Loss: 11.64
Test Set Loss: 8.78
Train set accuracy for a single minibatch: 90.62%
Test set accuracy for a single minibatch: 92.19%


Epoch 0, Iteration 200
Train Set Loss: 10.00
Test Set Loss: 10.72
Train set accuracy for a single minibatch: 92.97%
Test set accuracy for a single minibatch: 89.06%


Epoch 0, Iteration 250
Train Set Loss: 9.50
Test Set Loss: 10.60
Train set accuracy for a single minibatch: 91.41%
Test set accuracy for a single minibatch: 87.50%


Epo