In [1]:
%reload_ext autoreload
%autoreload 2
from spieks.ann.training import train_ann
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import os
print(f"Using pytorch {torch.__version__}")

N = 512
T = 1.0
DT = 1e-2

MAX_HZ = 10

BS = 32
LR = 1e-3
EPOCHS = 40

Using pytorch 2.3.1.post100


Define the model

In [2]:
class MNISTModel(nn.Module):
	def __init__(self):
		super(MNISTModel, self).__init__()
		self.flatten = nn.Flatten()
		self.w1 = nn.Linear(28 * 28, 128)
		self.r1 = nn.ReLU()
		self.w2 = nn.Linear(128, 64)
		self.r2 = nn.ReLU()
		self.w3 = nn.Linear(64, 10)
		#self.r3 = nn.ReLU()

	def forward(self, x):
		x = self.flatten(x)
		x = self.w1(x)
		x = self.r1(x)
		x = self.w2(x)
		x = self.r2(x)
		x = self.w3(x)
		#x = self.r3(x)
		return x

Train the base ANN on MNSIT

In [3]:
model_path = "tmp/models/best_mnist_model.pth"
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./tmp/data/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./tmp/data/', train=False, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True, drop_last=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=True, drop_last=True, pin_memory=True)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device ({device})')

# Initialize model, criterion, optimizer
model = MNISTModel().to(device)
if os.path.exists(model_path):
	# Load model if it exists
	model.load_state_dict(torch.load(model_path))
	print("Model loaded from file:", model_path)
else:
	# Train from scratch otherwise
	loss_fn = nn.CrossEntropyLoss()
	model = train_ann(model, train_loader, test_loader, loss_fn, EPOCHS, device, save_path=model_path)

Using device (cuda)
Model loaded from file: tmp/models/best_mnist_model.pth


Convert the ANN to an SNN

In [4]:
from spieks.network.converter import Converter
spiking_model = Converter.convert(model, DT)
print(spiking_model)

SpikingNetwork(
  (net): MNISTModel(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (w1): Linear(in_features=784, out_features=128, bias=True)
    (r1): IF()
    (w2): Linear(in_features=128, out_features=64, bias=True)
    (r2): IF()
    (w3): Linear(in_features=64, out_features=10, bias=True)
  )
)


Evaluate the SNN on classification

In [5]:
from spieks.simulator import Classifier
classifier = Classifier(spiking_model, T, test_dataset)
loss, accuracy = classifier.run(device=device)
print(f"Loss: {loss}")
print(f"Accuracy: {accuracy}")

  0%|          | 0/157 [00:00<?, ?it/s]

Loss: 19.23607063293457
Accuracy: 0.0974
