In [1]:
import pandas as pd
import re
import os
import numpy as np
project_path = os.path.abspath(os.path.relpath('../../../../', os.getcwd()))
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

In [7]:
import multiprocessing
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader

class CHBindependant_train(Dataset):
    def __init__(self, data_dir,num_leave_out=500):
        self.data_dir = data_dir
        self.file_list = os.listdir(data_dir)
        self.file_list.sort()
        self.num_samples_per_file = None
        self.leave_out = num_leave_out
        
        first_file = np.load(os.path.join(data_dir, self.file_list[0]))['data']
        self.num_samples_per_file = len(first_file)
        last_file= np.load(os.path.join(data_dir, self.file_list[-1]))['data']
        num_samples_last_file = len(last_file)
        del first_file, last_file
        self.total_num_samples = int(((len(self.file_list)-1) * (self.num_samples_per_file))+num_samples_last_file)
        
        self.lock = multiprocessing.Lock()
        self.current_file_idx=None
        self.current_data=None
        self.current_label=None
        
    def __len__(self):
        return self.total_num_samples
    
    def __getitem__(self, idx):
        file_idx = idx // self.num_samples_per_file
        sample_idx = idx % self.num_samples_per_file
        
        if sample_idx >= self.num_samples_per_file:
            raise ValueError('sample_idx out of range')
        elif sample_idx >= self.leave_out:
            sample_idx-=self.leave_out
        
        with self.lock:
            if self.current_file_idx != file_idx:
                current_file = np.load(os.path.join(self.data_dir, self.file_list[file_idx]),mmap_mode='r',allow_pickle=False)
                self.current_data = current_file['data']
                self.current_label = current_file['label']
                self.current_file_idx = file_idx
        data = self.current_data[sample_idx]
        label = self.current_label[sample_idx]
        return data, label

In [8]:
class CHBindependant_valid(Dataset):
    def __init__(self, data_dir,num_leave_out=500):
        self.file_list = os.listdir(data_dir)
        datas=[]
        labels=[]
        for file_name in self.file_list:
            file=np.load(os.path.join(data_dir, file_name),mmap_mode='r',allow_pickle=False)
            datas.append(file['data'][num_leave_out:])
            labels.append(file['label'][num_leave_out:])
        self.data = np.concatenate(datas)
        self.label = np.concatenate(labels)
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]


In [9]:

# 使用示例
data_dir = os.path.join(project_path, 'BilinearNetwork/Data/PreprocessedData/CHB-MIT/Concanate/Test')
# dataset_train = CHBindependant_train(data_dir)
dataset_test = CHBindependant_valid(data_dir)
dataloader = DataLoader(dataset_test, batch_size=32, shuffle=False, num_workers=0)

In [None]:
class CHBDependentDM(L.LightningDataModule):
    def __init__(self, root_dir: str ,batch_size:int=32):
        super().__init__()
        self.trainset = CHBindependant_train(os.path.join(root_dir, 'Train'))
        self.valset = CHBindependant_valid(os.path.join(root_dir, 'Train'))
        self.testset = None
        self.batch_size = batch_size
        self.root_dir = root_dir

        

    def prepare_data(self):
        pass

    def setup(self, stage: str):
        if stage=='test':
            self.testset = CHBindependant_train(os.path.join(self.root_dir, 'Test'))
        else:
            self.testset = None
        
        
    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size,shuffle=True,pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size,shuffle=False,pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size,shuffle=False,pin_memory=True)

    def predict_dataloader(self):
        return None