In [1]:
import pyedflib

import numpy as np
import pandas as pd

import plotly.graph_objects as go

import glob
import os
import re
import random
from tqdm import tqdm

In [2]:
def get_patient_dict(patient:str, root_path:str) -> dict:
    '''
    Creates dictionary of files and seizures of a patient

    Parameters
    ----------
    patient : str
        identifier of the patient
    root_path : str
        path to the root directory of the scalp database

    Returns
    -------
    dict
        dictionary that contains the list of channels and times of seizures for each file
    '''
    info_file = open(root_path + patient + '/' + patient + '-summary.txt','r').readlines() # Open txt file
    patient_dict = {'channel_list': []} # Create empty dictionary
    for line in info_file: # Iterate over lines in txt file
        if(re.findall(r'(File Name: )\D*\d*(_)\d*(.edf)', line)): # If information about next file
            file = re.findall(r'(?:chb)\d*_\d*(?:.edf)', line)[0] # Get filename
            patient_dict[file] = {'seizure_start': [], 'seizure_end': []} # Create new sub-dict for new file
        elif(re.findall(r'Channel \d+', line)): # If channel description
            patient_dict['channel_list'].append(str(re.findall(r'Channel\s\d+:\s(\S*)', line)[0])) # Add channels to list
        elif(re.findall(r'Seizure Start Time|Seizure \d+ Start Time', line)): # If seizure start timestamp
            patient_dict[file]['seizure_start'].append(int(re.findall(r'(\d+)\sseconds', line)[0])) # Add seizure start to list
        elif(re.findall(r'Seizure End Time|Seizure \d+ End Time', line)): # If seizure end timestamp
            patient_dict[file]['seizure_end'].append(int(re.findall(r'(\d+)\sseconds', line)[0])) # Add seizure end to list
    return patient_dict

In [3]:
def get_labeled_file(file_path:str, channel_list:list, patient_dict:dict) -> pd.DataFrame:
    '''
    Converts a single file from edf-format to a pandas DataFrame and adds a label if a seizure is present

    Parameters
    ----------
    file_path : str
        relative path to the file
    channel_list : list
        list of the requested channels names
    patient_dict : dict
        dict that contains the seizure information for each file of the patient

    Returns
    -------
    pd.DataFrame
        pandas DataFrame that contains the requested channels with seizure labels
    '''
    edf_file = pyedflib.EdfReader(file_path) # Read edf file
    if not set(channel_list).issubset(set(edf_file.getSignalLabels())): # Check if all requested channels are present
        raise ValueError("File " + file_path + " does not contain requested channels!") # Raise error if not
    signal_data = np.zeros((edf_file.getNSamples()[0], len(channel_list))) # Create empty array for data
    for i, channel in enumerate(channel_list): # Iterate over channels
        signal_data[:, i] = edf_file.readSignal(edf_file.getSignalLabels().index(channel)) # Add channel data to array
    dataframe = pd.DataFrame(signal_data, columns=channel_list).astype('float32') # Create a dataframe from array
    dataframe["seconds"] = np.floor(np.linspace(0, len(dataframe)/edf_file.getSampleFrequencies()[0], len(dataframe), endpoint=False)).astype('uint16') # Add seconds column
    file_name = re.findall(r'([^\/]+$)', file_path)[-1] # Get name of file
    seizure_start_list = patient_dict.get(file_name).get("seizure_start") # Get list of seizure starts for file
    seizure_end_list = patient_dict.get(file_name).get("seizure_end") # Get list of seizure ends for file
    dataframe["seizure"] = 0 # Create new column for seizure labels
    if(len(seizure_start_list) > 0): # If seizures are present in file
        for seizure in range(len(seizure_start_list)): # Iterate over seizures
            start_second = seizure_start_list[seizure] # Get current start of seizure
            end_second = seizure_end_list[seizure] # Get current end of seizure
            dataframe.loc[dataframe["seconds"].between(start_second, end_second), "seizure"] = 1 # Label timeframe of seizure
    dataframe = dataframe.drop(columns=["seconds"]) # Drop seconds column
    return dataframe

In [4]:
def get_complete_patient_data(patient:str, channel_list:list, root_path:str) -> pd.DataFrame:
    '''
    Creates a pandas DataFrame that contains all requested channels of the complete eeg data of a patient

    Parameters
    ----------
    patient : str
        identifier of the patient
    channel_list : list
        list of the requested channels names
    root_path: str
        path to the root directory of the scalp database

    Returns
    -------
    pd.DataFrame
        pandas DataFrame that contains the complete labeled eeg data of one patient
    '''
    parent_path = root_path + patient # Get path of patients parent directory
    all_patient_files = sorted(glob.glob(os.path.join(parent_path , ("*.edf")))) # Get list of patients files
    all_patient_files = [ x for x in all_patient_files if "+" not in x ] # Clean file list
    patient_dict = get_patient_dict(patient=patient, root_path=root_path) # Get dict of patient data information
    concat_list = [] # Create empty list for files
    bar = tqdm(total=len(all_patient_files)) # Create progress bar
    for file in all_patient_files: # Iterate over all files
        try:
            concat_list.append(get_labeled_file(file_path=file, channel_list=channel_list, patient_dict=patient_dict)) # Get labeled dataframe of file
        except Exception as e:
            print(e)
        bar.update(1) # Update progress bar
    bar.close() # Close progress bar
    dataframe = pd.concat(concat_list, axis=0, ignore_index=True) # Combine all dataframes into one
    dataframe["patient"] = patient # Create column with patient identifier
    dataframe["timestamp"] = pd.date_range('1970-01-01 00:00:00', freq='3906250N', periods=len(dataframe)) # Add timestamp for later resampling
    return dataframe

In [5]:
def resample_dataframe(dataframe:pd.DataFrame, resample_freq:str, time_col:str) -> pd.DataFrame:
    '''
    Resamples a pandas DataFrame to the desired frequency

    Parameters
    ----------
   dataframe : pd.DataFrame
        Dataframe to be resampled
    resample_freq : str
        Target data frequency
    time_col : str
        Column that contains the timestamp

    Returns
    -------
    pd.DataFrame
        pandas DataFrame that contains the resampled data
    '''
    resampled = dataframe.resample(rule=resample_freq, on=time_col).agg("first")
    resampled = resampled.reset_index(drop=True)
    return resampled

In [6]:
def scalp_database_to_dataframe(patient_list:list, channel_list:list, root_path:str):
    '''
    Creates a pandas DataFrame that contains the complete data of one patient and saves the dataframe

    Parameters
    ----------
    patient_list : list
        list of all patient identifiers
    channel_list : list
        list of the requested channels names
    root_path: str
        path to the root directory of the scalp database

    Returns
    -------
    pd.DataFrame
        pandas DataFrame that contains the complete labeled eeg data of one patient
    '''
    for patient in patient_list:
        print("Processing Patient: " + patient)
        try:
            temp_df = get_complete_patient_data(patient, channel_list, root_path) # Create dataframe that contains the labeled data of a patient
            print("Resample Data")
            temp_df_resampled = resample_dataframe(temp_df, '10ms', 'timestamp') # Resample Dataset
            temp_df_resampled.to_pickle('../00_Data/Dataframes/' + patient + '.pkl') # Store dataframe as a pickel
        except Exception as e:
            print(e)
    return None

In [7]:
def get_valid_channels(patient_list:list, root_path:str) -> list:
    '''
    Creates a list of channels present for all patients in all files

    Parameters
    ----------
    patient_list : list
        list of all patient identifiers
    root_path: str
        path to the root directory of the scalp database

    Returns
    -------
    list
        list that contains the channels that are present for all files
    '''
    channel_list = []
    patients = patient_list
    for patient in patients:
        parent_path = root_path + patient
        all_patient_files = sorted(glob.glob(os.path.join(parent_path , ("*.edf"))))
        all_patient_files = [ x for x in all_patient_files if "+" not in x ]
        for file in all_patient_files:
            temp_file = pyedflib.EdfReader(file)
            channel_list.append(temp_file.getSignalLabels())
    elements_in_all = list(set.intersection(*map(set, channel_list)))
    return elements_in_all

In [None]:
root_path = '../00_Data/chb-mit-scalp-eeg-database-1.0.0/'
all_patients = sorted([patient for patient in os.listdir(root_path) if re.match(r'(chb)\d+', patient)])
all_patients.remove("chb12")
# channels = ['FP1-F7', 'C3-P3', 'C4-P4', 'CZ-PZ', 'F3-C3', 'F4-C4', 'F7-T7', 'F8-T8', 'FP1-F3', 'FP2-F4', 'FP2-F8', 'FT10-T8', 'FT9-FT10', 'FZ-CZ', 'P3-O1', 'P4-O2', 'P7-O1', 'P7-T7', 'P8-O2', 'T7-FT9', 'T7-P7', 'T8-P8-0', 'T8-P8-1']
# channels = get_valid_channels(patient_list=all_patients, root_path=root_path)
channels = get_valid_channels(patient_list=all_patients, root_path=root_path)

In [None]:
scalp_database_to_dataframe(patient_list=[all_patients[10]], channel_list=channels, root_path=root_path)

In [8]:
df = pd.read_pickle("../00_Data/Dataframes/chb11.pkl")

In [9]:
def create_balanced_time_windows(dataframe:pd.DataFrame, window_length:int, id_column:str, label_column:str, balance_ratio:float, step:int, extract_series:bool, random_state:int):
    unique_ids = dataframe[id_column].unique()
    for id in unique_ids:
        print("Processing patient: " + str(id))
        index_positive = list(dataframe[(dataframe[id_column] == id) & (dataframe[label_column] == 1)].index.values)[::step]
        index_negative = dataframe[((dataframe[id_column] == id) & (dataframe[label_column] == 0))].index
        random.seed(random_state)
        index_negative_sample = random.sample(list(index_negative), int(len(index_positive) * balance_ratio))
        sample_indices = list(index_positive + index_negative_sample)
        X = []
        y = []
        bar = tqdm(total=len(sample_indices))
        i = 0
        for index in sample_indices:
            end_index = index + window_length
            if (end_index <= len(dataframe)):
                if(dataframe[id_column].iloc[index] == dataframe[id_column].iloc[end_index]):
                    seq_x = dataframe.drop(columns=[id_column, label_column]).iloc[index:end_index].values.tolist()
                    if extract_series:
                        seq_y = dataframe[label_column].iloc[index:end_index]
                    else:
                        seq_y = dataframe[label_column].iloc[end_index]
                    X.append(seq_x)
                    y.append(seq_y)
            bar.update(1)
            i += 1
        bar.close()
    return X, y

In [10]:
features, labels = create_balanced_time_windows(df, 10000, "patient", "seizure", 1.5, 100, False, 43)

Processing patient: chb11


100%|██████████| 4045/4045 [15:51<00:00,  4.25it/s]  


In [16]:
len(labels)

4044

In [20]:
np_labels = np.array(labels)
np_features = np.array(features)

In [21]:
np.savez_compressed('../test', label=np_labels, features=np_features)

In [None]:
# Version 1 of Window Generation
# Issues:
#   - OOM-Exception
#   - Extremely ineffecient for big data
#   - Very imbalanced data


# def create_sliding_windows(dataframe:pd.DataFrame, window_size:int, id_col:str, label_col:str) -> tuple:
#     """
#     Function for the creation of time windows of certain size

#     Parameters
#     ----------
#     dataframe : pd.DataFrame
#         Dataframe containing all features and target variable as well as a unique identifier
#     window_size : int
#         Number of timesteps in a timewindow
#     id_col : str
#         Name of the column that contains the ids
#     label_col : str
#         Name of the column that is to predicted

#     Returns
#     -------
#     X : np.array
#         Array that contains all of the features of each time window
#     y : np.array
#         Array that contains all of the labels of each time window
#     """
#     X, y = list(), list() # Create empty lists for X and y
#     unique_ids = dataframe[id_col].unique()
#     bar = tqdm(total=(len(dataframe) - window_size + 1))
#     for i in unique_ids:
#         temp_df = dataframe.loc[dataframe[id_col] == i].reset_index().drop(columns="index")
#         for n in range(len(temp_df)): # Iterate over rows of temporary dataframe
#             end_ix = n + window_size # Calculate last idx of time window
#             if (end_ix <= len(temp_df)): # If last idx is still within temporary dataframe
#                 seq_x = temp_df.drop(columns=[id_col, label_col])[n:end_ix].values.tolist()
#                 seq_y = temp_df.loc[end_ix][label_col]
#                 X.append(seq_x) # Append X of current time window to global list
#                 y.append(seq_y) # Append y of current time window to global list
#             bar.update(1)
#     bar.close()
#     X = np.array(X) # Create array from global list of X
#     y = np.float_(np.array(y)) # Create array of type float from global list of y
#     return X, y

In [None]:
# Version 2 of Window Generation
# Issues
#   - Still ineffecient
#   - Unbalanced data

# def create_sliding_windows_step(dataframe:pd.DataFrame, window_size:int, id_col:str, label_col:str, step:int, y_series:bool) -> tuple:
#     """
#     Function for the creation of time windows of certain size

#     Parameters
#     ----------
#     dataframe : pd.DataFrame
#         Dataframe containing all features and target variable as well as a unique identifier
#     window_size : int
#         Number of timesteps in a timewindow
#     id_col : str
#         Name of the column that contains the ids
#     label_col : str
#         Name of the column that is to predicted

#     Returns
#     -------
#     X : np.array
#         Array that contains all of the features of each time window
#     y : np.array
#         Array that contains all of the labels of each time window
#     """
#     X, y = list(), list() # Create empty lists for X and y
#     unique_ids = dataframe[id_col].unique()
#     bar = tqdm(total=len(range(0, len(dataframe), step)))
#     for i in unique_ids:
#         temp_df = dataframe.loc[dataframe[id_col] == i].reset_index().drop(columns="index")
#         for n in range(0, len(temp_df), step): # Iterate over rows of temporary dataframe
#             end_ix = n + window_size # Calculate last idx of time window
#             if (end_ix <= len(temp_df)): # If last idx is still within temporary dataframe
#                 seq_x = temp_df.drop(columns=[id_col, label_col])[n:end_ix].values.tolist()
#                 if y_series:
#                     seq_y = temp_df.loc[n:end_ix][label_col].values.tolist()
#                 else:
#                     seq_y = temp_df.loc[end_ix][label_col]
#                 X.append(seq_x) # Append X of current time window to global list
#                 y.append(seq_y) # Append y of current time window to global list
#             bar.update(1)
#     bar.close()
#     X = np.array(X) # Create array from global list of X
#     y = np.float_(np.array(y)) # Create array of type float from global list of y
#     return X, y

In [None]:
# Version 3 of Window Generation
# Issues
#   - Memmap has size of >100GB

# def create_balanced_time_windows(dataframe:pd.DataFrame, window_length:int, id_column:str, label_column:str, balance_ratio:float, extract_series:bool, random_state:int):
#     unique_ids = dataframe[id_column].unique()
#     for id in unique_ids:
#         print("Processing patient: " + str(id))
#         index_positive = dataframe[(dataframe[id_column] == id) & (dataframe[label_column] == 1)].index
#         index_negative = dataframe[((dataframe[id_column] == id) & (dataframe[label_column] == 0))].index
#         random.seed(random_state)
#         index_negative_sample = random.sample(list(index_negative), int(len(index_positive) * float(balance_ratio)))
#         sample_indices = index_positive + index_negative_sample
#         X = np.memmap('../00_Data/Dataframes/' + str(id) + '_features.npy', np.float32, mode='w+', shape=(len(sample_indices), window_length, 18))
#         if extract_series:
#             y = np.memmap('../00_Data/Dataframes/' + str(id) + '_label.npy', np.int16, mode='w+', shape=(len(sample_indices), window_length))
#         else:
#             y = np.memmap('../00_Data/Dataframes/' + str(id) + '_label.npy', np.int16, mode='w+', shape=(len(sample_indices), 1))
#         bar = tqdm(total=len(sample_indices))
#         i = 0
#         for index in sample_indices:
#             end_index = index + window_length # Calculate last idx of time window
#             if (end_index <= len(dataframe)):
#                 if(dataframe[id_column].iloc[index] == dataframe[id_column].iloc[end_index]):
#                     seq_x = dataframe.drop(columns=[id_column, "timestamp", label_column]).iloc[index:end_index].values.tolist()
#                     if extract_series:
#                         seq_y = dataframe[label_column].iloc[index:end_index]
#                     else:
#                         seq_y = dataframe[label_column].iloc[end_index]
#                     X[i] = seq_x # Append X of current time window to global list
#                     y[i] = seq_y # Append y of current time window to global list
#             bar.update(1)
#             i += 1
#         bar.close()
#     return None