In [16]:
%pip install rockpool

Note: you may need to restart the kernel to use updated packages.


# Import delle librerie

In [17]:
import numpy as np
import torch
import torch.nn as nn

from snntorch import spikeplot as splt
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
from pathlib import Path

# Download del datasets

In [18]:
# dataloader arguments
data_root='../data'
# Device and data repository
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Trasformazioni da applicare al Datasets

In [19]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

In [20]:
# Load Datasets
fmnist_train = datasets.FashionMNIST(data_root, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_root, train=False, download=True, transform=transform)

# Definizione della dimensione del batch

In [21]:
batch_size = 512

# Creazione dei dataloader

In [22]:
# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [23]:
# Utilities

# Targets dictionary
target_label = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

lib = "rockpool"

def plot_data_spike_recording(spk_rec, batch_size, targets, data, index=0):
    for i in range(0,batch_size):
        item_type = target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-')
        save_path = "../figures/classifications/{}/{}".format(
            lib,
            target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
        )
        Path(save_path).mkdir(parents=True, exist_ok=True)
        
        fig = plt.figure(figsize=(16, 9))
        bx, ax = fig.subplots(ncols=2)
        ax.set_xlim((-1,110))
        ax.set_ylim((-1,10))
        ax.set_yticks(range(0,10), target_label.items())
        ss = spk_rec[i,:,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Classification - {} - {}".format(target_label[targets[i].item()], targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number, Item class")

        bx.set_title(target_label[targets[i].item()])
        bx.set_xticks([], None)
        bx.set_yticks([], None)
        bx.imshow(data[i].squeeze().detach())

        fig.tight_layout()
        fig.savefig(
            "{}/{}-{}.eps".format(
                save_path,
                target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
                (index*batch_size)+i
            ),
            format='eps'
        )
        fig.savefig(
            "{}/{}-{}.png".format(
                save_path,
                target_label[targets[i].item()].lower().replace(' ','-').replace('/', '-'),
                (index*batch_size)+i
            ),
            format='png'
        )
        plt.close()

# Architettura e dinamica temporale

In [24]:
# Network Architecture
num_inputs = 28*28
# num_hidden = 1000
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 100
dt = 1e-2
tau_mem = Constant(100e-3)
tau_syn = Constant(50e-3)
threshold = Constant(1.)
bias = Constant(0.)

# Poisson Encode del batch

In [25]:
# - Define a function to encode an input into a poisson event series
def encode_poisson(data: torch.Tensor, num_steps: int, scale: float = 0.1) -> torch.Tensor:
    num_batches, _, _ = data.shape
    data = scale * data.view((num_batches, 1, -1)).repeat((1, num_steps, 1))
    return (torch.rand(data.shape) < (data * scale)).float()

# Encode in spike degli obiettivi

In [26]:
# - Define a function to encode the network target
def encode_class(class_idx: torch.Tensor, num_classes: int, num_steps: int) -> torch.Tensor:
    num_batches = class_idx.numel()
    target = torch.nn.functional.one_hot(class_idx, num_classes = num_classes)
    return target.view((num_batches, 1, -1)).repeat((1, num_steps, 1)).float()

# Network

In [27]:
# Define Network
def DefineNet(num_inputs, num_hidden, num_outputs):
    return Sequential(
        LinearTorch((num_inputs, num_hidden)),
        LIFTorch(
            num_hidden,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        ),
        LinearTorch((num_hidden, num_outputs)),
        LIFTorch(
            num_outputs,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        )
    )

In [28]:
# Network instantiation
net = DefineNet(num_inputs=num_inputs, num_hidden=num_hidden, num_outputs=num_outputs)
net.to(device=device)

TorchSequential  with shape (784, 10) {
    LinearTorch '0_LinearTorch' with shape (784, 512)
    LIFTorch '1_LIFTorch' with shape (512, 512)
    LinearTorch '2_LinearTorch' with shape (512, 10)
    LIFTorch '3_LIFTorch' with shape (10, 10)
}

In [29]:
run_train = False
state_dict_file_path = "../models/Rockpool-FMNIST-training.pt"

try:
    load_state_dict = torch.load( state_dict_file_path, map_location=device, )
    net.load_state_dict(load_state_dict)
except FileNotFoundError:
    print( "File not found running training" )
    run_train = True

if ( run_train == True ):
    print("Run the Rockpool-FMNIST-training.ipynb first")

In [30]:
index=0

# drop_last switched to False to keep all samples
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=False, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    enc_data = encode_poisson(data.squeeze(), num_steps)
    enc_targets = encode_class(targets, 10, num_steps)
    enc_targets = targets.to(device)

    # forward pass
    test_spk, _, _ = net(enc_data.to(device))

    plot_data_spike_recording(spk_rec=test_spk, batch_size=batch_size, targets=targets, data=data, index=index)
    index += 1