In [6]:
%reload_ext autoreload
%autoreload 2
from spieks.ann.training import train_ann
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from collections import OrderedDict
import os

print(f"Using pytorch {torch.__version__}")

N = 512
T = 2.0
DT = 1e-2

MAX_HZ = 10

BS = 64
LR = 1e-3
EPOCHS = 40

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device ({device})')

Using pytorch 2.3.1.post100
Using device (cuda)


Seeding

In [7]:
SEED = 42

torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

Define the model

In [8]:
class MNISTModel(nn.Module):
	def __init__(self):
		super(MNISTModel, self).__init__()
		self.flatten = nn.Flatten()
		self.w1 = nn.Linear(28 * 28, 128)
		self.r1 = nn.ReLU()
		self.w2 = nn.Linear(128, 64)
		self.r2 = nn.ReLU()
		self.w3 = nn.Linear(64, 10)
		self.r3 = nn.ReLU()

		nn.init.kaiming_uniform_(self.w1.weight, nonlinearity='relu')
		nn.init.kaiming_uniform_(self.w2.weight, nonlinearity='relu')
		nn.init.kaiming_uniform_(self.w3.weight, nonlinearity='relu')

	def forward(self, x):
		x = self.flatten(x)
		x = self.w1(x)
		x = self.r1(x)
		#print('l1', x.sum())
		x = self.w2(x)
		x = self.r2(x)
		#print('l2', x.sum())
		x = self.w3(x)
		x = self.r3(x)
		#print('l3', x.sum())
		return x

Train the base ANN on MNSIT

In [9]:
model_path = "tmp/models/best_mnist_model.pth"
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./tmp/data/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./tmp/data/', train=False, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device ({device})')

# Initialize model
model = MNISTModel().to(device)

# Replace all ReLU layers with QCFS to improve SNN inference
from spieks.network.converter import swap_layers
from spieks.ann.neurons import QCFS
swap_layers(model, old_layer_type=nn.ReLU, new_layer_type=QCFS, neuron_args={ "Q": 8 })

# Load or train model
if os.path.exists(model_path):
	# Load model if it exists
	model.load_state_dict(torch.load(model_path))
	print("Model loaded from file:", model_path)
else:
	# Train from scratch otherwise
	loss_fn = nn.CrossEntropyLoss()
	model = train_ann(model, train_loader, test_loader, loss_fn, EPOCHS, device, save_path=model_path)

Using device (cuda)
Model loaded from file: tmp/models/best_mnist_model.pth


In [10]:
# Calculate the ANN loss and accuracy
from spieks.ann.training import test_ann
ann_loss, ann_acc = test_ann(model, device, test_loader, loss_fn=nn.CrossEntropyLoss())
print(f"ANN Loss (CrossEntropy): {ann_loss}")
print(f"ANN Accuracy (CrossEntropy): {ann_acc}")

ANN Loss (CrossEntropy): 0.023488527187743247
ANN Accuracy (CrossEntropy): 0.9564302884615384


Convert the ANN to an SNN

In [11]:
from spieks.network.converter import Converter
from spieks.neurons import IF
spiking_model = Converter.convert(model, DT, model_subs={ QCFS: IF })

Evaluate the SNN on classification

In [None]:
from spieks.simulator import Classifier
classifier = Classifier(spiking_model, T, test_loader)
loss, accuracy = classifier.run(device=device)
print(f"Loss: {loss}")
print(f"Accuracy: {accuracy}")

  0%|          | 0/156 [00:00<?, ?it/s]