In [1]:
# Main Setup

# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools
import gc

# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(device)

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.70

V1 = 0.5 # shared recurrent connection
V2 = torch.rand(num_outputs) # unshared recurrent connections

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)

        # Default RLeaky Layer where recurrent connections
        # are initialized using PyTorch defaults in nn.Linear.
        self.lif1 = snn.RLeaky(beta=beta,
                    linear_features=num_hidden)

        self.fc2 = nn.Linear(num_hidden, num_outputs)

        # each neuron has a single connection back to itself
        # where the output spike is scaled by V.
        # For `all_to_all = False`, V can be shared between
        # neurons (e.g., V1) or unique / unshared between
        # neurons (e.g., V2).
        # V is learnable by default.
        self.lif2 = snn.RLeaky(beta=beta, all_to_all=False, V=V1)

    def forward(self, x):
        # Initialize hidden states at t=0
        spk1, mem1 = self.lif1.init_rleaky()
        spk2, mem2 = self.lif2.init_rleaky()

        # Record output layer spikes and membrane
        spk2_rec = []
        mem2_rec = []

        # time-loop
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, spk1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, spk2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # convert lists to tensors
        spk2_rec = torch.stack(spk2_rec)
        mem2_rec = torch.stack(mem2_rec)

        return spk2_rec, mem2_rec


cuda


In [2]:
for i in range(10):


    # Load the network onto CUDA if available
    net = Net().to(device)

    # pass data into the network, sum the spikes over time
    # and compare the neuron with the highest number of spikes
    # with the target

    def print_batch_accuracy(data, targets, train=False):
        output, _ = net(data.view(batch_size, -1))
        _, idx = output.sum(dim=0).max(1)
        acc = np.mean((targets == idx).detach().cpu().numpy())

        if train:
            print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
        else:
            print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

    def train_printer(
        data, targets, epoch,
        counter, iter_counter,
            loss_hist, test_loss_hist, test_data, test_targets):
        print(f"Epoch {epoch}, Iteration {iter_counter}")
        print(f"Train Set Loss: {loss_hist[counter]:.2f}")
        print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
        print_batch_accuracy(data, targets, train=True)
        print_batch_accuracy(test_data, test_targets, train=False)
        print("\n")

    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

    data, targets = next(iter(train_loader))
    data = data.to(device)
    targets = targets.to(device)

    spk_rec, mem_rec = net(data.view(batch_size, -1))
    # print(mem_rec.size())

    # initialize the total loss value
    loss_val = torch.zeros((1), dtype=dtype, device=device)

    # sum loss at every step
    for step in range(num_steps):
        loss_val += loss(mem_rec[step], targets)

    # print(f"Training loss: {loss_val.item():.3f}")
    # print_batch_accuracy(data, targets, train=True)

    # clear previously stored gradients
    optimizer.zero_grad()

    # calculate the gradients
    loss_val.backward()

    # weight update
    optimizer.step()

    # calculate new network outputs using the same data
    spk_rec, mem_rec = net(data.view(batch_size, -1))

    # initialize the total loss value
    loss_val = torch.zeros((1), dtype=dtype, device=device)

    # sum loss at every step
    for step in range(num_steps):
        loss_val += loss(mem_rec[step], targets)


    # Training Loop 
    num_epochs = 1
    loss_hist = []
    test_loss_hist = []
    counter = 0

    # Outer training loop
    for epoch in range(num_epochs):
        iter_counter = 0
        train_batch = iter(train_loader)

        # Minibatch training loop
        for data, targets in train_batch:
            data = data.to(device)
            targets = targets.to(device)

            # forward pass
            net.train()
            spk_rec, mem_rec = net(data.view(batch_size, -1))

            # initialize the loss & sum over time
            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                loss_val += loss(mem_rec[step], targets)

            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())

            # Test set
            with torch.no_grad():
                net.eval()
                test_data, test_targets = next(iter(test_loader))
                test_data = test_data.to(device)
                test_targets = test_targets.to(device)

                # Test set forward pass
                test_spk, test_mem = net(test_data.view(batch_size, -1))

                # Test set loss
                test_loss = torch.zeros((1), dtype=dtype, device=device)
                for step in range(num_steps):
                    test_loss += loss(test_mem[step], test_targets)
                test_loss_hist.append(test_loss.item())

                # Print train/test loss/accuracy
                if counter % 50 == 0:
                    train_printer(
                        data, targets, epoch,
                        counter, iter_counter,
                        loss_hist, test_loss_hist,
                        test_data, test_targets)
                counter += 1
                iter_counter +=1



    filename = 'ref_snn_rleaky_' + str(i) + '.pth'
    torch.save(net.state_dict(), filename)
    print(f"Finished Net#{i+1}")

Epoch 0, Iteration 0
Train Set Loss: 56.09
Test Set Loss: 53.67
Train set accuracy for a single minibatch: 3.91%
Test set accuracy for a single minibatch: 6.25%


Epoch 0, Iteration 50
Train Set Loss: 17.14
Test Set Loss: 14.70
Train set accuracy for a single minibatch: 84.38%
Test set accuracy for a single minibatch: 86.72%


Epoch 0, Iteration 100
Train Set Loss: 11.55
Test Set Loss: 12.00
Train set accuracy for a single minibatch: 93.75%
Test set accuracy for a single minibatch: 91.41%


Epoch 0, Iteration 150
Train Set Loss: 8.74
Test Set Loss: 10.32
Train set accuracy for a single minibatch: 93.75%
Test set accuracy for a single minibatch: 92.19%


Epoch 0, Iteration 200
Train Set Loss: 10.07
Test Set Loss: 10.82
Train set accuracy for a single minibatch: 93.75%
Test set accuracy for a single minibatch: 92.19%


Epoch 0, Iteration 250
Train Set Loss: 9.69
Test Set Loss: 8.28
Train set accuracy for a single minibatch: 92.97%
Test set accuracy for a single minibatch: 92.19%


Epoch 