<a href="https://colab.research.google.com/github/SangMin316/EEG_Data/blob/main/220929Data_Loader_concat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install mne

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from torch.utils.data import Dataset
import torch
import pickle
import mne
import numpy as np
import time

class Sleepedf_dataset(Dataset):
    def __init__(self, files, seq_len):
        self.files = files
        self.sequence_length = seq_len
        
        # sample을 split해줬을 때 몇개로 split되는지 누적해서 저장, i번째 data를 찾을 때 data_adress 각 값이 기준이 됨
        data_adress = [0]
        ad = 0
        for i in range(len(self.files)):
            with open(file = files[i] , mode='rb') as f:
                sample = pickle.load(f)
            c,t = sample.shape
            t = int(c*t*2/self.sequence_length) 
            ad += t
          
            data_adress.append(ad)
        self.data_adress = data_adress


    def preprocessing(self, data):
        data = mne.filter.resample(data, up = 2.0) # upsampling to 200Hz  
        return data


    def split_data(self, data):
        L = self.sequence_length
        channels, length = data.shape
        a = L*int(length/L)
        
        if length == a:
            data = np.reshape(data,(int(length/L*channels),1,L))
        
        else: # data가 sequence_length로 나눠지지 않을때, 앞에 a부분만 취하겠다.
            data = data[:,:a]
            data = np.reshape(data,(int(a/L*channels),1,L))
        data = np.squeeze(data,1)
        return data
    
    def __getitem__(self, index):
        for i in range(len(self.data_adress)): 
            if index < self.data_adress[i]: # index 찾을 때 어떤 파일에서 찾아야 하는지 서치
                break
          
        with open(file = self.files[i-1] , mode='rb') as f:
            sample = pickle.load(f)
            
        sample = self.preprocessing(sample)
        sample = self.split_data(sample)
          
        return sample[index - self.data_adress[i-1],:]
          
    def __len__(self):
        return self.data_adress[-1]


In [3]:
class MASS_dataset(Dataset):
    def __init__(self, files, seq_len):
        self.files = files
        self.sequence_length = seq_len
        
        # sample을 split해줬을 때 몇개로 split되는지 누적해서 저장, i번째 data를 찾을 때 data_adress 각 값이 기준이 됨
        data_adress = [0]
        ad = 0
        for i in range(len(self.files)):
            with open(file = files[i] , mode='rb') as f:
                sample = pickle.load(f)
            c,t = sample.shape
            t = int(c*t*2/self.sequence_length)
            ad += t
          
            data_adress.append(ad)
        self.data_adress = data_adress

    def preprocessing(self,data):
        data = mne.filter.resample(data, down = 1.28) # downsampling to 200Hz  
        return data

    def split_data(self, data):
        L = self.sequence_length
        channels, length = data.shape
        a = L*int(length/L)
        
        if length == a:
            data = np.reshape(data,(int(length/L*channels),1,L))
        
        else:
            data = data[:,:a]
            data = np.reshape(data,(int(a/L*channels),1,L))
        data = np.squeeze(data,1)
        return data
    
    def __getitem__(self, index):
        for i in range(len(self.data_adress)):
            if index < self.data_adress[i]:
                break
          
        with open(file = self.files[i-1] , mode='rb') as f:
            sample = pickle.load(f)
            
        sample = self.preprocessing(sample)
        sample = self.split_data(sample)
          
        return sample[index - self.data_adress[i-1],:]
          
    def __len__(self):
        return self.data_adress[-1]

In [4]:
from sklearn.model_selection import train_test_split

class concat_dataset():
    def __init__(self, data_dic,seq_len):
        self.data_dic = data_dic #data_dic : {'dataset1_name : [dataset1_adress],,,datasetN_name : [datasetN_adress]}
        self.seq_len = seq_len
    
    def tr_val_te_split(self,data_list):
        train, test = train_test_split(data_list, test_size=0.2) #, shuffle=True, random_state=34), #stratify=target
        train, val = train_test_split(train, test_size=0.25) #,shuffle=True, random_state=34)
        del data_list
        print('split done')
        return train, val, test    
    
    def call(self):
        # train_dataset = [] # extend로 빈 어레의 받으면 메모리가 터지는 문제 발생했음.
        # val_dataset = []
        # test_dataset = []
    
        for name, data_list in self.data_dic.items():
            print(name)
            tr, val, te = self.tr_val_te_split(data_list)
            
            if name =='Sleep_edf':
                sleepedf_train_data = Sleepedf_dataset(tr,self.seq_len)
                print('sleep train done')
                sleepedf_val_data = Sleepedf_dataset(val,self.seq_len)
                print('sleep val done')
                sleepedf_test_data = Sleepedf_dataset(te,self.seq_len)
                print('sleep test done')
            
            elif name == 'MASS':
                MASS_train_data = Sleepedf_dataset(tr,self.seq_len)
                print('MASS train done')
                MASS_val_data = Sleepedf_dataset(val,self.seq_len)
                print('MASS val done')
                MASS_test_data = Sleepedf_dataset(te,self.seq_len)
                print('MASS test done')
        
            # train_dataset.extend(train_data)
            # val_dataset.extend(val_data)
            # test_dataset.extend(test_data)
            # print(train_data)

        # del train_data,val_data, test_data
    
        train_dataset = torch.utils.data.ConcatDataset([sleepedf_train_data,MASS_train_data])
        val_dataset = torch.utils.data.ConcatDataset([sleepedf_test_data,MASS_val_data])
        test_dataset = torch.utils.data.ConcatDataset([sleepedf_val_data,MASS_test_data])
    
        return train_dataset, val_dataset, test_dataset
        

In [5]:
import glob 
sleepedf_list = glob.glob('/content/drive/MyDrive/sleep_edfx/sleep_edfx_CT+SC/**')
print(len(sleepedf_list))

MASS_list = glob.glob('/content/drive/MyDrive/EEG_data/MASS/**')
print(len(MASS_list))

data_dic = {'MASS' : MASS_list, 'Sleep_edf': sleepedf_list}

151
3


MASS

In [10]:
from torch.utils.data import DataLoader

data = MASS_dataset(MASS_list,3000)

Sleep-edf

In [11]:
from torch.utils.data import DataLoader

trainLoader = DataLoader(data, batch_size = 1, shuffle=True)

전에는 sleep edf 10개의 sample만 올려도 메모리 터졌는데, 153개 올려도 RAM 적게 차지.

ConcatDataset을 이용하여 두 dataset 모으는 법

In [7]:
data2 = Sleepedf_dataset(sleepedf_list[:20],3000)

In [19]:
train_dataset = torch.utils.data.ConcatDataset([data,data2])

In [20]:
from torch.utils.data import DataLoader

trainLoader = DataLoader(train_dataset, batch_size = 2)

In [None]:
import time
start_T = time.time()
for batch_idx, batch in enumerate(trainLoader):
  print('batch_idx:',batch_idx,' ',batch.shape)
  end_T = time.time()
  print('time:', end_T - start_T)
  if batch_idx >= 2:
        break

위에서 정의한 class 이용해서 dataset concatenate

In [12]:
sleepedf_list = glob.glob('/content/drive/MyDrive/sleep_edfx/sleep_edfx_CT+SC/**')
print(len(sleepedf_list))

MASS_list = glob.glob('/content/drive/MyDrive/EEG_data/MASS/**')
print(len(MASS_list))

data_dic = {'MASS' : MASS_list, 'Sleep_edf': sleepedf_list}

151
3


In [13]:
train_dataset, val_dataset, test_dataset = concat_dataset(data_dic, seq_len = 3000).call()

MASS
split done
MASS train done
MASS val done
MASS test done
Sleep_edf
split done
sleep train done
sleep val done
sleep test done


In [8]:
from torch.utils.data import DataLoader

trainLoader = DataLoader(train_dataset, batch_size = 2, shuffle=True)

In [9]:
import time
start_T = time.time()
for batch_idx, batch in enumerate(trainLoader):
  print('batch_idx:',batch_idx,' ',batch.shape)
  end_T = time.time()
  print('time:', end_T - start_T)
  if batch_idx >= 2:
        break

batch_idx: 0   torch.Size([2, 3000])
time: 1021.0426411628723
batch_idx: 1   torch.Size([2, 3000])
time: 1573.0411252975464


KeyboardInterrupt: ignored

RAM문제 없이 데이터가 잘 로드 되지만, 데이터 길이가 길어서 preprocessing 하는데 시간이 너무 오래 걸림. --> 데이터 찾을 때마다 preprocessing하는 꼴

해결책, data를 짧게 짜르기(ex 30s) or 전처리한 파일로 저장 ( 적어도 sampling rate 통일 시켜서, resampling 하는데 시간이 오래 걸림)
