In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools
from spikingjelly.activation_based import neuron, functional, layer, surrogate

In [None]:
# Define hyperparameters
batch_size = 64

# Load FER2013 dataset
transform = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Change brightness and contrast
    # transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Add small shifts
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset
train_dataset = datasets.ImageFolder(
    root='./dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(
    root='./dataset/test', transform=transform)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 12 * 12, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 7)  # FER-2013 has 7 emotion classes
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [None]:
model = SimpleCNN().to(device)
model.load_state_dict(torch.load('best_model.pth', map_location=device))

In [None]:
from spikingjelly.activation_based import ann2snn

model_conv = ann2snn.Converter(mode='max', dataloader=train_loader,device=device)

snn_model = model_conv(model).to(device)


In [None]:
print

In [None]:
def evaluate():
    correct = 0
    total = 0
    snn_model.eval()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # images = spikegen.rate(images, num_steps=10)
            outputs = snn_model(images)
            # outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')


evaluate()