In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset, Subset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
import snntorch as snn
import snntorch.spikeplot as splt
from IPython.display import HTML

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:

DATASET_PATH = Path('dataset')
EMOTIONS = {
    'Neutral': 0,
    'Happy': 1,
    'Sad': 2,
    'Anger': 3,
    'Fear': 4,
    'Disgust': 5,
    'Surprise': 6
}
TRAIN_SPLIT = 0.8
VALIDATION_SPLIT = 0

HIDDEN_SIZE = 256
LAYER_COUNT = 4

BATCH_SIZE = 128
LEARNING_RATE = 1e-4
GRADIENT_CLIPPING_MAX_NORM = 0.5
MAX_EPOCHS = 3000
EARLY_STOPPING_PATIENCE = 200
PRINT_FREQUENCY_EPOCHS = 10



In [None]:

class SeedDataset(Dataset):
    def __init__(self):
        self.features = []
        for subject_index in range(1, 21):
            subject_features = sio.loadmat(str(DATASET_PATH / 'EEG_features'
                                               / f'{subject_index}.mat'))
            for video_index in range(1, 81):
                de_features = subject_features[f'de_LDS_{video_index}']
                # Flatten the frequency band and EEG channel dimensions.
                de_features = de_features.reshape(de_features.shape[0], -1)
                # Shape: (sequence length, input size (5 * 62))
                self.features.append(de_features)
        labels = pd.read_excel(
            DATASET_PATH / 'emotion_label_and_stimuli_order.xlsx', header=None,
            usecols='B:U', skiprows=lambda row_index: row_index % 2 == 0
        )
        labels = labels.values.flatten().tolist()
        labels = [EMOTIONS[label] for label in labels]
        labels = labels * 20
        self.labels = labels

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> tuple[np.ndarray, int]:
        return self.features[idx], self.labels[idx]



In [None]:

def get_data_loaders() -> tuple[DataLoader, DataLoader, DataLoader]:
    dataset = SeedDataset()
    row_count = len(dataset)
    train_indices = list(range(0, int(TRAIN_SPLIT * row_count)))
    validation_indices = list(range(
        int(TRAIN_SPLIT * row_count),
        int((TRAIN_SPLIT + VALIDATION_SPLIT) * row_count)))
    test_indices = list(range(
        int((TRAIN_SPLIT + VALIDATION_SPLIT) * row_count), row_count))
    train_dataset = Subset(dataset, train_indices)
    validation_dataset = Subset(dataset, validation_indices)
    test_dataset = Subset(dataset, test_indices)

    def collate_fn(batch: list[tuple[np.ndarray, int]]) -> tuple:
        sequences = [torch.tensor(sequences_and_label[0], dtype=torch.float)
                     for sequences_and_label in batch]
        labels = torch.tensor([sequences_and_label[1]
                               for sequences_and_label in batch],
                              dtype=torch.long)
        sequence_lengths = torch.tensor([sequence.shape[0]
                                         for sequence in sequences],
                                        dtype=torch.long)
        # Shape: (batch size, max sequence length, input size)
        padded_sequences = pad_sequence(sequences, batch_first=True)
        return padded_sequences, sequence_lengths, labels

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_fn)
    validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE,
                                   shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, collate_fn=collate_fn)
    return train_loader, validation_loader, test_loader



In [None]:
train_loader, validation_loader, test_loader = get_data_loaders()


In [None]:
# mode 3: each input feature is used as the probability a spike occurs at any given time step
def gen_spike_data_bernoulli(x, T=50):
    """
    Encodes input data into spike trains using a Bernoulli process.

    Args:
        x (torch.Tensor): Input data tensor of shape [batch_size, num_features].
        T (int): Number of time steps for the spike train.

    Returns:
        torch.Tensor: Spike train tensor of shape [T, batch_size, num_features].
    """
    res = []
    for t in range(T):
        # Generate spikes based on Bernoulli trials for each feature
        encoded = torch.bernoulli(x)
        res.append(encoded)
    return torch.stack(res)  # Shape: [T, batch_size, num_features]

In [None]:
# Normalize the emotion data to [0, 1]
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)

# Normalize to [0, 1] for Bernoulli trials
X_train_tensor = (X_train_tensor - X_train_tensor.min()) / (X_train_tensor.max() - X_train_tensor.min())
X_test_tensor = (X_test_tensor - X_test_tensor.min()) / (X_test_tensor.max() - X_test_tensor.min())

# Generate spike data
T = 50  # Number of time steps
spike_train_train = gen_spike_data_bernoulli(X_train_tensor, T)
spike_train_test = gen_spike_data_bernoulli(X_test_tensor, T)

print("Spike train shape (train):", spike_train_train.shape)  # Should be [T, batch_size, num_features]
print("Spike train shape (test):", spike_train_test.shape)

In [None]:
class LeakySurrogate(nn.Module):
    def __init__(self, beta, z=1, threshold=1.0):
        super(LeakySurrogate, self).__init__()

        # initialize decay rate beta and threshold
        self.beta = beta
        self.threshold = threshold
        self.spike_op = self.SpikeOperator.apply
        self.z = z
        self.mem = None

    # the forward function is called each time we call Leaky
    def forward(self, input_):
        spk = self.spike_op(self.mem - self.threshold, self.z)  # call the Heaviside function
        reset = (spk * self.threshold).detach() # removes spike_op gradient from reset
        self.mem = self.beta * self.mem + input_ - reset
        return spk

    # forward pass: Heaviside function
    @staticmethod
    class SpikeOperator(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, z):
            """
            In the forward pass we compute a step function of the input Tensor
            and return it. ctx is a context object that we use to stash information which
            we need to later backpropagate our error signals. To achieve this we use the
            ctx.save_for_backward method.
            """
            ctx.save_for_backward(input)
            ctx.z = z
            spk = torch.zeros_like(input)
            spk[input > 0] = 1.0
            return spk

        @staticmethod
        def backward(ctx, grad_output):
            """
            In the backward pass we receive a Tensor we need to compute the
            surrogate gradient of the loss with respect to the input.
            Here we use the fast Sigmoid function with z = 1.
            """
            input, = ctx.saved_tensors
            z = ctx.z
            grad_input = grad_output.clone()
            # TODO: add your implementation here.
            grad = (
                grad_input
                * z * torch.exp(-z * input)
                / (torch.exp(-z * input) + 1) ** 2
            )
            return grad, None

In [None]:
class SNN(nn.Module):
    def __init__(self, T, beta=0.8, z=1, threshold=1.0):
        super(SNN, self).__init__()
        self.T = T
        self.flatten = nn.Flatten()
        # 1st fully-connected layer
        self.fc1 = nn.Linear(1500, 10)
        self.lif1 = LeakySurrogate(beta=beta, z=z, threshold=threshold)
        # 2nd fully-connected layer
        self.fc2 = nn.Linear(10, 10)
        # output layer neurons, whose firing rate will be served as the final prediction
        self.lif2 = LeakySurrogate(beta=beta, z=z, threshold=threshold)

    def init_mem(self, batch_size, feature_num):
        return nn.init.kaiming_uniform_(torch.empty(batch_size, feature_num)).to(device)

    # define the forward pass
    def forward(self, input_):
        self.lif1.mem = self.init_mem(input_.shape[1], 10)
        self.lif2.mem = self.init_mem(input_.shape[1], 10)
        
        output_spikes = 0
        for t in range(self.T):
            x = input_[t]
            x = self.flatten(x)
            x = self.fc1(x)
            spk1 = self.lif1(x)
            x = self.fc2(spk1)
            spk2 = self.lif2(x)
            output_spikes = output_spikes + spk2

        return output_spikes / self.T