In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from decolle_loss import DECOLLELoss
import lava.lib.dl.slayer as slayer
from nmnist_dataset import NMNISTDataset, augment # Import NMNIST dataset.

# Create Network

In [2]:
class DECOLLENetwork(torch.nn.Module):
    def __init__(self, input_shape, hidden_shape, output_shape, burn_in=0):
        super(DECOLLENetwork, self).__init__()
        neuron_params = {
            "threshold": 1.25,
            "current_decay": 0.25,
            "voltage_decay": 0.03,
            "tau_grad": 0.03,
            "scale_grad": 3,
            "requires_grad": True,
            "persistent_state": True
        }

        neuron_params_drop = {**neuron_params}
        self.burn_in = burn_in
        
        self.blocks = torch.nn.ModuleList() # Network's feedforward blocks.
        self.readout_layers = torch.nn.ModuleList() # Network's DFA weights to calculate loss.
        hidden_shape = [input_shape] + hidden_shape
        
        for i in range(len(hidden_shape)-1):
            self.blocks.append(slayer.block.cuba.Dense(
            neuron_params_drop, hidden_shape[i], hidden_shape[i+1],
            weight_norm=False)
            )
            
            # One fixed readout per layer.
            readout = torch.nn.Linear(hidden_shape[i+1], output_shape, bias=False)
            readout.weight.requires_grad=False
            self.readout_layers.append(readout)
            
    def forward(self, spike):
        spike.requires_grad_() # Set requires grad of input spikes to True.
        spikes = []
        readouts = []
        voltages = []
        count = []
            
        for block in self.blocks:
            # Decompose the behaviour of the block to obtain the voltages
            # for regularization.
            z = block.synapse(spike.detach())
            #z = block.synapse(spike)
            _, voltage = block.neuron.dynamics(z)
            voltages.append(voltage)
                
            spike = block.neuron.spike(voltage)
            spikes.append(spike)
            count.append(torch.mean(spike.detach()))
            
        for ro, spike in zip(self.readout_layers, spikes):
            # Compute readouts with layer-wise output spikes as input to the Dense.
            readout = []
            for t in range(spike.shape[-1]):
                readout.append(ro(spike[..., t]))
            readouts.append(torch.stack(readout, dim=-1))
            
        return spikes, readouts, voltages, count
        
    def init_state(self, inputs, burn_in=None):
        self.reset_()
        # Initialize the network states + crop inputs.
        if burn_in is None:
            burn_in = self.burn_in
            
        self.forward(inputs[..., :burn_in])
        return inputs[..., burn_in:]
        
    def reset_(self):
        # reset the states after each example.
        for block in self.blocks:
            block.neuron.current_state[:] = 0.
            block.neuron.voltage_state[:] = 0.

# Instantiate Network, Optimizer, Dataset and DataLoader

In [3]:
trained_folder = "Trained"
os.makedirs(trained_folder, exist_ok=True)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

net = DECOLLENetwork(input_shape=34*34*2,
                     hidden_shape=[512, 256],
                     output_shape=10,
                     burn_in=10).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

training_set = NMNISTDataset(train=True, transform=augment, download=True)
testing_set = NMNISTDataset(train=False)

train_loader = DataLoader(dataset=training_set, batch_size=8, shuffle=True)
test_loader = DataLoader(dataset=testing_set, batch_size=8, shuffle=True)


NMNIST dataset is freely available here:
https://www.garrickorchard.com/datasets/n-mnist
(c) Creative Commons:
    Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N.
    "Converting Static Image Datasets to Spiking Neuromorphic Datasets Using
    Saccades",
    Frontiers in Neuroscience, vol.9, no.437, Oct. 2015



# Error Module

In [4]:
# DECOLLELoss is used to compute per-layer pseudo errors.
# Regularization allows to control per-layer spike rates.
error = DECOLLELoss(torch.nn.CrossEntropyLoss, reg=0.01, reduction="mean")

# Training Loop

In [5]:
epochs = 5
training_mode = "batch" # either "online" or "batch".

test_losses = []
test_accs_l1 = [] # Test accuracies of the first layer.
test_accs_l2 = [] # Test accuracies of the second layer.

for epoch in range(epochs):
    train_iter = iter(train_loader)
    net.train()
    batch = 0
    for inputs, target in train_iter: # training loop.
        batch += 1
        inputs = inputs.reshape([inputs.shape[0], -1, inputs.shape[-1]]).to(device)
        target = target.to(device)
        
        # Reset network's state + burn-in + resize the inputs accordingly.
        inputs = net.init_state(inputs)
        
        # training phase.
        if training_mode == "online":
            for t in range(inputs.shape[-1]):
                x = inputs[..., t].unsqueeze(-1)
                spikes, readouts, voltages, count_t = net(x)
                loss = error(readouts, voltages, target)
                loss.backward()
                
                optimizer.step()
                optimizer.zero_grad()
        else:
            spikes, readouts, voltages, count = net(inputs)
            loss = error(readouts, voltages, target)
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
        if batch%100==0:
            print("Training Epoch: %s, Batch: %s Done." % (epoch, batch))
        if batch == 500:
            break
            
    with torch.no_grad():
        test_iter = iter(test_loader)
        net.eval()
        preds_test_l1 = torch.Tensor()
        preds_test_l2 = torch.Tensor()
        targets_test = torch.Tensor()
        
        test_losses.append(0.)
        batch = 0
        for inputs, target in test_iter: # test loop.
            with torch.no_grad():
                inputs = inputs.reshape([inputs.shape[0], -1, inputs.shape[-1]]).to(device)
                target = target.to(device)
                
                # Reset net state + burn-in + resize the inputs accordingly.
                inputs = net.init_state(inputs)
                
                # Forward pass + record loss.
                spikes, readouts, voltages, count = net(inputs)
                loss = error(readouts, voltages, target)
                test_losses[epoch] += loss.cpu().numpy()
                
                preds_test_l1 = torch.cat((preds_test_l1, torch.mean(readouts[0], dim=-1).argmax(-1).cpu()))
                preds_test_l2 = torch.cat((preds_test_l2, torch.mean(readouts[1], dim=-1).argmax(-1).cpu()))
                targets_test = torch.cat((targets_test, target.cpu()))
            batch += 1
            if batch == 100:
                print("Partial Testing Done. Breaking loop...")
                break
                
        acc_test_l1 = torch.mean((preds_test_l1 == targets_test).type(torch.float))
        acc_test_l2 = torch.mean((preds_test_l2 == targets_test).type(torch.float))
        
        test_accs_l1.append(acc_test_l1.cpu().numpy())
        test_accs_l2.append(acc_test_l2.cpu().numpy())
        
    if (epoch+1)%1 == 0:
        print("\r", " "*len(f"\r[Epoch {epoch:2d}/{epochs}]"))
        print("Test loss = %f (min = %f)   accuracy = %f (max = %f)" % (
              test_losses[epoch], np.min(test_losses), acc_test_l2, np.max(test_accs_l2)))
    
    if acc_test_l2 >= np.max(test_accs_l2):
        torch.save(net.state_dict(), "./" + trained_folder + "/network.pt")

Training Epoch: 0, Batch: 100 Done.
Training Epoch: 0, Batch: 200 Done.
Training Epoch: 0, Batch: 300 Done.
Training Epoch: 0, Batch: 400 Done.
Training Epoch: 0, Batch: 500 Done.
Partial Testing Done. Breaking loop...
              
Test loss = 360.149810 (min = 360.149810)   accuracy = 0.572500 (max = 0.572500)
Training Epoch: 1, Batch: 100 Done.
Training Epoch: 1, Batch: 200 Done.
Training Epoch: 1, Batch: 300 Done.
Training Epoch: 1, Batch: 400 Done.
Training Epoch: 1, Batch: 500 Done.
Partial Testing Done. Breaking loop...
              
Test loss = 337.444161 (min = 337.444161)   accuracy = 0.700000 (max = 0.700000)
Training Epoch: 2, Batch: 100 Done.
Training Epoch: 2, Batch: 200 Done.
Training Epoch: 2, Batch: 300 Done.
Training Epoch: 2, Batch: 400 Done.
Training Epoch: 2, Batch: 500 Done.
Partial Testing Done. Breaking loop...
              
Test loss = 328.739234 (min = 328.739234)   accuracy = 0.747500 (max = 0.747500)
Training Epoch: 3, Batch: 100 Done.
Training Epoch: 3, 

# Investigating DECOLLE Network

In [6]:
net.blocks

ModuleList(
  (0): Dense(
    (neuron): Neuron()
    (synapse): Dense(2312, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  )
  (1): Dense(
    (neuron): Neuron()
    (synapse): Dense(512, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  )
)

In [7]:
net.readout_layers

ModuleList(
  (0): Linear(in_features=512, out_features=10, bias=False)
  (1): Linear(in_features=256, out_features=10, bias=False)
)

## Accuracies from readout Layer 1 and Layer 2 

In [8]:
test_accs_l1

[array(0.665, dtype=float32),
 array(0.71125, dtype=float32),
 array(0.74, dtype=float32),
 array(0.7425, dtype=float32),
 array(0.75625, dtype=float32)]

In [9]:
test_accs_l2

[array(0.5725, dtype=float32),
 array(0.7, dtype=float32),
 array(0.7475, dtype=float32),
 array(0.79875, dtype=float32),
 array(0.78875, dtype=float32)]