In [None]:
%pip install rockpool

# Import delle librerie

In [19]:
import numpy as np
import torch
import torch.nn as nn

from snntorch import spikeplot as splt
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# Download del datasets

In [20]:
# dataloader arguments
data_root='../data'
# Device and data repository
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Trasformazioni da applicare al Datasets

In [21]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

In [22]:
# Load Datasets
mnist_train = datasets.MNIST(data_root, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_root, train=False, download=True, transform=transform)

# Definizione della dimensione del batch

In [23]:
batch_size = 512

# Creazione dei dataloader

In [24]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [25]:
# Utilities
def plot_mem_rec(mem_rec, batch_size, targets):
    num_steps = len(mem_rec)
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-2,2))
        ss = mem_rec[:,i,:]
        plt.plot(range(0,num_steps), ss.cpu().detach())

        plt.title("Output Layer Membrane Output - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def plot_spk_rec(spk_rec, batch_size, targets):
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-1,11))
        ax.set_yticks(range(0,11))
        ss = spk_rec[i,:,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Output Layer - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def print_batch_accuracy(net, data, targets, train=False):
    spk_rec, _, _ = net(data)
    _, data_idx = spk_rec.sum(dim=(1)).max(1)
    _, targets_idx = targets.sum(dim=(1)).max(1)
    acc = np.mean((targets_idx == data_idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def csv_print_batch_accuracy(
        net, data, targets, train=False, epoch=0, iteration=0, 
        training_data_file="../data/training-testing.csv"):
    spk_rec, _, _ = net(data)
    _, data_idx = spk_rec.sum(dim=(1)).max(1)
    _, targets_idx = targets.sum(dim=(1)).max(1)
    acc = np.mean((targets_idx == data_idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(net, data, test_data, targets, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(net, data, targets, train=True)
    print_batch_accuracy(net, test_data, test_targets, train=False)
    print("\n")

# Architettura e dinamica temporale

In [26]:
# Network Architecture
num_inputs = 28*28
# num_hidden = 1000
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 100
dt = 1e-2
tau_mem = Constant(100e-3)
tau_syn = Constant(50e-3)
threshold = Constant(1.)
bias = Constant(0.)

# Poisson Encode del batch

In [27]:
# - Define a function to encode an input into a poisson event series
def encode_poisson(data: torch.Tensor, num_steps: int, scale: float = 0.1) -> torch.Tensor:
    num_batches, _, _ = data.shape
    data = scale * data.view((num_batches, 1, -1)).repeat((1, num_steps, 1))
    return (torch.rand(data.shape) < (data * scale)).float()

# Encode in spike degli obiettivi

In [28]:
# - Define a function to encode the network target
def encode_class(class_idx: torch.Tensor, num_classes: int, num_steps: int) -> torch.Tensor:
    num_batches = class_idx.numel()
    target = torch.nn.functional.one_hot(class_idx, num_classes = num_classes)
    return target.view((num_batches, 1, -1)).repeat((1, num_steps, 1)).float()

# Network

In [29]:
# Define Network
def DefineNet(num_inputs, num_hidden, num_outputs):
    return Sequential(
        LinearTorch((num_inputs, num_hidden)),
        LIFTorch(
            num_hidden,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        ),
        LinearTorch((num_hidden, num_outputs)),
        LIFTorch(
            num_outputs,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        )
    )

In [30]:
# Network instantiation
net = DefineNet(num_inputs=num_inputs, num_hidden=num_hidden, num_outputs=num_outputs)
net.to(device=device)

TorchSequential  with shape (784, 10) {
    LinearTorch '0_LinearTorch' with shape (784, 512)
    LIFTorch '1_LIFTorch' with shape (512, 512)
    LinearTorch '2_LinearTorch' with shape (512, 10)
    LIFTorch '3_LIFTorch' with shape (10, 10)
}

In [31]:
run_train = False
state_dict_file_path = "../models/Rockpool-MNIST-training.pt"

try:
    load_state_dict = torch.load( state_dict_file_path, map_location=device, )
    net.load_state_dict(load_state_dict)
except FileNotFoundError:
    print( "File not found running training" )
    run_train = True

if ( run_train == True ):
    num_epochs = 3
    loss_hist = []
    test_loss_hist = []
    counter = 0
    # loss = nn.CrossEntropyLoss()
    # loss = nn.MSELoss()
    loss = torch.nn.functional.mse_loss
    optimizer = torch.optim.Adam(net.parameters().astorch())#, lr=5e-4)
    net.train()

    # Outer training loop
    for epoch in range(num_epochs):
        # train_batch = iter()

        # Minibatch training loop
        for data, targets in train_loader:
            print( "Epoch ", epoch, " Iteration: ", counter)
            data = encode_poisson(data.squeeze(), num_steps)

            targets = encode_class(targets, 10, num_steps)

            # - Zero gradients, simulate model
            optimizer.zero_grad()
            # forward pass
            spk_rec, _, _ = net(data.to(device))
            loss_val = loss(spk_rec, targets.to(device))
            print(loss_val)
            # Gradient calculation + weight update
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            counter += 1
            print_batch_accuracy(net, data.to(device), targets.to(device), train=True)

    torch.save( net.state_dict(), state_dict_file_path )

In [32]:
with torch.no_grad():
    net.eval()
    for test_data, test_targets in test_loader:
        test_data = encode_poisson(test_data.squeeze(), num_steps)
        test_data = test_data.to(device)
        test_targets = encode_class(test_targets, 10, num_steps)
        test_targets = test_targets.to(device)
        
        # Test set forward pass
        test_spk, _, _ = net(test_data)

        test_loss = torch.zeros((1), dtype=dtype, device=device)
        test_loss = loss(test_spk, test_targets)

        test_loss_hist.append(test_loss.item())

        # Print train/test loss/accuracy
        print_batch_accuracy(net, test_data.to(device), test_targets.to(device))

Test set accuracy for a single minibatch: 90.04%
Test set accuracy for a single minibatch: 89.26%
Test set accuracy for a single minibatch: 91.60%
Test set accuracy for a single minibatch: 89.84%
Test set accuracy for a single minibatch: 89.45%
Test set accuracy for a single minibatch: 89.65%
Test set accuracy for a single minibatch: 89.45%
Test set accuracy for a single minibatch: 89.65%
Test set accuracy for a single minibatch: 88.09%
Test set accuracy for a single minibatch: 89.84%
Test set accuracy for a single minibatch: 89.45%
Test set accuracy for a single minibatch: 88.87%
Test set accuracy for a single minibatch: 90.62%
Test set accuracy for a single minibatch: 85.94%
Test set accuracy for a single minibatch: 89.45%
Test set accuracy for a single minibatch: 88.67%
Test set accuracy for a single minibatch: 89.45%
Test set accuracy for a single minibatch: 90.82%
Test set accuracy for a single minibatch: 88.09%


In [33]:
print(loss_hist)
print(test_loss_hist)

[0.10682813078165054, 0.10914649069309235, 0.11120312660932541, 0.12012305110692978, 0.12407618016004562, 0.12315430492162704, 0.13258399069309235, 0.13133594393730164, 0.1283300817012787, 0.1274140626192093, 0.12962110340595245, 0.12312696129083633, 0.12244727462530136, 0.12012695521116257, 0.11998828500509262, 0.12404688447713852, 0.11854297667741776, 0.11735742539167404, 0.11878320574760437, 0.11824023723602295, 0.1180429756641388, 0.10902539640665054, 0.10792383551597595, 0.10695899277925491, 0.11232031881809235, 0.10446875542402267, 0.10731054842472076, 0.10698828846216202, 0.10663477331399918, 0.10240625590085983, 0.10024414211511612, 0.11371093988418579, 0.10036914795637131, 0.1063535213470459, 0.10422461479902267, 0.10090039670467377, 0.09357226639986038, 0.09573633223772049, 0.10163672268390656, 0.09200391173362732, 0.09450195729732513, 0.09293555468320847, 0.09063086658716202, 0.09485937654972076, 0.0942968800663948, 0.0891738310456276, 0.09135742485523224, 0.0906054750084877

In [41]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = encode_poisson(data.squeeze(), num_steps)
    targets = encode_class(targets, 10, num_steps)
    targets = targets.to(device)

    # forward pass
    test_spk, _, _ = net(data.to(device))

    # calculate total accuracy
    _, predicted = test_spk.sum(1).max(1)
    total += len(targets.sum(1).max(1).indices)
    correct += (predicted == targets.sum(1).max(1).indices).sum().item()

In [42]:
print(total)
print(correct)
print("{:.2f}".format(100*(correct/total)))

10000
8951
89.51
