In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from nmnist_dataset import NMNISTDataset, augment # Import NMNIST dataset.

In [2]:
trained_folder = "Trained"
os.makedirs(trained_folder, exist_ok=True)

# Create SLAYER network

In [3]:
import lava.lib.dl.slayer as slayer
from lava.lib.dl.slayer.loss import SpikeMax
from lava.lib.dl.slayer.utils import Assistant

In [4]:
class SLAYERNetwork(torch.nn.Module):
    def __init__(self):
        super(SLAYERNetwork, self).__init__()
        
        neuron_params = {
            "threshold": 1.25,
            "current_decay": 0.25,
            "voltage_decay": 0.03,
            "tau_grad": 0.03,
            "scale_grad": 3,
            "requires_grad": True
        }
        neuron_params = {**neuron_params}
        
        self.blocks = torch.nn.ModuleList([
            slayer.block.cuba.Dense(neuron_params, 34*34*2, 512, weight_norm=True, delay=True),
            slayer.block.cuba.Dense(neuron_params, 512, 512, weight_norm=True, delay=True),
            slayer.block.cuba.Dense(neuron_params, 512, 10, weight_norm=True)
        ])
    
    def forward(self, spike):
        for block in self.blocks:
            spike = block(spike)
        return spike
    
    def grad_flow(self, path):
        # helps monitor the gradient flow.
        grad = [b.synapse.grad_norm for b in sef.blocks if hasattr(b, "synapse")]
        plt.figure()
        plt.semilogy(grad)
        plt.savefig(path+"gradFlow.png")
        plt.close()
        
        return grad

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

slayer_net = SLAYERNetwork().to(device)
error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction="sum").to(device)
optimizer = torch.optim.Adam(slayer_net.parameters(), lr=0.001)
slayer_stats = slayer.utils.LearningStats()
slayer_assistant = Assistant(slayer_net, error, optimizer, slayer_stats, 
                             classifier=slayer.classifier.Rate.predict, count_log=False)

# Instantiate Dataset and DataLoader

In [6]:
training_set = NMNISTDataset(train=True, transform=augment, download=True)
testing_set = NMNISTDataset(train=False)
train_loader = DataLoader(dataset=training_set, batch_size=256, shuffle=True)
test_loader = DataLoader(dataset=testing_set, batch_size=256, shuffle=True)


NMNIST dataset is freely available here:
https://www.garrickorchard.com/datasets/n-mnist
(c) Creative Commons:
    Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N.
    "Converting Static Image Datasets to Spiking Neuromorphic Datasets Using
    Saccades",
    Frontiers in Neuroscience, vol.9, no.437, Oct. 2015



# Train SLAYER Network

In [7]:
epochs = 20
for epoch in range(epochs):
    train_iter = iter(train_loader)
    print(f"\r[Epoch {epoch:2d}/{epochs}] {slayer_stats}", end="")
    
    for i, (inputs, labels) in tqdm(enumerate(train_iter)): # Training loop.
        output = slayer_assistant.train(inputs, labels)
    
    test_iter = iter(test_loader)
    for i, (inputs, label) in enumerate(test_iter): # Testing loop.
        output = slayer_assistant.test(inputs, label)
    
    print(f"\r[Epoch {epoch:2d}/{epochs}] {slayer_stats}", end="")
    
    if epoch%5 == 4: # Cleanup Display.
        print('\r', ' '*len(f'\r[Epoch {epoch:2d}/{epochs}] {slayer_stats}'))
        stats_str = str(slayer_stats).replace("| ", "\n")
        print(f'[Epoch {epoch:2d}/{epochs}]\n{stats_str}')
    
    if slayer_stats.testing.best_accuracy:
        torch.save(slayer_net.state_dict(), trained_folder + '/slayer_network.pt')
    slayer_stats.update()

[Epoch  0/20] Train 

235it [02:56,  1.33it/s]


[Epoch  0/20] Train loss =     5.07748                        accuracy = 0.63355 | Test  loss =     2.32295                        accuracy = 0.88850[Epoch  1/20] Train 

235it [02:55,  1.34it/s]


[Epoch  1/20] Train loss =     2.46712 (min =     5.07748)    accuracy = 0.87288 (max = 0.63355) | Test  loss =     1.84621 (min =     2.32295)    accuracy = 0.92430 (max = 0.88850)[Epoch  2/20] Train 

235it [02:56,  1.33it/s]


[Epoch  2/20] Train loss =     1.98756 (min =     2.46712)    accuracy = 0.90812 (max = 0.87288) | Test  loss =     1.72022 (min =     1.84621)    accuracy = 0.92920 (max = 0.92430)[Epoch  3/20] Train 

235it [02:56,  1.33it/s]


[Epoch  3/20] Train loss =     1.79057 (min =     1.98756)    accuracy = 0.92083 (max = 0.90812) | Test  loss =     1.34987 (min =     1.72022)    accuracy = 0.95070 (max = 0.92920)[Epoch  4/20] Train 

235it [02:55,  1.34it/s]


[Epoch  4/20] Train loss =     1.57116 (min =     1.79057)    accuracy = 0.93200 (max = 0.92083) | Test  loss =     1.28323 (min =     1.34987)    accuracy = 0.95030 (max = 0.95070)                                                                                                                                                                                       
[Epoch  4/20]
Train loss =     1.57116 (min =     1.79057)    accuracy = 0.93200 (max = 0.92083) 
Test  loss =     1.28323 (min =     1.34987)    accuracy = 0.95030 (max = 0.95070)
[Epoch  5/20] Train 

235it [02:55,  1.34it/s]


[Epoch  5/20] Train loss =     1.48068 (min =     1.57116)    accuracy = 0.93572 (max = 0.93200) | Test  loss =     1.15632 (min =     1.28323)    accuracy = 0.95730 (max = 0.95070)[Epoch  6/20] Train 

235it [02:56,  1.33it/s]


[Epoch  6/20] Train loss =     1.38956 (min =     1.48068)    accuracy = 0.93740 (max = 0.93572) | Test  loss =     1.11252 (min =     1.15632)    accuracy = 0.95650 (max = 0.95730)[Epoch  7/20] Train 

235it [02:56,  1.33it/s]


[Epoch  7/20] Train loss =     1.31360 (min =     1.38956)    accuracy = 0.94180 (max = 0.93740) | Test  loss =     1.11020 (min =     1.11252)    accuracy = 0.96020 (max = 0.95730)[Epoch  8/20] Train 

235it [02:55,  1.34it/s]


[Epoch  8/20] Train loss =     1.28305 (min =     1.31360)    accuracy = 0.94400 (max = 0.94180) | Test  loss =     0.99424 (min =     1.11020)    accuracy = 0.96690 (max = 0.96020)[Epoch  9/20] Train 

235it [02:55,  1.34it/s]


[Epoch  9/20] Train loss =     1.21671 (min =     1.28305)    accuracy = 0.94722 (max = 0.94400) | Test  loss =     0.93210 (min =     0.99424)    accuracy = 0.96570 (max = 0.96690)                                                                                                                                                                                       
[Epoch  9/20]
Train loss =     1.21671 (min =     1.28305)    accuracy = 0.94722 (max = 0.94400) 
Test  loss =     0.93210 (min =     0.99424)    accuracy = 0.96570 (max = 0.96690)
[Epoch 10/20] Train 

235it [02:54,  1.34it/s]


[Epoch 10/20] Train loss =     1.16687 (min =     1.21671)    accuracy = 0.94973 (max = 0.94722) | Test  loss =     0.94632 (min =     0.93210)    accuracy = 0.96630 (max = 0.96690)[Epoch 11/20] Train 

235it [02:56,  1.33it/s]


[Epoch 11/20] Train loss =     1.15986 (min =     1.16687)    accuracy = 0.95097 (max = 0.94973) | Test  loss =     0.90963 (min =     0.93210)    accuracy = 0.96680 (max = 0.96690)[Epoch 12/20] Train 

235it [02:56,  1.33it/s]


[Epoch 12/20] Train loss =     1.12586 (min =     1.15986)    accuracy = 0.95332 (max = 0.95097) | Test  loss =     0.92153 (min =     0.90963)    accuracy = 0.96820 (max = 0.96690)[Epoch 13/20] Train 

235it [02:56,  1.33it/s]


[Epoch 13/20] Train loss =     1.13797 (min =     1.12586)    accuracy = 0.95470 (max = 0.95332) | Test  loss =     0.90434 (min =     0.90963)    accuracy = 0.96720 (max = 0.96820)[Epoch 14/20] Train 

235it [02:56,  1.33it/s]


[Epoch 14/20] Train loss =     1.09113 (min =     1.12586)    accuracy = 0.95508 (max = 0.95470) | Test  loss =     0.99115 (min =     0.90434)    accuracy = 0.96870 (max = 0.96820)                                                                                                                                                                                       
[Epoch 14/20]
Train loss =     1.09113 (min =     1.12586)    accuracy = 0.95508 (max = 0.95470) 
Test  loss =     0.99115 (min =     0.90434)    accuracy = 0.96870 (max = 0.96820)
[Epoch 15/20] Train 

235it [02:55,  1.34it/s]


[Epoch 15/20] Train loss =     1.05736 (min =     1.09113)    accuracy = 0.95672 (max = 0.95508) | Test  loss =     0.84850 (min =     0.90434)    accuracy = 0.96880 (max = 0.96870)[Epoch 16/20] Train 

235it [02:55,  1.34it/s]


[Epoch 16/20] Train loss =     1.02757 (min =     1.05736)    accuracy = 0.95673 (max = 0.95672) | Test  loss =     0.91514 (min =     0.84850)    accuracy = 0.97100 (max = 0.96880)[Epoch 17/20] Train 

235it [02:56,  1.33it/s]


[Epoch 17/20] Train loss =     1.05009 (min =     1.02757)    accuracy = 0.95738 (max = 0.95673) | Test  loss =     0.86629 (min =     0.84850)    accuracy = 0.97070 (max = 0.97100)[Epoch 18/20] Train 

235it [02:55,  1.34it/s]


[Epoch 18/20] Train loss =     1.01840 (min =     1.02757)    accuracy = 0.95863 (max = 0.95738) | Test  loss =     0.85288 (min =     0.84850)    accuracy = 0.97040 (max = 0.97100)[Epoch 19/20] Train 

235it [02:55,  1.34it/s]


[Epoch 19/20] Train loss =     1.02871 (min =     1.01840)    accuracy = 0.95945 (max = 0.95863) | Test  loss =     0.90141 (min =     0.84850)    accuracy = 0.97260 (max = 0.97100)                                                                                                                                                                                       
[Epoch 19/20]
Train loss =     1.02871 (min =     1.01840)    accuracy = 0.95945 (max = 0.95863) 
Test  loss =     0.90141 (min =     0.84850)    accuracy = 0.97260 (max = 0.97100)
