In [1]:
import pandas as pd
import re
import os
import numpy as np
project_path = os.path.abspath(os.path.relpath('../../../../', os.getcwd()))
data_dir= os.path.join(project_path,'BilinearNetwork\Data\PreprocessedData\CHB-MIT\Prediction')
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

In [2]:
def get_files_plan(data_dir,patient_code,leave_one_code=None):
        files = os.listdir(data_dir)
        prelix = 'chb'+str(patient_code).zfill(2)
        files_filter= [f for f in files if re.match(prelix,f)]
        files_preictal = [f for f in files_filter if re.match('.*preictal.*',f)]
        files_interictal = [f for f in files_filter if re.match('.*interictal.*',f)]
        files_preictal_post=[]
        leave_outs=[]
        for i in range(len(files_preictal)):
            files_preictal_post.append(files_preictal[0:i]+files_preictal[i+1:])
            leave_outs.append(files_preictal[i])
        assert len(files_interictal)==1
        return {'interictal':files_interictal[0],'preictal':files_preictal_post[leave_one_code],'leave_out':leave_outs[leave_one_code]}
get_files_plan(data_dir,1,0)

In [3]:
import torch
from torch.utils.data import TensorDataset


class CHBDependentDMT(L.LightningDataModule):

    def __init__(self, data_dir: str ,patient_id:int,leave_out_id:int,batch_size:int=32):
        super().__init__()

        self.batch_size = batch_size
        self.data_dir = data_dir
        self.patient_id = patient_id
        self.leave_out_id = leave_out_id

        file_plan=self.get_files_plan(data_dir,patient_id,leave_out_id)
        interictal_name,preictal_list,leave_out_name=file_plan['interictal'],file_plan['preictal'],file_plan['leave_out']

        interictal_data=np.load(os.path.join(data_dir,interictal_name))
        preictal_data_list=[np.load(os.path.join(data_dir,f)) for f in preictal_list]
        preictal_data=np.concatenate(preictal_data_list)
        leave_out_val_data=np.load(os.path.join(data_dir,leave_out_name))

        np.random.shuffle(interictal_data)

        interictal_data_train,interictal_data_val=train_test_split(interictal_data,test_size=0.2)

        label_test=np.concatenate([np.zeros(len(interictal_data_val)),np.ones(len(leave_out_val_data))])
        data_test=np.concatenate([interictal_data_val,leave_out_val_data])

        label_fit=np.concatenate([np.zeros(len(interictal_data_train)),np.ones(len(preictal_data))])
        data_fit=np.concatenate([interictal_data_train,preictal_data])
        X_train,X_valid,y_train,y_valid=train_test_split(data_fit,label_fit,test_size=0.1,shuffle=True)

        print('Train Count: Negative {}, Positive {}'.format(len(X_train[y_train==0]),len(X_train[y_train==1])))
        print('Validation Count: Negative {}, Positive {}'.format(len(X_valid[y_valid==0]),len(X_valid[y_valid==1])))
        print('Test Count: Negative {}, Positive {}'.format(len(data_test[label_test==0]),len(data_test[label_test==1]))
        )

        self.trainset = TensorDataset(torch.tensor(X_train,dtype=torch.float32),torch.tensor(y_train,dtype=torch.float32))
        self.valset = TensorDataset(torch.tensor(X_valid,dtype=torch.float32),torch.tensor(y_valid,dtype=torch.float32))
        self.testset = TensorDataset(torch.tensor(data_test,dtype=torch.float32),torch.tensor(label_test,dtype=torch.float32))

        del interictal_data,preictal_data,leave_out_val_data,interictal_data_train,interictal_data_val,label_test,data_test,label_fit,data_fit,X_train,X_valid,y_train,y_valid

    def get_files_plan(self,data_dir,patient_code,leave_one_code=None):
        files = os.listdir(data_dir)
        prelix = 'chb'+str(patient_code).zfill(2)
        files_filter= [f for f in files if re.match(prelix,f)]
        files_preictal = [f for f in files_filter if re.match('.*preictal.*',f)]
        files_interictal = [f for f in files_filter if re.match('.*interictal.*',f)]
        files_preictal_post=[]
        leave_outs=[]
        for i in range(len(files_preictal)):
            files_preictal_post.append(files_preictal[0:i]+files_preictal[i+1:])
            leave_outs.append(files_preictal[i])
        assert len(files_interictal)==1
        return {'interictal':files_interictal[0],'preictal':files_preictal_post[leave_one_code],'leave_out':leave_outs[leave_one_code]}

    def prepare_data(self):
        pass

    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size,shuffle=True,pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size,shuffle=True,pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size,shuffle=True,pin_memory=True)

    def predict_dataloader(self):
        return None

In [4]:
dm=CHBDependentDMT(data_dir=data_dir,patient_id=1,leave_out_id=0,batch_size=32)