In [None]:
from spieks.neurons import SpikingNeuron, IF
from spieks.network import run_sim
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import os
from tqdm.notebook import tdqm, trange
import numpy as np

N = 512
T = 1
DT = 1e-2

MAX_HZ = 10

BS = 128
LR = 1e-3
EPOCHS = 20

In [None]:
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.fc(x)

In [None]:
model_path = "tmp/best_mnist_model.pth"
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./tmp/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./tmp/', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BS, shuffle=False)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device ({device})')

# Initialize model, criterion, optimizer
model = MNISTModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Load model if it exists
if os.path.exists(model_path):
	model.load_state_dict(torch.load(model_path))
	print("Model loaded from file:", model_path)
else:
	best_accuracy = 0
	train_losses, test_accuracies = [], []

	# Training loop
	for epoch in trange(EPOCHS):
		model.train()
		running_loss = 0.0
		for inputs, targets in train_loader:
			inputs, targets = inputs.to(device), targets.to(device)
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, targets)
			loss.backward()
			optimizer.step()
			running_loss += loss.item()
		train_loss = running_loss / len(train_loader)
		train_losses.append(train_loss)
		scheduler.step()
		
		# Evaluate on test set
		model.eval()
		correct = 0
		total = 0
		with torch.no_grad():
			for inputs, targets in test_loader:
				inputs, targets = inputs.to(device), targets.to(device)
				outputs = model(inputs)
				_, predicted = torch.max(outputs, 1)
				total += targets.size(0)
				correct += (predicted == targets).sum().item()
		accuracy = correct / total
		test_accuracies.append(accuracy)

		print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss:.4f}, Accuracy: {accuracy:.4f}")

		# Save the model if it's the best so far
		if accuracy > best_accuracy:
			best_accuracy = accuracy
			torch.save(model.state_dict(), model_path)
			
	print("Model saved with accuracy:", best_accuracy)
	
	# Plot loss and accuracy
	plt.figure(figsize=(10, 5))
	plt.subplot(1, 2, 1)
	plt.plot(range(1, EPOCHS + 1), train_losses, label='Loss')
	plt.title("Training Loss")
	plt.xlabel("Epoch")
	plt.ylabel("Loss")
	plt.legend()

	plt.subplot(1, 2, 2)
	plt.plot(range(1, EPOCHS + 1), test_accuracies, label='Accuracy')
	plt.title("Test Accuracy")
	plt.xlabel("Epoch")
	plt.ylabel("Accuracy")
	plt.legend()

	plt.tight_layout()
	plt.show()

In [None]:
# Convert the ann to an snn by replacing all relu layers with IF neurons
def swap_layers(model, old_layer_type: type[nn.Module], new_layer_type: type[nn.Module]):
    for name, module in model.named_children():
        if isinstance(module, old_layer_type):
            setattr(model, name, new_layer_type(*module.parameters()))
        elif isinstance(module, nn.Module):
            swap_layers(module, old_layer_type, new_layer_type)
    return model
model = swap_layers(model, nn.ReLU, IF)
print(model)

In [None]:
# Check the accuracy after conversion
def eval_snn(test_dataloader, model, loss_fn, device, sim_len=8, rank=0, batches=-1, dt=1.0):
    tot = torch.zeros(sim_len).to(device)
    loss = torch.zeros(sim_len).to(device)
    length = 0
    model.eval()
    model = model.to(device)

    with torch.no_grad():
        for idx, (img, label) in enumerate(tqdm(test_dataloader)):
            if batches > 0 and idx >= batches:
                break
            length += len(label)
            img = img.to(device)
            label = label.to(device)
            spikes = torch.zeros_like(model(img))
            time, activation = run_sim(model, img, )
			
            #time = np.arange(1, sim_len + 1) * dt
            #for i, t in enumerate(time):
            #	out = model(img)
            #	spikes += out
            #	tot[i] += (label==spikes.max(1)[1]).sum()
            #	loss[i] += loss_fn(spikes / t, label)
            print('label:', label)
            print('spikes:', spikes)
            print('spikes.argmax(dim=1):', spikes.argmax(dim=1))
    return tot.detach().cpu().numpy() / length, loss.detach().cpu().numpy() / length