In [58]:
%pip install snntorch

Note: you may need to restart the kernel to use updated packages.


# Import librerie

In [59]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import spikeplot as splt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# Batch size

In [60]:
batch_size = 512

# Dataset transformation

In [61]:
# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

In [62]:
# dataloader arguments
data_root='../data'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Datasets

In [63]:
fmnist_train = datasets.FashionMNIST(data_root, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_root, train=False, download=True, transform=transform)


# Dataloaders

In [64]:
# Create DataLoaders
train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [65]:
"""
Utilities
"""
def plot_mem_rec(mem_rec, batch_size, targets):
    num_steps = len(mem_rec)
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-2,2))
        ss = mem_rec[:,i,:]
        plt.plot(range(0,num_steps), ss.cpu().detach())

        plt.title("Output Layer Membrane Output - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def plot_spk_rec(spk_rec, batch_size, targets):
    for i in range(0,batch_size):
        fig = plt.figure()
        ax = fig.subplots()
        ax.set_xlim((-10,210))
        ax.set_ylim((-1,11))
        ax.set_yticks(range(0,11))
        ss = spk_rec[:,i,:]
        splt.raster(ss, ax, s=1, c="black")

        plt.title("Output Layer - {}".format(targets[i]))
        plt.xlabel("Time step")
        plt.ylabel("Neuron Number")
        fig.tight_layout()
    plt.show()


def print_batch_accuracy(net, data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

# Architettura e dinamica temporale

In [66]:
# Network Architecture
num_inputs = 28*28
num_hidden = 512
num_outputs = 10

# Temporal Dynamics
num_steps = 200
beta = 0.95

# Definizione della rete

In [67]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Creazione della rete

In [68]:
"""
Network instantiation
"""
net = Net().to(device=device)

In [69]:
# data, targets = next(iter(train_loader))
# print(data.size())
# print(targets.size())

In [70]:
# print(data.view(batch_size, -1).size())
# spk_rec, mem_rec = net(data.to(device).view(batch_size,-1))
# print(spk_rec.size())
# print(mem_rec.size())


In [71]:
# plot_spk_rec(spk_rec=spk_rec, batch_size=batch_size, targets=targets)

In [72]:
# plot_mem_rec(mem_rec=mem_rec, batch_size=batch_size, targets=targets)

In [73]:
# loss = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

# initialize the total loss value
# loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
# for step in range(num_steps):
#   loss_val += loss(mem_rec[step], targets.to(device))

# print(loss_val)

In [74]:
# clear previously stored gradients
# optimizer.zero_grad()

# calculate the gradients
# loss_val.backward()

# weight update
# optimizer.step()

In [75]:
# data, targets = next(iter(train_loader))

In [76]:
# calculate new network outputs using the same data
# spk_rec, mem_rec = net(data.to(device).view(batch_size, -1))

In [77]:
# initialize the total loss value
# loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
# for step in range(num_steps):
#   loss_val += loss(mem_rec[step], targets.to(device))

In [78]:
# plot_spk_rec(spk_rec=spk_rec,batch_size=batch_size,targets=targets)

In [79]:
# plot_mem_rec(mem_rec=mem_rec,batch_size=batch_size,targets=targets)

In [80]:
run_train = False
state_dict_file_path = "../models/snnTorch-FMNIST-training.pt"

try:
    load_state_dict = torch.load( state_dict_file_path, map_location=device, )
    net.load_state_dict(load_state_dict)
except FileNotFoundError:
    print( "File not found running training" )
    run_train = True

if ( run_train == True ):
    num_epochs = 3
    loss_hist = []
    test_loss_hist = []
    counter = 0
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
    net.train()


    # Outer training loop
    for epoch in range(num_epochs):
        # Minibatch training loop
        for data, targets in train_loader:
            print( "Epoch ", epoch, " Iteration: ", counter)
            data = data.to(device)
            targets = targets.to(device)

            # forward pass
            optimizer.zero_grad()
            spk_rec, mem_rec = net(data.view(batch_size, -1))

            # initialize the loss & sum over time
            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                loss_val += loss(mem_rec[step], targets)
            print(loss_val)
            # Gradient calculation + weight update
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            counter += 1
            print_batch_accuracy(net, data, targets, train=True)

    torch.save( net.state_dict(), state_dict_file_path )

File not found running training
Epoch  0  Iteration:  0
tensor([645.7095], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 29.88%
Epoch  0  Iteration:  1
tensor([444.5945], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 48.24%
Epoch  0  Iteration:  2
tensor([394.8665], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 50.78%
Epoch  0  Iteration:  3
tensor([350.8332], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 54.69%
Epoch  0  Iteration:  4
tensor([339.0765], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 56.84%
Epoch  0  Iteration:  5
tensor([296.2602], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 64.45%
Epoch  0  Iteration:  6
tensor([288.5741], device='cuda:0', grad_fn=<AddBackward0>)
Train set accuracy for a single minibatch: 59.38%
Epoch  0  Iteration:  7
tensor

In [81]:
# Test set
with torch.no_grad():
    net.eval()
    for test_data, test_targets in test_loader:
        test_data = test_data.to(device)
        test_targets = test_targets.to(device)

        # Test set forward pass
        test_spk, test_mem = net(test_data.view(batch_size, -1))

        # Test set loss
        test_loss = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            test_loss += loss(test_mem[step], test_targets)
        test_loss_hist.append(test_loss.item())

        # Print train/test loss/accuracy
        print_batch_accuracy(net, test_data, test_targets)

Test set accuracy for a single minibatch: 70.90%
Test set accuracy for a single minibatch: 62.89%
Test set accuracy for a single minibatch: 69.34%
Test set accuracy for a single minibatch: 69.14%
Test set accuracy for a single minibatch: 64.45%
Test set accuracy for a single minibatch: 65.23%
Test set accuracy for a single minibatch: 68.75%
Test set accuracy for a single minibatch: 62.70%
Test set accuracy for a single minibatch: 66.60%
Test set accuracy for a single minibatch: 67.38%
Test set accuracy for a single minibatch: 66.21%
Test set accuracy for a single minibatch: 66.99%
Test set accuracy for a single minibatch: 67.58%
Test set accuracy for a single minibatch: 66.80%
Test set accuracy for a single minibatch: 70.51%
Test set accuracy for a single minibatch: 64.84%
Test set accuracy for a single minibatch: 60.94%
Test set accuracy for a single minibatch: 68.95%
Test set accuracy for a single minibatch: 66.02%


In [82]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

In [83]:
print(total)
print(correct)
print("{:.2f}".format(100*(correct/total)))

10000
6680
66.80
