In [1]:
import sys
import os
import glob
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
from scipy.fftpack import fft, rfft, fftfreq, irfft, ifft, rfftfreq
from scipy import signal
import numpy as np
import importlib
import cv2
class EEGDataset:
    
    # Constructor
    def __init__(self, eeg_signals_path):
        # Load EEG signals
        loaded = torch.load(eeg_signals_path)

        self.data   = loaded['dataset']        
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        
        # Compute size
        self.size = len(self.data)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        eeg = self.data[i]["eeg"].float().t()
        eeg = eeg[:,:]

        # Get label
        label = self.data[i]["label"]
        
        image = self.images[self.data[i]["image"]]
        # Return
        return eeg, image, label

# Splitter class
class Splitter:

    def __init__(self, dataset, split_path, split_num=0, split_name="train"):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        loaded = torch.load(split_path)
        self.split_idx = loaded["splits"][split_num][split_name]
        # Filter data
        self.split_idx = [i for i in self.split_idx if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600]
        # Compute size
        self.size = len(self.split_idx)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Get sample from dataset
        eeg, image, label = self.dataset[self.split_idx[i]]
        # Return
        return eeg, image, label


class ToTensor(object):

    r"""
    Make Image data to tensor type
    """

    def __call__(self, data:np.array) -> None:
        data = torch.from_numpy(data.transpose(2, 0, 1).astype(np.float32))
        return data



  warn(


In [2]:
eeg_signals_path = "/media/NAS/EEG2IMAGE/eeg_cvpr_2017/data/eeg_signals_raw_with_mean_std.pth"
img_path = '/media/NAS/EEG2IMAGE/eeg_cvpr_2017/image'
split_path = "/media/NAS/EEG2IMAGE/eeg_cvpr_2017/data/block_splits_by_image_all.pth"
# Load dataset
dataset = EEGDataset(eeg_signals_path = eeg_signals_path)
# Create loaders
loaders = {split: DataLoader(Splitter(dataset, split_path = split_path, 
                                    split_num = 0, 
                                    split_name = split), 1, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}

In [6]:
import os,sys
from tqdm.notebook  import tqdm

path = os.path.join(".","DATA","raw")
file_name  = eeg_signals_path.split("/")[-1].replace(".pth", "")

for split in ["train", "val", "test"]:
    for idx, (eeg, image, label) in tqdm(enumerate(loaders[split]), total = len(loaders[split]), desc = f"{split} data preprocessing..."):
        data = {"eeg":eeg.numpy().squeeze(), "image":image, "label":label.item()}
        torch.save(data, os.path.join(path, split, f"{file_name}_{idx}.pth"))


train data preprocessing...:   0%|          | 0/7959 [00:00<?, ?it/s]

val data preprocessing...:   0%|          | 0/1994 [00:00<?, ?it/s]

test data preprocessing...:   0%|          | 0/1987 [00:00<?, ?it/s]