# Import delle librerie 

In [29]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import spikeplot as splt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# Batch size

In [30]:
batch_size = 512

# Dataset transformation

In [31]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

In [32]:
# dataloader arguments
data_root='../data'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Datasets

In [33]:
# Datasets
mnist_train = datasets.MNIST(data_root, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_root, train=False, download=True, transform=transform)

# Dataloaders

In [34]:
# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [35]:
"""
Utilities
"""
def plot_mem_rec(mem_rec, batch_size, targets):
    num_steps = len(mem_rec)
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-2,2))
        ss = mem_rec[:,i,:]
        plt.plot(range(0,num_steps), ss.cpu().detach())

        plt.title("Output Layer Membrane Output - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def plot_spk_rec(spk_rec, batch_size, targets):
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-1,11))
        ax.set_yticks(range(0,11))
        ss = spk_rec[:,i,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Output Layer - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def print_batch_accuracy(net, data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

# Architettura e dinamica temporale

In [36]:
# Network Architecture
num_inputs = 28*28
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 200
beta = 0.95

# Definizione della rete

In [37]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Creazione della rete

In [38]:
"""
Network instantiation
"""
net = Net().to(device=device)

In [39]:
# data, targets = next(iter(train_loader))
# print(data.size())
# print(targets.size())

In [40]:
# print(data.view(batch_size, -1).size())
# spk_rec, mem_rec = net(data.to(device).view(batch_size,-1))
# print(spk_rec.size())
# print(mem_rec.size())

In [41]:
# plot_spk_rec(spk_rec=spk_rec, batch_size=batch_size, targets=targets)

In [42]:
# plot_mem_rec(mem_rec=mem_rec, batch_size=batch_size, targets=targets)

In [43]:
# loss = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

# initialize the total loss value
# loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
# for step in range(num_steps):
#   loss_val += loss(mem_rec[step], targets.to(device))

# print(loss_val)

In [44]:
# clear previously stored gradients
# optimizer.zero_grad()

# calculate the gradients
# loss_val.backward()

# weight update
# optimizer.step()

In [45]:
# data, targets = next(iter(train_loader))

In [46]:
# calculate new network outputs using the same data
# spk_rec, mem_rec = net(data.to(device).view(batch_size, -1))

In [47]:
# initialize the total loss value
# loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
# for step in range(num_steps):
#   loss_val += loss(mem_rec[step], targets.to(device))

In [48]:
# plot_spk_rec(spk_rec=spk_rec,batch_size=batch_size,targets=targets)

In [49]:
# plot_mem_rec(mem_rec=mem_rec,batch_size=batch_size,targets=targets)

In [50]:
num_epochs = 3
loss_hist = []
test_loss_hist = []
counter = 0
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
net.train()

# Outer training loop
for epoch in range(num_epochs):
    # Minibatch training loop
    for data, targets in train_loader:
        print( "Epoch ", epoch, " Iteration: ", counter)
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        optimizer.zero_grad()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)
        print(loss_val)

        # Gradient calculation + weight update
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
        counter += 1
        print_batch_accuracy(net, data, targets, train=True)

Epoch  0  Iteration:  0
tensor([551.9994], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 27.15%
Epoch  0  Iteration:  1
tensor([412.0614], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 48.83%
Epoch  0  Iteration:  2
tensor([358.9297], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 61.52%
Epoch  0  Iteration:  3
tensor([334.2766], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 71.09%
Epoch  0  Iteration:  4
tensor([313.9788], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 72.66%
Epoch  0  Iteration:  5
tensor([300.6126], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 77.34%
Epoch  0  Iteration:  6
tensor([288.2823], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 77.73%
Epoch  0  Iteration:  7
tensor([265.4679], device='cuda:0', gr

In [51]:
# Test set
with torch.no_grad():
    net.eval()

    for test_data, test_targets in test_loader: 
        test_data = test_data.to(device)
        test_targets = test_targets.to(device)

        # Test set forward pass
        test_spk, test_mem = net(test_data.view(batch_size, -1))

        # Test set loss
        test_loss = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            test_loss += loss(test_mem[step], test_targets)
        test_loss_hist.append(test_loss.item())

        # Print train/test loss/accuracy
        print_batch_accuracy(net, test_data, test_targets)

Test set accuracy for a single minibatch: 91.41%
Test set accuracy for a single minibatch: 92.58%
Test set accuracy for a single minibatch: 90.43%
Test set accuracy for a single minibatch: 91.02%
Test set accuracy for a single minibatch: 92.19%
Test set accuracy for a single minibatch: 91.02%
Test set accuracy for a single minibatch: 92.58%
Test set accuracy for a single minibatch: 91.80%
Test set accuracy for a single minibatch: 91.02%
Test set accuracy for a single minibatch: 91.60%
Test set accuracy for a single minibatch: 91.80%
Test set accuracy for a single minibatch: 89.84%
Test set accuracy for a single minibatch: 90.62%
Test set accuracy for a single minibatch: 91.02%
Test set accuracy for a single minibatch: 90.62%
Test set accuracy for a single minibatch: 90.23%
Test set accuracy for a single minibatch: 93.36%
Test set accuracy for a single minibatch: 91.21%
Test set accuracy for a single minibatch: 93.36%


In [52]:
print(loss_hist)
print(test_loss_hist)

[551.9993896484375, 412.0614318847656, 358.9297180175781, 334.27655029296875, 313.978759765625, 300.6126403808594, 288.2822570800781, 265.4678649902344, 251.8961944580078, 252.32183837890625, 245.14071655273438, 232.74342346191406, 228.6934051513672, 230.4676055908203, 201.2893524169922, 208.553466796875, 188.87991333007812, 191.3806610107422, 182.64779663085938, 184.07432556152344, 186.99932861328125, 161.07884216308594, 160.49960327148438, 145.4230499267578, 154.9762725830078, 148.9999542236328, 146.00001525878906, 142.49546813964844, 161.40438842773438, 157.50450134277344, 155.08184814453125, 128.4412841796875, 131.46548461914062, 181.49862670898438, 166.49423217773438, 137.1864776611328, 139.57659912109375, 176.53164672851562, 182.0670166015625, 161.30960083007812, 196.18338012695312, 159.373779296875, 170.6416778564453, 156.44493103027344, 153.03680419921875, 166.39703369140625, 166.86058044433594, 176.5999755859375, 163.63377380371094, 152.28916931152344, 144.81759643554688, 220.

In [53]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

In [54]:
print(total)
print(correct)
print("{:.2f}".format(100*(correct/total)))

10000
9144
91.44
