In [2]:
import sys
import os
import glob
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
from scipy.fftpack import fft, rfft, fftfreq, irfft, ifft, rfftfreq
from scipy import signal
import numpy as np
import importlib
import cv2

In [3]:
import torchvision
print(torchvision.__version__)
# print(torch.__version__)

0.18.0+cu121


In [3]:
eeg_signals_path = "/media/NAS/EEG2IMAGE/eeg_cvpr_2017/data/eeg_5_95_std.pth"
img_path = '/media/NAS/EEG2IMAGE/eeg_cvpr_2017/image'

In [4]:
class EEGDataset:
    
    # Constructor
    def __init__(self, eeg_signals_path, eeg_data_path):
        # Load EEG signals
        print("Start Load...")
        loaded = torch.load(eeg_signals_path)
        # if opt.subject!=0:
        #     self.data = [loaded['dataset'][i] for i in range(len(loaded['dataset']) ) if loaded['dataset'][i]['subject']==0]
        # else:
        self.data=loaded['dataset']        
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.image_path = eeg_data_path
        
        # Compute size
        self.size = len(self.data)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        # print(self.data[i].keys())
        eeg = self.data[i]["eeg"].float().t()
        eeg = eeg[20:460,:]

        # if opt.model_type == "model10":
        #     eeg = eeg.t()
        #     eeg = eeg.view(1,128,460-20)
        # Get label        
        label = self.data[i]["label"]

        # Get Original Image
        image = self.images[self.data[i]["image"]]

        subject = self.data[i]['subject']

        # Return
        return eeg, image, label, subject

# Splitter class
class Splitter:

    def __init__(self, dataset, split_path, split_num=0, split_name="train"):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        loaded = torch.load(split_path)
        self.split_idx = loaded["splits"][split_num][split_name]
        # Filter data
        self.split_idx = [i for i in self.split_idx if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600]
        # Compute size
        self.size = len(self.split_idx)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Get sample from dataset
        eeg, image, label, subject = self.dataset[self.split_idx[i]]
        # Return
        return eeg, image, label, subject


# Load dataset
dataset = EEGDataset(eeg_signals_path = eeg_signals_path,  eeg_data_path = img_path)
# Create loaders
loaders = {split: DataLoader(Splitter(dataset, split_path = "/media/NAS/EEG2IMAGE/eeg_cvpr_2017/data/block_splits_by_image_all.pth", 
                                      split_num = 0, 
                                      split_name = split), 1, drop_last = False, shuffle = False) for split in ["train", "val", "test"]}


Start Load...


In [5]:

class EEGPreDataset:

    # Constructo
    def __init__(self, eeg_pre_path, eeg_data_path, transforms=None):
        # Load EEG signals
        print("Start Load...")
        # loaded = torch.load(eeg_signals_path)

        # split_loaded = torch.load(split_path)
        # if opt.subject!=0:
        #     self.data = [loaded['dataset'][i] for i in range(len(loaded['dataset']) ) if loaded['dataset'][i]['subject']==0]
        # # else:
        # self.data=loaded['dataset']        
        # self.labels = loaded["labels"]
        # self.images = loaded["images"]
        self.image_path = eeg_data_path
        self.data = glob.glob(os.path.join(eeg_pre_path, "*"))

        # Compute size
        self.dataset_size = len(self.data)

        
        # Transforms
        self.transforms = transforms
        # self.to_tensor  = ToTensor()

    # Get size
    def __len__(self):
        return self.dataset_size

    # Get item
    def __getitem__(self, i):

        loaded = torch.load(self.data[i])
        # Process EEG
        eeg = loaded["eeg"]

        # Get label        
        label = loaded["label"]

        # Get Original Image
        image_name = loaded["image"]
        s, _ = image_name.split("_")
        image = torch.empty((224,224))
        if os.path.exists(os.path.join(self.image_path, s, image_name+".JPEG")):
            image = cv2.imread(os.path.join(self.image_path, s, image_name+".JPEG"))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.
        else:
            print(os.path.join(self.image_path, s, image_name+".JPEG"))

        if self.transforms:
            image = self.transforms(image)
        # 
        # image = self.to_tensor(image)
        
        # Return
        return eeg, image, label

dataset = EEGPreDataset(os.path.join(".","preprocessing_data","train") , img_path)
loaders = DataLoader(dataset, 1, drop_last = False, shuffle = False)

Start Load...


In [14]:
for l in loaders["train"]:
    print(l)
    test = l
    break

dict_keys(['eeg', 'image', 'label', 'subject'])
[tensor([[[-0.2232, -0.0730, -0.1207,  ..., -0.4589, -0.0330, -0.2386],
         [-0.2237, -0.0651, -0.1656,  ..., -0.4115, -0.0304, -0.2057],
         [-0.1934, -0.0134, -0.2735,  ..., -0.3696, -0.0285, -0.1856],
         ...,
         [ 0.3540,  0.2597,  0.3852,  ...,  0.3877,  0.0131,  0.1309],
         [ 0.3000,  0.1950,  0.4775,  ...,  0.3140,  0.0086,  0.0727],
         [ 0.2586,  0.1368,  0.5080,  ...,  0.2619,  0.0052,  0.0441]]]), ('n02951358_31190',), tensor([10]), tensor([4])]


In [15]:
l[-1].item()

4

In [7]:
import os,sys
from tqdm.notebook  import tqdm

path = os.path.join("/media/NAS/EEG2IMAGE/eeg_cvpr_2017","preprocessing_data")
file_name  = eeg_signals_path.split("/")[-1].replace(".pth", "")

for split in ["train", "val", "test"]:
    for idx, data in tqdm(enumerate(loaders[split]), total = len(loaders[split]), desc = f"{split} data preprocessing..."):
        subject = data[2].item()
        data = {"eeg":data[0].numpy().squeeze(), "image":data[1][0], "label":data[2].item(), "subject":data[3].item()}
        # torch.save(data, os.path.join(path, split, f"{file_name}_{idx}.pth"))
        if not os.path.exists(os.path.join(path, f"class_{subject}")): os.mkdir(os.path.join(path, f"class_{subject}"))
        torch.save(data, os.path.join(path, f"class_{subject}", f"{file_name}_{idx}.pth"))

    

train data preprocessing...:   0%|          | 0/7959 [00:00<?, ?it/s]

val data preprocessing...:   0%|          | 0/1994 [00:00<?, ?it/s]

test data preprocessing...:   0%|          | 0/1987 [00:00<?, ?it/s]

In [4]:
import os,sys
from tqdm.notebook  import tqdm
from random import shuffle
import shutil

frac = 0.8
path = glob.glob(os.path.join("/media/NAS/EEG2IMAGE/eeg_cvpr_2017","preprocessing_data", "by_class", "*"))
train_path = os.path.join("/media/NAS/EEG2IMAGE/eeg_cvpr_2017","preprocessing_data","train")
test_path = os.path.join("/media/NAS/EEG2IMAGE/eeg_cvpr_2017","preprocessing_data","test")
valid_path = os.path.join("/media/NAS/EEG2IMAGE/eeg_cvpr_2017","preprocessing_data","val")
tr_idx = 0
te_idx = 0
va_idx = 0

for i, p in tqdm(enumerate(path), desc=f"preprocessing...", total=len(path)):
    ep_lst = os.listdir(p)
    name   = p.split("/")[-1]
    shuffle(ep_lst)
    length = len(ep_lst)
    train_len  = int(length * frac)
    valid_len  = int(length * (1-frac)//2)
    for idx, t in enumerate(ep_lst[:train_len]):
        shutil.copy2(os.path.join(p, t), os.path.join(train_path, f"{name}_{tr_idx}.pth"))
        tr_idx +=1

    for idx, t in enumerate(ep_lst[train_len:train_len+valid_len]): 
        shutil.copy2(os.path.join(p, t), os.path.join(valid_path, f"{name}_{va_idx}.pth"))
        va_idx +=1

    for idx, t in enumerate(ep_lst[train_len+valid_len:]): 
        shutil.copy2(os.path.join(p, t), os.path.join(test_path, f"{name}_{te_idx}.pth"))
        te_idx +=1

preprocessing...:   0%|          | 0/40 [00:00<?, ?it/s]

In [37]:
p.split("/")[-1]

'class_31'

In [33]:
train_len + valid_len + valid_len

296

In [24]:
print(valid_len)
train_len

29


238

In [9]:
os.listdir(path)

['class_31',
 'class_5',
 'class_21',
 'class_20',
 'class_6',
 'class_2',
 'class_30',
 'class_37',
 'class_22',
 'class_23',
 'class_11',
 'class_18',
 'class_15',
 'class_26',
 'class_29',
 'class_39',
 'class_7',
 'class_24',
 'class_12',
 'class_9',
 'class_8',
 'class_36',
 'class_13',
 'class_25',
 'class_17',
 'class_4',
 'class_34',
 'class_16',
 'class_3',
 'class_33',
 'class_28',
 'class_27',
 'class_0',
 'class_10',
 'class_35',
 'class_19',
 'class_38',
 'class_14',
 'class_1',
 'class_32']

In [9]:
data['eeg'].shape

(440, 128)

In [None]:
torch.load(os.path.join(path, split, f"{file_name}_{idx}.pth"))

In [None]:
image_name = dataset.images[dataset.data[1]["image"]]
s, _ = image_name.split("_")
image = cv2.imread(os.path.join(dataset.image_path, s, image_name+".JPEG"))

In [5]:
for l in loaders:
    pass