In [None]:
# extract patient ids from edf files
from glob2 import glob
import pandas as pd
import numpy as np
import mne 
import tqdm

# read the list of paths of edf files
glob_path_to_edf_files = "..\\tuh_eeg_epilepsy\\edf\\*epilepsy\\*\\*\\*\\*\\*.edf"
edf_file_list = glob(glob_path_to_edf_files)

# extract patient IDs from the file path, 
# create python set to extract unique elements from list,
# convert to list again 
unique_epilepsy_patient_ids = list(set([x.split("\\")[-1].split("_")[0] for x in edf_file_list]))


# create the index table and save it
unique_epilepsy_patient_ids = [x.strip() for x in unique_epilepsy_patient_ids]

# pick your desired preprocessing configuration.
SAMPLING_FREQ = 250.0
WINDOW_LENGTH_SECONDS = 60.0
WINDOW_LENGTH_SAMPLES = int(WINDOW_LENGTH_SECONDS * SAMPLING_FREQ)

# loop over one patient at a time, and add corresponding metadata to csv
dataset_index_rows = [ ]
label_count = { 
    "epilepsy": 0,
    "no_epilepsy": 0
}

for idx, patient_id in tqdm(enumerate(unique_epilepsy_patient_ids)):
  
    # find all edf files corresponding to this patient id
    patients_edf_file = f"..\\tuh_eeg_epilepsy\\edf\\*epilepsy\\*\\*\\{patient_id}\\*\\{patient_id}_*.edf"
    patient_edf_file_list = glob(patients_edf_file)
    assert len(patient_edf_file_list) >= 1
        
    # get label of the recording from the file name, ensure all labels 
    # for the same subject are the same
    # the label of the recording is copied to each of its windows
    labels = [x.split("\\")[3] for x in patient_edf_file_list]
    assert labels == [labels[0]]*len(labels)
        
    label = labels[0]
    label_count[label] += 1
    
    # keep only the first file per patient
    raw_file_path = patient_edf_file_list[0]
    raw_data = mne.io.read_raw_edf(raw_file_path, verbose=False, preload=False)
    
    # generate window metadata = one row of dataset_index
    window_iterator = range(0, int(int(raw_data.times[-1]) * SAMPLING_FREQ), WINDOW_LENGTH_SAMPLES) 
    for start_sample_index in window_iterator:

        end_sample_index = start_sample_index + (WINDOW_LENGTH_SAMPLES - 1)
        
        # ensure 10 seconds are available in window and recording does not end
        if end_sample_index > raw_data.n_times:
            break

        row = {}
        row["patient_id"] = patient_id
        row["raw_file_path"] = patient_edf_file_list[0]
        row["record_length_seconds"] = raw_data.times[-1]
        # this is the desired SFREQ using which sample indices are derived.
        # this is not the original SFREQ at which the data is recorded.
        row["sampling_freq"] = SAMPLING_FREQ
        row["channel_config"] = raw_file_path.split("\\")[4]
        row["start_sample_index"] = start_sample_index
        row["end_sample_index"] = end_sample_index
        row["text_label"] = label
        row["numeric_label"] = 0 if label == "no_epilepsy" else 1
        dataset_index_rows.append(row)
        
# create dataframe from rows and save it
df = pd.DataFrame(dataset_index_rows, 
                    columns=["patient_id", 
                             "raw_file_path",
                             "record_length_seconds", 
                             "sampling_freq",
                             "channel_config",
                             "start_sample_index",
                             "end_sample_index",
                             "text_label",
                             "numeric_label"])
df.to_csv(f"epilepsy_corpus_window_index_{str(int(WINDOW_LENGTH_SECONDS))}s.csv",
          index=False)

# functions for preprocessing the eeg data
def standardize_sensors(raw_data):
    # the TUEP database has 3 EEG channel configurations: 
    # '02_tcp_le', '03_tcp_ar_a', '01_tcp_ar'
	# number of channels and channel names differ within these configurations
	# to be able to compare the different EEG readings we need to select channels
	# that are common for all configurations

    # the list of 19 channels (their short labels) that we will use for analysing EEG data
    channels_to_use = ["FP1", "FP2", "F7", "F3", "FZ", "F4", "F8",
                          "T3", "C3", "CZ", "C4", "T4", "T5",
                          "P3", "PZ", "P4", "T6", "O1", "O2"]
    
	# the function to update channel names from original to new format:
    ch_name_update_func = lambda ch: ch.split(' ')[-1].split('-')[0]
    
    # renaming the original channel names in one .edf file;
    # the update will be written into the in-memory edf object
    raw_data.rename_channels(mapping=ch_name_update_func)

    raw_data = raw_data.pick_channels(channels_to_use, ordered=True)
    
     # check if all required channels are in the edf file
    try:
        assert all([ch in raw_data.info["ch_names"] for ch in channels_to_use])
    except:
        print('Not all required channels are in the edf file.')

    return raw_data

def downsample(raw_data, freq=250):
    raw_data = raw_data.resample(sfreq=freq, n_jobs=-2, verbose=False)
    return raw_data, freq


# compute mi and correlation matrices
index_df = pd.read_csv("epilepsy_corpus_window_index_60s.csv")
reduced_index_df = pd.DataFrame()

# we will use only up to 10 EEG windows per patient
n_windows_per_patient = 10

for p in index_df.patient_id.unique().tolist():
    reduced_index_df = pd.concat([reduced_index_df, 
                          index_df.query(f'patient_id == {p}').iloc[:n_windows_per_patient,:]])

reduced_index_df = reduced_index_df.reset_index()
reduced_index_df.to_csv('reduced_epilepsy_corpus_window_index_60s.csv', index=False)
grouped_df = reduced_index_df.groupby("raw_file_path")

num_channels = 19

mutual_info_matrix = np.zeros((reduced_index_df.shape[0], num_channels**2))
correlation_matrix = np.zeros((reduced_index_df.shape[0], num_channels**2))
all_window_data = []

# open up one raw_file at a time.
for raw_file_path, group_df in tqdm(grouped_df):    
    windows_list = group_df.index.tolist()

    raw_data = mne.io.read_raw_edf(raw_file_path, preload=True, verbose=False)
    raw_data = standardize_sensors(raw_data)
    raw_data, sfreq = downsample(raw_data, SAMPLING_FREQ)
    
    # data is ready for feature extraction, loop over windows, extract features
    for window_idx in windows_list:
     
        # get raw data for the window
        start_sample = group_df.loc[window_idx]['start_sample_index']
        stop_sample = group_df.loc[window_idx]['end_sample_index']
        window_data = raw_data.get_data(start=start_sample, stop=stop_sample)
        all_window_data.append(window_data)
        
        df = pd.DataFrame(window_data.T)
        corr_matrix = df.corr().values 
        corr_matrix_values = corr_matrix.reshape(1, num_channels**2)
        correlation_matrix[window_idx] = corr_matrix_values
        
        mi_matrix = compute_mi_matrix(df) 
        normed_mi_matrix = compute_normed_mi_matrix(mi_matrix)
        normed_mi_matrix_values = normed_mi_matrix.reshape(1, num_channels**2)   
        mutual_info_matrix[window_idx] = normed_mi_matrix_values
                
# save the features and labels as numpy array to disk
np.save("reduced_X_windows_epilepsy_corpus_60s.npy", np.array(all_window_data))
np.save("reduced_X_normed_mi_epilepsy_corpus_60s.npy", mutual_info_matrix)
np.save("reduced_X_correlation_epilepsy_corpus_60s.npy", correlation_matrix)
np.save("reduced_y_epilepsy_corpus_60s.npy", reduced_index_df["text_label"].to_numpy())
