# Import Required Libraries

In [None]:
from glob2 import glob
import mne
from collections import OrderedDict
import numpy as np
from mne.time_frequency import psd_array_welch

## Loading normal (non-epilepsy) & abnormal(epilepsy) EDF files from NMT DATASET

In [None]:
normal_edf_files = glob('Dataset/normal/train/*.edf')
abnormal_edf_files = glob('Dataset/abnormal/train/*.edf')

# Extract normal, abnormal patient IDs and combine them for All patients ID list

In [None]:
normal_patient_id = list(set([x.split('/')[-1].split('.')[0] for x in normal_edf_files]))
abnormal_patient_id = list(set([x.split('/')[-1].split('.')[0] for x in abnormal_edf_files]))
all_patient_id = abnormal_patient_id + normal_patient_id

# Sampling frequency is 200.0 hz according to NMT dataset. Window length of 10 second is chosen

In [1]:
SAMPLING_FREQ = 200.0
WINDOW_LENGTH_SECONDS = 20.0
WINDOW_LENGTH_SAMPLES = int(WINDOW_LENGTH_SECONDS * SAMPLING_FREQ)

dataset_index_rows = []


# Information from the EDF files are extracted, labaled and saved in a dictionary

In [None]:
for idx, patient in enumerate(all_patient_id):
    print(f"{patient}: {idx+1}/{len(all_patient_id)}\n\n")
    if patient in normal_patient_id:
        label = 'no_epilepsy'
        print(label)
        raw_file_path = f'Dataset/normal/{patient}.edf'
        print(raw_file_path)
        raw_data = mne.io.read_raw_edf(
            raw_file_path, verbose=False, preload=False)
    elif patient in abnormal_patient_id:
        label = 'epilespy'
        print(label)
        raw_file_path = f'Dataset/abnormal/{patient}.edf'
        print(raw_file_path)
        raw_data = mne.io.read_raw_edf(
            raw_file_path, verbose=False, preload=False)

    for start_sample_index in range(0, int(int(raw_data.times[-1]) * SAMPLING_FREQ), WINDOW_LENGTH_SAMPLES):
        end_sample_index = start_sample_index + (WINDOW_LENGTH_SAMPLES - 1)

        # ensure 10 seconds are available in window and recording does not end
        if end_sample_index > raw_data.n_times:
            break

    row = {}
    row["patient_id"] = patient
    row["raw_file_path"] = raw_file_path
    row["record_length_seconds"] = raw_data.times[-1]
    # this is the desired SFREQ using which sample indices are derived.
    # CAUTION - this is not the original SFREQ at which the data is recorded.
    row["sampling_freq"] = SAMPLING_FREQ
    row["channel_config"] = '02_tcp_le'
    row["start_sample_index"] = start_sample_index
    row["end_sample_index"] = end_sample_index
    row["text_label"] = label
    row["numeric_label"] = 0 if label == "no_epilepsy" else 1
    dataset_index_rows.append(row)


# Patient information is saved as `nmt_data.csv` under data folder

In [None]:
df = pd.DataFrame(dataset_index_rows, columns=["patient_id",
                                               "raw_file_path",
                                               "record_length_seconds",
                                               "sampling_freq",
                                               "channel_config",
                                               "start_sample_index",
                                               "end_sample_index",
                                               "text_label",
                                               "numeric_label"])
df.to_csv("nmt_data.csv", index=False)


# Sensor standardization function according to channel, highpass and noise removal function. Unused and kept for future application.

In [None]:
def standardize_sensors(raw_data):
	montage_sensor_set = ['FP1',
                       'FP2',
                       'F3',
                       'F4',
                       'C3',
                       'C4',
                       'P3',
                       'P4',
                       'O1',
                       'O2',
                       'F7',
                       'F8',
                       'T3',
                       'T4',
                       'T5',
                       'T6',
                       'FZ',
                       'PZ',
                       'CZ',
                       'A1',
                       'A2']
	
	raw_data = raw_data.pick_channels(montage_sensor_set, ordered=True)

	return raw_data


def highpass(raw_data, cutoff=1.0):
	mne.filter.filter_data(raw_data, sfreq=200.0, l_freq=cutoff, h_freq=None)
	return raw_data


def remove_line_noise(raw_data, ac_freqs=np.arange(50, 101, 50)):
	mne.filter.notch_filter(raw_data, freqs=ac_freqs, picks="eeg", verbose=False)
	return raw_data


# PSD extraction based on standard brain rhythm bands

In [None]:
def get_brain_waves_power(psd_welch, freqs):

	brain_waves = OrderedDict({
		"delta": [1.0, 4.0],
		"theta": [4.0, 7.5],
		"alpha": [7.5, 13.0],
		"lower_beta": [13.0, 16.0],
		"higher_beta": [16.0, 30.0],
		"gamma": [30.0, 40.0]
	})

	# create new variable you want to "fill": n_brain_wave_bands
	band_powers = np.zeros((psd_welch.shape[0], 6))

	for wave_idx, wave in enumerate(brain_waves.keys()):
		# identify freq indices of the wave band
		if wave_idx == 0:
			band_freqs_idx = np.argwhere((freqs <= brain_waves[wave][1]))
		else:
			band_freqs_idx = np.argwhere(
				(freqs >= brain_waves[wave][0]) & (freqs <= brain_waves[wave][1]))

		# extract the psd values for those freq indices
		band_psd = psd_welch[:, band_freqs_idx.ravel()]

		# sum the band psd data to get total band power
		total_band_power = np.sum(band_psd, axis=1)

		# set power in band for all sensors
		band_powers[:, wave_idx] = total_band_power

	return band_powers


# Loading the patient data file

In [None]:
index_df = pd.read_csv("data/nmt_data.csv")

grouped_df = index_df.groupby("raw_file_path")

num_channels = 21
power_bands = 6

feature_matrix = np.zeros((index_df.shape[0], num_channels*power_bands))

SAMPLING_FREQ = 200.0

# Extract Feautures from raw edf data & store them in data folder as `features.npy` and also labels are stored as `labels.npy`

In [None]:
for raw_file_path, group_df in grouped_df:
    raw_data = mne.io.read_raw_edf(raw_file_path, verbose=False, preload=True)
    raw_data = standardize_sensors(raw_data)
    print(raw_file_path)

    for window_idx in group_df.index.tolist():
        start_sample = group_df.loc[window_idx]['start_sample_index']
        stop_sample = group_df.loc[window_idx]['end_sample_index']
        window_data = raw_data.get_data(start=start_sample, stop=stop_sample)

        transf_window_data = np.expand_dims(window_data, axis=0)

        psd_welch, freqs = psd_array_welch(window_data, sfreq=SAMPLING_FREQ, fmax=50.0, n_per_seg=150,
                                           average='mean', verbose=False)
        psd_welch = 10 * np.log10(psd_welch)
        band_powers = get_brain_waves_power(psd_welch, freqs)

        features = band_powers.flatten()
        feature_matrix[window_idx, :] = features

np.save("data/features.npy", feature_matrix)
np.save("data/labels.npy", index_df["text_label"].to_numpy())