# EEG Signal Classification Notebook

## Introduction

The main objective of this task is to obtain a neural representation from the output of both a typical Convulutional Neural Network which feeds on images and from a Convolutional Neural Network which processes and classifies EEG signals, and compare the obtained representations in order to try and find a correlation between the two.

Let's start by importing `torch` (PyTorch).

In [3]:
import torch

In the original file, we defined the different arguments using the `argparse` library. In this notebook, we devote a whole cell to define all the arguments needed by our analysis.

In [None]:
EEG_DATASET_PATH = "data/eeg_5_95_std.pth"

SPLITS_PATH = "data/block_splits_by_image_all.pth"

# Leave this always to zero
SPLIT_NUM = 0

# Subject selecting
# Choose a subject from 1 to 6, default is 0 (all subjects)
SUBJECT = 0

# Time options: select from 20 to 460 samples from EEG data
TIME_LOW = 20
TIME_HIGH = 460

# Model type/options
# Specify which generator should be used. Available: lstm | EEGChannelNet
# It is possible to test out multiple deep classifiers:
#   - lstm is the model described in the paper 
#     "Deep Learning Human Mind for Automated Visual Classification”, CVPR 2017
#   - model10 is the model described in the paper 
#     "Decoding brain representations by multimodal learning of neural activity and visual features", TPAMI 2020
MODEL_TYPE = "lstm"

MODEL_PARAMS = ""
PRETRAINED_NET = ""

# Training options
BATCH_SIZE = 16
OPTIMIZER = "Adam"
LEARNING_RATE = 0.0001
LEARNING_RATE_DECAY_BY = 0.5
LEARNING_RATE_DECAY_EVERY = 10
DATA_WORKERS = 4
EPOCHS = 200

# Save options
SAVE_CHECK = 100

# Backend options
CUDA = False 
METAL = True

torch.utils.backcompat.broadcast_warning.enabled = True

if CUDA:
    torch.backends.cudnn.benchmark = True

Let's define our `EEGDataset` and `Splitter` classes as defined in the file [`eeg_signal_classification.py`](./eeg_signal_classification.py)

In [None]:
class EEGDataset:
    def __init__(self, eeg_signals_path, subject=0, time_low=20, time_high=460, model_type="lstm"):
        self.subject = subject
        self.time_low = time_low
        self.time_high = time_high
        self.model_type = model_type

        # Load EEG signals
        loaded = torch.load(eeg_signals_path)
        if subject != 0:
            self.data = [loaded["dataset"][i] for i in range(
                len(loaded["dataset"])) if loaded["dataset"][i]["subject"] == subject]
        else:
            self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]

        # Compute size
        self.size = len(self.data)

    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        eeg = self.data[i]["eeg"].float().t()
        eeg = eeg[self.time_low:self.time_high, :]

        if self.model_type == "model10":
            eeg = eeg.t()
            eeg = eeg.view(1, 128, self.time_high - self.time_low)
        # Get label
        label = self.data[i]["label"]
        
        return eeg, label

class Splitter:
    def __init__(self, dataset, split_path, split_num=0, split_name="train"):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        loaded = torch.load(split_path)
        self.split_idx = loaded["splits"][split_num][split_name]
        # Filter data
        self.split_idx = [i for i in self.split_idx if 450 <=
                          self.dataset.data[i]["eeg"].size(1) <= 600]
        # Compute size
        self.size = len(self.split_idx)

    def __len__(self):
        return self.size

    def __getitem__(self, i):
        # Get sample from dataset
        eeg, label = self.dataset[self.split_idx[i]]
        
        return eeg, label
        