In [2]:
from torch import nn
import torch
import snntorch as snn
from snntorch import surrogate
from snntorch import utils
from torchsummary import summary
from src.utils.parameters import load_parameters
from src.utils.dataloaders import load_mnist_dataloader
from snntorch import spikegen
from src.utils.parameters import load_fnc

encoding = load_fnc('snntorch.spikegen', 'rate')

params = load_parameters('./params/mnist.yaml')
data_params = params['dataset']['parameters']

gpu = torch.cuda.is_available()
DEVICE = torch.device("cuda") if gpu else torch.device("cpu")

in_shape = (1, 28, 28)
in_features = 16
out_features = 32
k_size = 3
spike_grad = surrogate.fast_sigmoid()
beta = 0.5
num_conv = 2
num_steps = 50

fc_features = int(
    (in_shape[1] - num_conv * k_size + 1 * num_conv) / 2**num_conv) - 1
n_class = 10

model = nn.Sequential(
            nn.Conv2d(in_shape[0], in_features, k_size),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
            
            nn.Conv2d(in_features, out_features, k_size),
            nn.MaxPool2d(2),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
            
            nn.Flatten(),
            nn.Linear(out_features*fc_features*fc_features, n_class),
            snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)).to(DEVICE)

summary(model, (1, 28, 28))

def forward_pass(net, data, num_steps):
        spike_data = encoding(data, num_steps=num_steps)
        mem_rec = []
        spk_rec = []

        utils.reset(net)

        for t in range(num_steps):
            spk_out, mem_out = net(spike_data[t]) 
            spk_rec.append(spk_out)
            mem_rec.append(mem_out)
        
        return torch.stack(spk_rec), torch.stack(mem_rec)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 26, 26]             160
         MaxPool2d-2           [-1, 16, 13, 13]               0
             Leaky-3           [-1, 16, 13, 13]               0
            Conv2d-4           [-1, 32, 11, 11]           4,640
         MaxPool2d-5             [-1, 32, 5, 5]               0
             Leaky-6             [-1, 32, 5, 5]               0
           Flatten-7                  [-1, 800]               0
            Linear-8                   [-1, 10]           8,010
             Leaky-9       [[-1, 10], [-1, 10]]               0
Total params: 12,810
Trainable params: 12,810
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.17
Params size (MB): 0.05
Estimated Total Size (MB): 0.22
---------------------------------------------

In [3]:
from tqdm import tqdm
from snntorch import functional as SF
import numpy as np

model.train()
lr = .001
train_dl, test_dl, n = load_mnist_dataloader(data_params, True)

optimizer = torch.optim.Adam(model.parameters(),lr=lr)
criterion = SF.ce_rate_loss()
        
losses = []

In [4]:
with tqdm(train_dl, leave=False, desc="Running training phase") as pbar:
    for step, (data, targets) in enumerate(train_dl):
        data = data.to(DEVICE)
        targets = targets.to(DEVICE)
        optimizer.zero_grad()
        spk_rec, _ = forward_pass(model, data, num_steps)
        loss_val = criterion(spk_rec, targets)
        loss_val.backward()
        optimizer.step()
        losses.append(loss_val.item())
        pbar.set_description(
            f"Running training phase | loss/train : {np.mean(losses):.4f}")
        pbar.update()

                                                                                               

In [5]:
with torch.no_grad():
    total = 0
    acc = 0
    model.eval()
    with tqdm(test_dl, leave=False, desc="Running testing phase") as pbar:
        for data, targets in test_dl:
            data = data.to(DEVICE)
            targets = targets.to(DEVICE)
            spk_rec, _ = forward_pass(model, data, num_steps)
            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)
            pbar.update()
total_acc = round(acc/total * 100, 2)
print(f'Final Accuracy {total_acc}%')

                                                                      

Final Accuracy 95.94%


