In [1]:
%pip install snntorch

Note: you may need to restart the kernel to use updated packages.


# Import librerie

In [2]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import spikeplot as splt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
from pathlib import Path

# Batch size

In [3]:
batch_size = 512

# Dataset transformation

In [4]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

In [5]:
# dataloader arguments
data_root='../data'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Datasets

In [6]:
fmnist_train = datasets.FashionMNIST(data_root, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_root, train=False, download=True, transform=transform)


# Dataloaders

In [7]:
# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [8]:
# Utilities

# Targets dictionary
target_label = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

lib = "snntorch"

def plot_data_spike_recording(spk_rec, batch_size, targets, data, index=0):
    for i in range(0,batch_size):
        item_type = target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-')
        save_path = "../figures/classifications/{}/{}".format(
            lib,
            target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
        )
        Path(save_path).mkdir(parents=True, exist_ok=True)

        fig = plt.figure(figsize=(16, 9))
        bx, ax = fig.subplots(ncols=2)
        ax.set_xlim((-1,210))
        ax.set_ylim((-1,10))
        ax.set_yticks(range(0,10), target_label.items())
        ss = spk_rec[:,i,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Classification - {} - {}".format(target_label[targets[i].item()], targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number, Item class")

        bx.set_title(target_label[targets[i].item()])
        bx.set_xticks([], None)
        bx.set_yticks([], None)
        bx.imshow(data[i].squeeze().detach())

        fig.tight_layout()
        fig.savefig(
            "{}/{}-{}.eps".format(
                save_path,
                target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
                (index*batch_size)+i
            ),
            format='eps'
        )
        fig.savefig(
            "{}/{}-{}.png".format(
                save_path,
                target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
                (index*batch_size)+i
            ),
            format='png'
        )
        plt.close()

# Architettura e dinamica temporale

In [9]:
# Network Architecture
num_inputs = 28*28
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 200
beta = 0.95

# Definizione della rete

In [10]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Creazione della rete

In [11]:
"""
Network instantiation
"""
net = Net().to(device=device)

In [12]:
run_train = False
state_dict_file_path = "../models/snnTorch-FMNIST-training.pt"

try:
    load_state_dict = torch.load( state_dict_file_path, map_location=device, )
    net.load_state_dict(load_state_dict)
except FileNotFoundError:
    print( "File not found running training" )
    run_train = True

if ( run_train == True ):
    print("Run the Rockpool-FMNIST-training.ipynb first")

In [13]:
# Changed batch size to make it easy to print image
index=0

# drop_last switched to False to keep all samples
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=False, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    device_data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(device_data.view(data.size(0), -1))

    plot_data_spike_recording(spk_rec=test_spk, batch_size=batch_size, targets=targets, data=data, index=index)
    index += 1

: 

: 