In [1]:
# imports
import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import psnn

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print(f"Train set accuracy for a single minibatch: {acc_train*100:.2f}%")
    print(f"Train set accuracy for a single minibatch: {acc_test*100:.2f}%") 
    print("\n")

In [2]:
# dataloader arguments
batch_size = 128
data_path='./data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((10, 10)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

In [3]:
# Network Architecture
num_inputs = 100
num_hidden = 50
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

In [4]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [5]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)
        
        # Calculate accuracy
        _, idx = spk_rec.sum(dim=0).max(1)
        acc_train = np.mean((targets == idx).detach().cpu().numpy())

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Calculate accuracy
            _, idx = test_spk.sum(dim=0).max(1)
            acc_test = np.mean((test_targets == idx).detach().cpu().numpy())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 59.05
Test Set Loss: 60.60
Train set accuracy for a single minibatch: 8.59%
Train set accuracy for a single minibatch: 7.81%


Epoch 0, Iteration 50
Train Set Loss: 47.42
Test Set Loss: 47.37
Train set accuracy for a single minibatch: 53.91%
Train set accuracy for a single minibatch: 46.09%


Epoch 0, Iteration 100
Train Set Loss: 38.46
Test Set Loss: 36.31
Train set accuracy for a single minibatch: 69.53%
Train set accuracy for a single minibatch: 77.34%


Epoch 0, Iteration 150
Train Set Loss: 29.68
Test Set Loss: 29.25
Train set accuracy for a single minibatch: 79.69%
Train set accuracy for a single minibatch: 80.47%


Epoch 0, Iteration 200
Train Set Loss: 22.75
Test Set Loss: 25.84
Train set accuracy for a single minibatch: 82.03%
Train set accuracy for a single minibatch: 80.47%


Epoch 0, Iteration 250
Train Set Loss: 20.48
Test Set Loss: 23.14
Train set accuracy for a single minibatch: 87.50%
Train set accuracy for a single minibatch: 82.81%

In [6]:
model = psnn.SpikingNeuralNetwork([num_inputs, num_hidden, num_outputs], beta=torch.tensor(beta), threshold=torch.tensor(1.), random_state=False).to(device)

In [7]:
loss = psnn.SNNLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [8]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        train_data = data.view(batch_size, -1).repeat(num_steps,1,1).permute(1,2,0)
        spk_rec, mem_rec = model(train_data)
        # initialize the loss & sum over time
        loss_val = loss(mem_rec, targets)

        # Calculate accuracy
        acc_train = (spk_rec.sum(2).argmax(dim=1) == targets).float().mean()

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_data = test_data.view(batch_size, -1).repeat(num_steps,1,1).permute(1,2,0)
            test_spk, test_mem = model(test_data)

            # Test set loss
            test_loss = loss(test_mem, test_targets)
            test_loss_hist.append(test_loss.item())

            # Calculate accuracy
            acc_test = (test_spk.sum(2).argmax(dim=1) == test_targets).float().mean()

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 2.58
Test Set Loss: 2.60
Train set accuracy for a single minibatch: 7.03%
Train set accuracy for a single minibatch: 10.16%


Epoch 0, Iteration 50
Train Set Loss: 2.09
Test Set Loss: 2.06
Train set accuracy for a single minibatch: 41.41%
Train set accuracy for a single minibatch: 42.97%


Epoch 0, Iteration 100
Train Set Loss: 1.77
Test Set Loss: 1.80
Train set accuracy for a single minibatch: 61.72%
Train set accuracy for a single minibatch: 51.56%


Epoch 0, Iteration 150
Train Set Loss: 1.59
Test Set Loss: 1.64
Train set accuracy for a single minibatch: 63.28%
Train set accuracy for a single minibatch: 60.94%


Epoch 0, Iteration 200
Train Set Loss: 1.40
Test Set Loss: 1.36
Train set accuracy for a single minibatch: 70.31%
Train set accuracy for a single minibatch: 72.66%


Epoch 0, Iteration 250
Train Set Loss: 1.18
Test Set Loss: 1.32
Train set accuracy for a single minibatch: 82.81%
Train set accuracy for a single minibatch: 76.56%


Epoch 0,

KeyboardInterrupt: 