In [None]:
import torch
# print(torch.version.cuda)
torch.cuda.is_available()
# print(torch.backends.cudnn.enabled)

In [None]:
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [None]:
# Define hyperparameters
batch_size = 64

# Load FER2013 dataset
transform = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Change brightness and contrast
    # transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),  # Add small shifts
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset
train_dataset = datasets.ImageFolder(
    root='./dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(
    root='./dataset/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,drop_last=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class CSNN(nn.Module):
    def __init__(self,beta=0.9, slope=25):
        super(CSNN, self).__init__()

        # Convolutional layers with 12 filters in first layer, 64 filters in second, etc.
        self.conv1 = nn.Conv2d(1, 32,5)  # Input is 1 channel (grayscale)
        self.lif1 = snn.Leaky(
            beta=beta, spike_grad=surrogate.fast_sigmoid(slope=slope))
        # self.dropout1 = nn.Dropout(p=0.1)

        self.conv2 = nn.Conv2d(32, 64,5)
        self.lif2 = snn.Leaky(
            beta=beta, spike_grad=surrogate.fast_sigmoid(slope=slope))
        # self.dropout2 = nn.Dropout(p=0.1)

        # Adjusted for smaller output size after convolution and pooling
        self.fc1 = nn.Linear(9 *9, 128)
        self.lif3 = snn.Leaky(
            beta=beta, spike_grad=surrogate.fast_sigmoid(slope=slope))
        # self.dropout3 = nn.Dropout(p=0.1)

        self.fc2 = nn.Linear(128, 7)  # 7 classes for FER2013 emotions
        self.lif4 = snn.Leaky(
            beta=beta, spike_grad=surrogate.fast_sigmoid(slope=slope))
        
        self.timesteps = 10

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        # Forward pass through convolution and pooling layers
        cur1 = self.conv1(x)
        pool1 = F.max_pool2d(cur1, 2)
        spk1, mem1 = self.lif1(pool1, mem1)

        cur2 = self.conv2(spk1)
        pool2 = F.max_pool2d(cur2, 2)
        spk2, mem2 = self.lif2(pool2, mem2)
        
        cur3 = self.fc1(spk2.view(spk2.size(0), -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        cur4 = self.fc2(spk3)
        spk4, mem4 = self.lif4(cur4, mem4)
        
        return spk4, mem4

        # spk_out = None
        # for t in range(self.timesteps):
        #     cur_input = x[t]

        #     # Layer 1: Conv + Pool + LIF
        #     cur1 = self.conv1(cur_input)
        #     pool1 = F.max_pool2d(cur1, 2)
        #     spk1, mem1 = self.lif1(pool1, mem1)

        #     # Layer 2: Conv + Pool + LIF
        #     cur2 = self.conv2(spk1)
        #     pool2 = F.max_pool2d(cur2, 2)
        #     spk2, mem2 = self.lif2(pool2, mem2)

        #     # Layer 3: Flatten + FC + LIF
        #     cur3 = self.fc1(spk2.view(spk2.size(0), -1))
        #     spk3, mem3 = self.lif3(cur3, mem3)

        #     # Layer 4: FC + LIF
        #     cur4 = self.fc2(spk3)
        #     spk4, mem4 = self.lif4(cur4, mem4)

        #     # Output layer
        #     out = self.fc3(spk4)
        #     spk_out = out if spk_out is None else spk_out + out

        # # Average the output over timesteps
        # return spk_out / self.timesteps

In [None]:
net = CSNN(beta=0.9).to(device)

In [None]:
#original beta = 0.9 lr = 0.0001 epochs = 100 accuracy = 49.33% loss < 0.9
#2nd try - beta = 0.99 lr = 0.0001 epochs = 20 accuracy = 35.64% loss 1.6463
# 3nd try - beta = 0.95 lr = 0.0001 epochs = 20 loss 1.6640
# 5nd try - beta = 0.9 lr = 0.001 epochs = 20 accuracy = 28.66% loss 1.7209
# 6nd try - beta = 0.9 lr = 0.0001 epochs = 20 accuracy = 35.62% loss 1.6436
# 7nd try - beta = 0.8 lr = 0.0001 epochs = 20 accuracy = 33.92% loss 1.6602
# 8nd try - beta = 0.8 lr = 0.001 epochs = 20 loss 1.7313
# 9nd try - beta = 0.7 lr = 0.0001 epochs = 20 loss 1.6574
# 10nd try - beta = 0.5 lr = 0.0001 epochs = 20  loss 1.6670
# 11nd try - beta = 0.65 lr = 0.0001 epochs = 20 loss 1.6332
# 8nd try - beta = 0.7 lr = 0.0001 epochs = 20 accuracy = 33.92% loss 1.6574

In [None]:
train_class_counts = {3: 7215, 4: 4965,
                      5: 4830, 2: 4097, 0: 3995, 6: 3171, 1: 436}
class_weights = torch.tensor([1 / train_class_counts[i]
                             for i in range(7)]).to(device)
# loss_fn = nn.CrossEntropyLoss(weight=class_weights)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
# loss_fn = SF.mse_count_loss(correct_rate=1.0,incorrect_rate=0)

optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

In [None]:
def forward_pass(net, data):
  spk_rec = []
  snn.utils.reset(net)
  for step in range(data.size(0)):
      spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)
  return torch.stack(spk_rec)

In [None]:
# from snntorch import spikegen

num_epochs = 100

num_steps = 100

# Training loop



def train_snn(num_epochs):

    for epoch in range(num_epochs):
        net.train()

        running_loss = 0

        for images, labels in train_loader:
            # images = add_noise(images)

            images, labels = images.to(device), labels.to(device)

            labels = labels.long()


            # Forward pass

            optimizer.zero_grad()
            # images = spikegen.rate(images,num_steps=10 )
            # spk_rec,_= net(images)
            # spk_rec,_,_= net(images)

            spk_rec= forward_pass(net, images)

            # print(spk_rec.size(),epoch)

            # spk_rec.squeeze(1)

            # labels = labels.view(-1)

            # print(spk_rec.size())
            labels_onehot = F.one_hot(labels, num_classes=7).long()
            loss = loss_fn(spk_rec, labels_onehot)


            # Backward pass and optimization

            loss.backward(retain_graph=True)

            optimizer.step()


            running_loss += loss.item()


        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')



train_snn(num_epochs)

In [None]:
def evaluate():
    correct = 0
    total = 0
    net.eval()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # images = spikegen.rate(images, num_steps=10)
            outputs, a, _ = net(images)
            # outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')


evaluate()

In [None]:
# torch.save(net.state_dict(), 'csnn2.pth')

In [None]:
def forward_pass(net, data):
  spk_rec = []
  snn.utils.reset(net)
  for step in range(data.size(0)):
      out, spk_out, mem_out = net(data[step])
      spk_rec.append(spk_out)
  return torch.stack(spk_rec)

In [None]:
num_epochs = 50
counter = 0

loss_hist = []
acc_hist = []
test_acc_hist = []

# Training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(iter(train_loader)):
        # Downsampling image from (128 x 128) to (32 x 32)
        # data = nn.functional.interpolate(data, size=(48, 48))
        data = data.to(device)
        targets = targets.to(device)

        net.train()
        # propagating one batch through the network and evaluating loss
        # spk_rec = forward_pass(net, data)
        data = spikegen.rate(data,num_steps=10)
        spk_rec = net(data)
        # targets_one_hot = F.one_hot(targets, num_classes=7).long()
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)

        # print metrics every so often
        if counter % 16 == 0:
          print(
              f"Epoch {epoch}, Iteration {i} \nTrain Loss: {loss_val.item():.2f}")
          print(f"Train Accuracy: {acc * 100:.2f}%\n")

          correct = 0
          total = 0

          for i, (data, targets) in enumerate(iter(test_loader)):
            # data = nn.functional.interpolate(data, size=(48,48))
            data = data.to(device)
            targets = targets.to(device)
            data = spikegen.rate(data,num_steps=10)
            spk_rec = net(data)
            # spk_rec = forward_pass(net, data)
            correct += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

          test_acc = (correct/total) * 100
          test_acc_hist.append(test_acc)
          print(f"========== Test Set Accuracy: {test_acc:.2f}% ==========\n")

        counter += 1

In [None]:
def train_snn(num_epochs, train_loader, val_loader, timesteps):
    for epoch in range(num_epochs):
        # Training Phase
        net.train()
        running_loss = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Rate encoding for training
            spike_trains = spikegen.rate(images, timesteps)

            # Forward pass
            optimizer.zero_grad()
            outputs = net(spike_trains)

            # Compute loss
            loss = loss_fn(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        print(
            f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}")
        # acc = SF.accuracy_rate(spk_rec, targets)
        # acc_hist.append(acc)

        # Validation Phase
        net.eval()
        val_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                # Rate encoding for validation
                spike_trains = spikegen.rate(images, timesteps)

                # Forward pass
                outputs = net(spike_trains)

                # Compute loss
                loss = loss_fn(outputs, labels)
                val_loss += loss.item()

                # Compute accuracy
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total
        print(
            f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

In [None]:
train_snn(num_epochs, train_loader, test_loader, timesteps=10)