In [1]:
import torch, torch.nn as nn
import snntorch as snn

In [2]:
batch_size = 128
data_path='/tmp/data/fmnist'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True)

100%|██████████| 26.4M/26.4M [00:00<00:00, 60.7MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 2.45MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 40.6MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 32.2MB/s]


In [4]:
from snntorch import surrogate

# network parameters
num_inputs = 28*28
num_hidden = 128
num_outputs = 10
num_steps = 1

# spiking neuron parameters
beta = 0.9  # neuron decay rate
grad = surrogate.fast_sigmoid()

In [5]:
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(num_inputs, num_hidden),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                    nn.Linear(num_hidden, num_outputs),
                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                    ).to(device)

In [6]:
pop_outputs = 500

net_pop = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hidden),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
                        nn.Linear(num_hidden, pop_outputs),
                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
                        ).to(device)

In [7]:
import snntorch.functional as SF

optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)

In [9]:
from snntorch import utils

def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      utils.reset(net)
      spk_rec, _ = net(data)

      if population_code:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets, population_code=True, num_classes=10) * spk_rec.size(1)
      else:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets) * spk_rec.size(1)

      total += spk_rec.size(1)

  return acc/total

In [10]:
from snntorch import backprop

num_epochs = 5

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\n")

Epoch: 0
Test set accuracy: 59.187%

Epoch: 1
Test set accuracy: 72.033%

Epoch: 2
Test set accuracy: 65.823%

Epoch: 3
Test set accuracy: 62.342%

Epoch: 4
Test set accuracy: 64.794%



In [11]:
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0, population_code=True, num_classes=10)
optimizer = torch.optim.Adam(net_pop.parameters(), lr=2e-3, betas=(0.9, 0.999))

In [12]:
num_epochs = 5

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net_pop, train_loader, num_steps=num_steps,
                            optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net_pop, num_steps, population_code=True, num_classes=10)*100:.3f}%\n")


Epoch: 0
Test set accuracy: 79.747%

Epoch: 1
Test set accuracy: 80.053%

Epoch: 2
Test set accuracy: 81.230%

Epoch: 3
Test set accuracy: 83.060%

Epoch: 4
Test set accuracy: 82.595%



# Own implementation

In [None]:
from neural_nets.configurable_spiking_neural_net import ConfigurableSpikingNeuralNet
from constants import NUMBER_INPUT_NEURONS, NUMBER_HIDDEN_LAYERS, NUMBER_HIDDEN_NEURONS, NUMBER_OUTPUT_NEURONS, BETA, THRESHOLD, TIME_STEPS, DEVICE
from training.train_simplified_snn import train_simplified_snn

net = ConfigurableSpikingNeuralNet(number_input_neurons=NUMBER_INPUT_NEURONS, 
                                           number_hidden_neurons=NUMBER_HIDDEN_NEURONS,
                                           number_hidden_layers=NUMBER_HIDDEN_LAYERS,
                                           number_output_neurons=NUMBER_OUTPUT_NEURONS, 
                                           beta=BETA, 
                                           threshold=THRESHOLD,
                                           time_steps=TIME_STEPS, 
                                           sparsity=0,
                                           population_coding=False).to(DEVICE)

#train_simplified_snn(net, num_epochs='early_stopping', loss_configuration='membrane_potential_cross_entropy')
train_simplified_snn(net, num_epochs='early_stopping', 
                     loss_configuration='rate_code_cross_entropy', 
                     output_file_path='./output/experiments_population_coding/rate_code_cross_entropy.json',
                     save_plots=f'./output/experiments_population_coding/rate_code_cross_entropy')


Epoch: 0
loss 2.885247600844469
train accuracy 0.1662579695929377
test accuracy 28.48939929328622
Epoch: 1
loss 2.682589676521642
train accuracy 0.2697400686611084
test accuracy 37.014134275618375
Epoch: 2
loss 2.599637602286458
train accuracy 0.42116233447768514
test accuracy 45.75971731448763
