<a href="https://colab.research.google.com/github/Deep-Learning-Qatar/EEG-Vision/blob/main/Data_preparation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import matplotlib.pyplot as plt
import random

In [3]:
cuda = torch.cuda.is_available()
print('Cuda:', cuda)

Cuda: False


# Data Organisation

In [5]:
data_path = '/content/gdrive/MyDrive/11-785 Deep Learning/Project/eeg_55_95_std.pth'
data_dict = torch.load(data_path)
data_dict.keys()

dict_keys(['dataset', 'labels', 'images'])

In [10]:
# Load data and create dataset for each split
def split_train_val_test(data_dict, splits=(0.8, 0.1, 0.1)):

    # Find all possible image IDs
    image_ids = set()
    for di in data_dict['dataset']:
        image_ids.add(di['image'])
    # print(len(image_ids))

    # Organise data by image (key: image ID, val: list of data dicts with all data)
    data_by_image = dict()
    for id in image_ids:
        data_by_image[id] = []
    for di in data_dict['dataset']:
        image_id = di['image']
        data_by_image[image_id].append(di)

    # Shuffle data so selection for splits are random
    image_ids_li = list(image_ids)
    random.shuffle(image_ids_li)
    data_by_image = {id: data_by_image[id] for id in image_ids_li}
    
    # Create val and test sets
    data_len = len(image_ids_li)
    val_len, test_len = int(splits[1]*data_len), int(splits[2]*data_len)
    val_data, test_data = dict(), dict()
    for i in range(val_len):
        k, v = data_by_image.popitem()
        val_data[k] = v
    for i in range(test_len):
        k, v = data_by_image.popitem()
        test_data[k] = v
    train_data = data_by_image
    # print(train_data.keys())
    # print(val_data.keys())
    # print(test_data.keys())
    
    # Return all sets
    return train_data, val_data, test_data

In [None]:
# TODO: Make a split function where a number of classes are kept away from training

In [19]:
# General data set for EEG data dictionary
class EEGDataSet(data.Dataset):
    def __init__(self, data_dict, x_label='eeg', y_label='label', interval=(20, 460)):
        dl = list(data_dict.values())
        self.data_list = [item for sublist in dl for item in sublist]
        self.length = len(self.data_list)
        self.interval = interval
        self.x_label = x_label
        self.y_label = y_label
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        data_entry = self.data_list[index]
        x = data_entry[self.x_label]
        if self.x_label == 'eeg':
            x = x[:, self.interval[0]:self.interval[1]]
        else:
            x = torch.as_tensor(x).float()
        y = data_entry[self.y_label]
        if self.y_label == 'eeg':
            y = y[:, self.interval[0]:self.interval[1]]
        else:
            y = torch.as_tensor(y).long()
        return x, y

In [11]:
train_data, val_data, test_data = split_train_val_test(data_dict)

In [20]:
train_dataset = EEGDataSet(train_data)
train_loader_args = dict(shuffle=True, batch_size=16, num_workers=2) if cuda\
                    else dict(shuffle=True, batch_size=16)
train_loader = data.DataLoader(train_dataset, **train_loader_args)

val_dataset = EEGDataSet(val_data)
val_loader_args = dict(shuffle=False, batch_size=16, num_workers=2) if cuda\
                    else dict(shuffle=False, batch_size=16)
val_loader = data.DataLoader(val_dataset, **val_loader_args)

test_dataset = EEGDataSet(test_data)
test_loader_args = dict(shuffle=False, batch_size=16, num_workers=2) if cuda\
                    else dict(shuffle=False, batch_size=16)
test_loader = data.DataLoader(test_dataset, **test_loader_args)

In [23]:
# Check if loader works
for i, d in enumerate(train_loader):
    if i == 0:
        print(d[0].shape)
        print(d[1])
        break

# Correct output: 
# torch.Size([16, 128, 440])
# tensor([ list with 16 values between 0 and 39 ])

torch.Size([16, 128, 440])
tensor([11, 30, 19, 22, 15, 24, 37, 39, 28, 18, 24, 23, 19, 29,  6,  9])
