# Import delle librerie

In [1]:
import numpy as np
import torch
import torch.nn as nn

from snntorch import spikeplot as splt
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# Download del datasets

In [2]:
# dataloader arguments
data_root='../data'
# Device and data repository
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Trasformazioni da applicare al Datasets

In [3]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

In [4]:
# Load Datasets
fmnist_train = datasets.FashionMNIST(data_root, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_root, train=False, download=True, transform=transform)

# Definizione della dimensione del batch

In [5]:
batch_size = 512

# Creazione dei dataloader

In [6]:
# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [7]:
# Utilities
def plot_mem_rec(mem_rec, batch_size, targets):
    num_steps = len(mem_rec)
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-2,2))
        ss = mem_rec[:,i,:]
        plt.plot(range(0,num_steps), ss.cpu().detach())

        plt.title("Output Layer Membrane Output - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def plot_spk_rec(spk_rec, batch_size, targets):
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-1,11))
        ax.set_yticks(range(0,11))
        ss = spk_rec[i,:,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Output Layer - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def print_batch_accuracy(net, data, targets, train=False):
    spk_rec, _, _ = net(data)
    _, data_idx = spk_rec.sum(dim=(1)).max(1)
    _, targets_idx = targets.sum(dim=(1)).max(1)
    acc = np.mean((targets_idx == data_idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(net, data, test_data, targets, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(net, data, targets, train=True)
    print_batch_accuracy(net, test_data, test_targets, train=False)
    print("\n")

# Architettura e dinamica temporale

In [8]:
# Network Architecture
num_inputs = 28*28
# num_hidden = 1000
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 100
dt = 1e-2
tau_mem = Constant(100e-3)
tau_syn = Constant(50e-3)
threshold = Constant(1.)
bias = Constant(0.)

# Poisson Encode del batch

In [9]:
# - Define a function to encode an input into a poisson event series
def encode_poisson(data: torch.Tensor, num_steps: int, scale: float = 0.1) -> torch.Tensor:
    num_batches, _, _ = data.shape
    data = scale * data.view((num_batches, 1, -1)).repeat((1, num_steps, 1))
    return (torch.rand(data.shape) < (data * scale)).float()

# Encode in spike degli obiettivi

In [10]:
# - Define a function to encode the network target
def encode_class(class_idx: torch.Tensor, num_classes: int, num_steps: int) -> torch.Tensor:
    num_batches = class_idx.numel()
    target = torch.nn.functional.one_hot(class_idx, num_classes = num_classes)
    return target.view((num_batches, 1, -1)).repeat((1, num_steps, 1)).float()

# Network

In [11]:
# Define Network
def DefineNet(num_inputs, num_hidden, num_outputs):
    return Sequential(
        LinearTorch((num_inputs, num_hidden)),
        LIFTorch(
            num_hidden,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        ),
        LinearTorch((num_hidden, num_outputs)),
        LIFTorch(
            num_outputs,
            tau_mem=tau_mem,
            tau_syn=tau_syn,
            threshold=threshold,
            bias=bias,
            dt=dt
        )
    )

In [12]:
# Network instantiation
net = DefineNet(num_inputs=num_inputs, num_hidden=num_hidden, num_outputs=num_outputs)
net.to(device=device)

TorchSequential  with shape (784, 10) {
    LinearTorch '0_LinearTorch' with shape (784, 512)
    LIFTorch '1_LIFTorch' with shape (512, 512)
    LinearTorch '2_LinearTorch' with shape (512, 10)
    LIFTorch '3_LIFTorch' with shape (10, 10)
}

In [13]:
num_epochs = 3
loss_hist = []
test_loss_hist = []
counter = 0
# loss = nn.CrossEntropyLoss()
# loss = nn.MSELoss()
loss = torch.nn.functional.mse_loss
optimizer = torch.optim.Adam(net.parameters().astorch())#, lr=5e-4)
net.train()

# Outer training loop
for epoch in range(num_epochs):
    # train_batch = iter()

    # Minibatch training loop
    for data, targets in train_loader:
        print( "Epoch ", epoch, " Iteration: ", counter)
        data = encode_poisson(data.squeeze(), num_steps)

        targets = encode_class(targets, 10, num_steps)

        # - Zero gradients, simulate model
        optimizer.zero_grad()
        # forward pass
        spk_rec, _, _ = net(data.to(device))
        loss_val = loss(spk_rec, targets.to(device))
        print(loss_val)
        # Gradient calculation + weight update
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
        counter += 1
        print_batch_accuracy(net, data.to(device), targets.to(device), train=True)
        


Epoch  0  Iteration:  0
tensor(0.1167, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 19.92%
Epoch  0  Iteration:  1
tensor(0.1076, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 28.32%
Epoch  0  Iteration:  2
tensor(0.1102, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 27.34%
Epoch  0  Iteration:  3
tensor(0.1163, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 31.45%
Epoch  0  Iteration:  4
tensor(0.1222, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 33.40%
Epoch  0  Iteration:  5
tensor(0.1288, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 42.97%
Epoch  0  Iteration:  6
tensor(0.1273, device='cuda:0', grad_fn=<MseLossBackward0>)
Train set accuracy for a single minibatch: 47.27%
Epoch  0  Iteration:  7
tensor(0.1242, device='cuda:0', grad_f

In [14]:
with torch.no_grad():
    net.eval()
    for test_data, test_targets in test_loader:
        test_data = encode_poisson(test_data.squeeze(), num_steps)
        test_data = test_data.to(device)
        test_targets = encode_class(test_targets, 10, num_steps)
        test_targets = test_targets.to(device)
        
        # Test set forward pass
        test_spk, _, _ = net(test_data)

        test_loss = torch.zeros((1), dtype=dtype, device=device)
        test_loss = loss(test_spk, test_targets)

        test_loss_hist.append(test_loss.item())

        # Print train/test loss/accuracy
        print_batch_accuracy(net, test_data.to(device), test_targets.to(device))

Test set accuracy for a single minibatch: 75.98%
Test set accuracy for a single minibatch: 76.37%
Test set accuracy for a single minibatch: 74.02%
Test set accuracy for a single minibatch: 76.76%
Test set accuracy for a single minibatch: 74.80%
Test set accuracy for a single minibatch: 75.98%
Test set accuracy for a single minibatch: 75.20%
Test set accuracy for a single minibatch: 75.39%
Test set accuracy for a single minibatch: 75.98%
Test set accuracy for a single minibatch: 78.91%
Test set accuracy for a single minibatch: 74.22%
Test set accuracy for a single minibatch: 73.83%
Test set accuracy for a single minibatch: 73.44%
Test set accuracy for a single minibatch: 73.24%
Test set accuracy for a single minibatch: 75.98%
Test set accuracy for a single minibatch: 75.39%
Test set accuracy for a single minibatch: 76.76%
Test set accuracy for a single minibatch: 76.37%
Test set accuracy for a single minibatch: 75.00%


In [15]:
print(loss_hist)
print(test_loss_hist)

[0.11671485006809235, 0.10761133581399918, 0.11019531637430191, 0.1163339912891388, 0.12217579036951065, 0.12882617115974426, 0.12726953625679016, 0.12419141083955765, 0.12432422488927841, 0.12160938233137131, 0.11897070705890656, 0.11580859869718552, 0.11087695509195328, 0.11392383277416229, 0.11056641489267349, 0.11030860245227814, 0.10856445878744125, 0.10395313054323196, 0.11546485126018524, 0.10676172375679016, 0.10620313137769699, 0.10250195860862732, 0.09828125685453415, 0.0985097736120224, 0.09978711605072021, 0.0974746122956276, 0.0972910225391388, 0.09257812798023224, 0.10100195556879044, 0.09818555414676666, 0.10102539509534836, 0.09586133062839508, 0.09988086670637131, 0.09141992777585983, 0.09733008593320847, 0.09626953303813934, 0.10298437625169754, 0.09915625303983688, 0.09199219197034836, 0.09403516352176666, 0.09735547006130219, 0.09660547226667404, 0.09378320723772049, 0.08806055039167404, 0.0908808633685112, 0.08885352313518524, 0.09732617437839508, 0.091865241527557

In [16]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = encode_poisson(data.squeeze(), num_steps)
    targets = encode_class(targets, 10, num_steps)
    targets = targets.to(device)

    # forward pass
    test_spk, _, _ = net(data.to(device))

    # calculate total accuracy
    _, predicted = test_spk.sum(1).max(1)
    total += len(targets.sum(1).max(1).indices)
    correct += (predicted == targets.sum(1).max(1).indices).sum().item()

NameError: name 'mnist_test' is not defined

In [None]:
print(total)
print(correct)
print("{:.2f}".format(100*(correct/total)))