In [1]:
# %%
import os

os.environ['TF_GPU_THREAD_MODE']='gpu_private'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from datetime import datetime
import sys
import numpy as np
import random
import operator
import matplotlib as plt
import pickle
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import Sequence


from sklearn.model_selection import KFold

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_memory_growth(gpus[0], True)
  except RuntimeError as e:
    print(e)

In [26]:
# %%
class autoencoder_generator(Sequence):
    def __init__(self,type_1_data, type_2_data, type_3_data, batch_size):
        
        self.type_1_data = type_1_data
        self.type_2_data = type_2_data
        self.type_3_data = type_3_data

        self.type_1_data_len = len(type_1_data)
        self.type_2_data_len = len(type_2_data)
        
        type_3_sampled_for_balance = type_3_data[np.random.choice(len(type_3_data), int((self.type_1_data_len + self.type_2_data_len)*1.5),replace=False)]
        self.type_3_data_len = len(type_3_sampled_for_balance)

        self.batch_num = int((self.type_1_data_len + self.type_2_data_len + self.type_3_data_len)/batch_size)

        self.type_1_batch_indexes = GetBatchIndexes(self.type_1_data_len, self.batch_num)
        self.type_2_batch_indexes = GetBatchIndexes(self.type_2_data_len, self.batch_num)
        self.type_3_batch_indexes = GetBatchIndexes(self.type_3_data_len, self.batch_num)

    def __len__(self):
        return self.batch_num
    
    def __getitem__(self, idx):
        input_seg = np.concatenate((self.type_1_data[self.type_1_batch_indexes[idx]], self.type_2_data[self.type_2_batch_indexes[idx]], self.type_3_data[self.type_3_batch_indexes[idx]]))
        X_batch = Segments2Data(input_seg)
        #X_batch = np.random.standard_normal((300,21,512))
        return X_batch, X_batch

# %%
if __name__=='__main__':
    window_size = 2
    overlap_sliding_size = 1
    normal_sliding_size = window_size
    state = ['preictal_ontime', 'ictal', 'preictal_late', 'preictal_early', 'postictal','interictal']

    # for WSL
    train_info_file_path = "/host/d/SNU_DATA/patient_info_train.csv"
    test_info_file_path = "/host/d/SNU_DATA/patient_info_test.csv"
    edf_file_path = "/host/d/SNU_DATA"

    ## for window
    # train_info_file_path = "D:/SNU_DATA/patient_info_train.csv"
    # test_info_file_path = "D:/SNU_DATA/patient_info_test.csv"
    # edf_file_path = "D:/SNU_DATA"


    train_interval_set = LoadDataset(train_info_file_path)
    train_segments_set = {}

    test_interval_set = LoadDataset(test_info_file_path)
    test_segments_set = {}

    # 상대적으로 데이터 갯수가 적은 것들은 window_size 2초에 sliding_size 1초로 overlap 시켜 데이터 증강
    for state in ['preictal_ontime', 'ictal', 'preictal_late', 'preictal_early']:
        train_segments_set[state] = Interval2Segments(train_interval_set[state],edf_file_path, window_size, overlap_sliding_size)
        test_segments_set[state] = Interval2Segments(test_interval_set[state],edf_file_path, window_size, overlap_sliding_size)
        

    for state in ['postictal', 'interictal']:
        train_segments_set[state] = Interval2Segments(train_interval_set[state],edf_file_path, window_size, normal_sliding_size)
        test_segments_set[state] = Interval2Segments(test_interval_set[state],edf_file_path, window_size, normal_sliding_size)

    # type 1은 True Label데이터 preictal_ontime
    # type 2는 특별히 갯수 맞춰줘야 하는 데이터
    # type 3는 나머지

    # AutoEncoder 단계에서는 1:1:3

    train_type_1 = np.array(train_segments_set['preictal_ontime'])
    train_type_2 = np.array(train_segments_set['ictal'] + train_segments_set['preictal_early'] + train_segments_set['preictal_late'])
    train_type_3 = np.array(train_segments_set['postictal'] + train_segments_set['interictal'])

    test_type_1 = np.array(test_segments_set['preictal_ontime'])
    test_type_2 = np.array(test_segments_set['ictal'] + test_segments_set['preictal_early'] + test_segments_set['preictal_late'])
    test_type_3 = np.array(test_segments_set['postictal'] + test_segments_set['interictal'])

    fold_n = 5

    kf = KFold(n_splits=5, shuffle=True)
    epochs = 100
    batch_size = 500   # 한번의 gradient update시마다 들어가는 데이터의 사이즈
    total_len = len(train_type_1)+len(train_type_2)
    total_len = int(total_len*2.5) # 데이터 비율 2:2:6

    type_1_kfold_set = kf.split(train_type_1)
    type_2_kfold_set = kf.split(train_type_2)
    type_3_kfold_set = kf.split(train_type_3)


    for _ in range(fold_n):
        (type_1_train_indexes, type_1_val_indexes) = next(type_1_kfold_set)
        (type_2_train_indexes, type_2_val_indexes) = next(type_2_kfold_set)
        (type_3_train_indexes, type_3_val_indexes) = next(type_3_kfold_set)
        if os.path.exists(f"./AutoEncoder_training_{_+1}"):
            if os.path.exists(f"./AutoEncoder_training_{_}"):
                autoencoder_model = tf.keras.models.create_model()
                autoencoder_model = tf.keras.models.load_model(f"/AutoEncoder_training_{_}/cp.ckpt")
            else:
                encoder_inputs = Input(shape=(21,512,1))
                encoder_outputs = FullChannelEncoder(encoded_feature_num=64,inputs = encoder_inputs)
                decoder_outputs = FullChannelDecoder(encoder_outputs)
                autoencoder_model = Model(inputs=encoder_inputs, outputs=decoder_outputs)
                autoencoder_model.compile(optimizer = 'Adam', loss='mse',)
            

        type_1_data_len = len(type_1_train_indexes)
        type_2_data_len = len(type_2_train_indexes)
        type_3_data_len = int((type_1_data_len + type_2_data_len)*1.5)
        train_batch_num = int((type_1_data_len + type_2_data_len + type_3_data_len)/batch_size)

        type_1_data_len = len(type_1_val_indexes)
        type_2_data_len = len(type_2_val_indexes)
        type_3_data_len = int((type_1_data_len + type_2_data_len)*1.5)
        val_batch_num = int((type_1_data_len + type_2_data_len + type_3_data_len)/batch_size)
        logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

        

        tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                                        histogram_freq = 1,
                                                        profile_batch = '1,20')
        checkpoint_path = f"AutoEncoder_training_{_}/cp.ckpt"
        checkpoint_dir = os.path.dirname(checkpoint_path)

        # Create a callback that saves the model's weights
        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                        save_best_only=True,
                                                        verbose=1)
        
        train_generator = autoencoder_generator(train_type_1[type_1_train_indexes], train_type_2[type_2_train_indexes], train_type_3[type_3_train_indexes],batch_size)
        validation_generator = autoencoder_generator(train_type_1[type_1_val_indexes], train_type_2[type_2_val_indexes], train_type_3[type_3_val_indexes],batch_size)
# %%
        history = autoencoder_model.fit_generator(
                    train_generator,
                    epochs = epochs,
                    steps_per_epoch =  train_batch_num,
                    validation_data = validation_generator,
                    validation_steps = val_batch_num,
                    use_multiprocessing=True,
                    workers=4,
                    callbacks= [ tboard_callback, cp_callback ]
                    )
        
        with open(f'./AutoEncoder_training_{_}/trainHistoryDict', 'wb') as file_pi:
            pickle.dump(history.history, file_pi)




In [2]:
import pandas as pd
from pyedflib.highlevel import read_edf, read_edf_header

In [3]:
WINDOW_SIZE = 2 # sec
OVERLAP = 0 # sec
FREQUENCY_CHB = 256 #Hz
ictal_section_name = ['ictal', 'preictal_late', 'preictal_ontime', 'preictal_early', 'postictal', 'interictal']

origin = pd.read_csv("./patient_info_chb_split.csv")

patient_info_chb_segment = list()

for state in ictal_section_name:
    total_current_state = origin[origin['state']==state]

    # short-term data with overlap
    if state in ['ictal', 'preictal_late', 'preictal_ontime', 'preictal_early']:
        OVERLAP = 1

    # long-term data without overlap
    if state in ["interictal", "postictal"]:
        OVERLAP = 0

    for current_state in total_current_state.itertuples():
        filename, start, end = current_state[1], current_state[2], current_state[3]

        if end-start<WINDOW_SIZE:
            continue
        
        dirname = filename.split('_')[0]
        filepath = os.path.join('./data/CHB/',dirname,filename+'.edf')
        
        header = read_edf_header(filepath)
        startdate = header['startdate']
        
        start_from_0sec = int(start-startdate.timestamp())*FREQUENCY_CHB
        end_from_0sec = int(end-startdate.timestamp())*FREQUENCY_CHB            

        step_size = (WINDOW_SIZE-OVERLAP)*FREQUENCY_CHB #index
        window_size_with_frequency = int(WINDOW_SIZE*FREQUENCY_CHB)
        for window_start_index in range(start_from_0sec, end_from_0sec, step_size):
            # ["name", "start", "duration", "state", "frequency"]
            current_segment = [filename, window_start_index, window_size_with_frequency, state, FREQUENCY_CHB]
            patient_info_chb_segment.append(current_segment)

In [4]:
print(len(patient_info_chb_segment))

1770539


In [5]:
df =pd.DataFrame(patient_info_chb_segment, columns=["name", "start", "duration", "state", "frequency"])

In [6]:
df[df['state']=='ictal'].shape

(257581, 5)

In [9]:
df1 = df.iloc[:df.shape[0]//2,:]
df2 = df.iloc[df.shape[0]//2:,:]

In [12]:
df1.to_csv('./patient_info_chb_segment_1.csv', index=False)

In [13]:
df2.to_csv('./patient_info_chb_segment_2.csv', index=False)

In [4]:
from pyedflib import EdfReader
from pyedflib.highlevel import read_edf

In [2]:
def read_segment(path, chn, start, n):
    """
    Returns the physical data of signal chn. When start and n is set, a subset is returned

    Parameters:	
    chn : int
    channel number

    start : int
    start pointer (default is 0)

    n : int
    length of data to read (default is None, by which the complete data of the channel are returned)

    digital: bool
    will return the signal in original digital values instead of physical values

    Examples
    ——–
    >>> import pyedflib
    >>> f = pyedflib.data.test_generator()
    >>> x = f.readSignal(0,0,1000)
    >>> int(x.shape[0])
    1000
    >>> x2 = f.readSignal(0)
    >>> int(x2.shape[0])
    120000
    >>> f.close()    
    """
    path = "./data/CHB/CHB001/CHB001_01.edf"
    f = EdfReader(path)
    segment = f.readSignal(chn,start,n,digital=True)
    

array([ 0.1953602 ,  0.1953602 ,  0.1953602 ,  0.1953602 ,  0.97680098,
        0.58608059, -1.75824176, -1.36752137,  2.93040293,  4.1025641 ])

In [6]:
f.close()

array([[-1.45934066e+02,  1.95360195e-01,  1.95360195e-01, ...,
        -1.15262515e+01, -2.93040293e+00,  1.93406593e+01],
       [-1.04517705e+02,  1.95360195e-01,  1.95360195e-01, ...,
         2.36385836e+01,  2.75457875e+01,  3.06715507e+01],
       [-4.27838828e+01,  1.95360195e-01,  1.95360195e-01, ...,
         4.86446886e+01,  4.51282051e+01,  3.45787546e+01],
       ...,
       [-2.64713065e+02,  1.95360195e-01,  5.86080586e-01, ...,
         9.76800977e-01, -1.58241758e+01, -2.94993895e+01],
       [ 9.47496947e+01,  1.95360195e-01,  1.95360195e-01, ...,
        -7.22832723e+00, -1.03540904e+01, -1.34798535e+01],
       [ 4.47374847e+01,  1.95360195e-01,  1.95360195e-01, ...,
         1.69963370e+01,  2.24664225e+01,  2.63736264e+01]])