In [None]:
import mne
import numpy as np
import pandas as pd
from collections import OrderedDict
from mne.time_frequency import psd_array_welch

# Sensor standardization function according to channel, highpass and noise removal function. Unused and kept for future application.

In [None]:
def standardize_sensors(raw_data):
	montage_sensor_set = ['FP1',
                       'FP2',
                       'F3',
                       'F4',
                       'C3',
                       'C4',
                       'P3',
                       'P4',
                       'O1',
                       'O2',
                       'F7',
                       'F8',
                       'T3',
                       'T4',
                       'T5',
                       'T6',
                       'FZ',
                       'PZ',
                       'CZ',
                       'A1',
                       'A2']
	
	raw_data = raw_data.pick_channels(montage_sensor_set, ordered=True)

	return raw_data


def highpass(raw_data, cutoff=1.0):
	mne.filter.filter_data(raw_data, sfreq=200.0, l_freq=cutoff, h_freq=None)
	return raw_data


def remove_line_noise(raw_data, ac_freqs=np.arange(50, 101, 50)):
	mne.filter.notch_filter(raw_data, freqs=ac_freqs, picks="eeg", verbose=False)
	return raw_data


# PSD extraction based on standard brain rhythm bands

In [None]:
def get_brain_waves_power(psd_welch, freqs):

	brain_waves = OrderedDict({
		"delta": [1.0, 4.0],
		"theta": [4.0, 7.5],
		"alpha": [7.5, 13.0],
		"lower_beta": [13.0, 16.0],
		"higher_beta": [16.0, 30.0],
		"gamma": [30.0, 40.0]
	})

	# create new variable you want to "fill": n_brain_wave_bands
	band_powers = np.zeros((psd_welch.shape[0], 6))

	for wave_idx, wave in enumerate(brain_waves.keys()):
		# identify freq indices of the wave band
		if wave_idx == 0:
			band_freqs_idx = np.argwhere((freqs <= brain_waves[wave][1]))
		else:
			band_freqs_idx = np.argwhere(
				(freqs >= brain_waves[wave][0]) & (freqs <= brain_waves[wave][1]))

		# extract the psd values for those freq indices
		band_psd = psd_welch[:, band_freqs_idx.ravel()]

		# sum the band psd data to get total band power
		total_band_power = np.sum(band_psd, axis=1)

		# set power in band for all sensors
		band_powers[:, wave_idx] = total_band_power

	return band_powers


# Loading the patient data file

In [None]:
index_df = pd.read_csv("data/nmt_data.csv")

grouped_df = index_df.groupby("raw_file_path")

num_channels = 21
power_bands = 6

feature_matrix = np.zeros((index_df.shape[0], num_channels*power_bands))

SAMPLING_FREQ = 200.0

# Extract Feautures from raw edf data & store them in data folder as `features.npy` and also labels are stored as `labels.npy`

In [None]:
for raw_file_path, group_df in grouped_df:
    raw_data = mne.io.read_raw_edf(raw_file_path, verbose=False, preload=True)
    raw_data = standardize_sensors(raw_data)
    print(raw_file_path)

    for window_idx in group_df.index.tolist():
        start_sample = group_df.loc[window_idx]['start_sample_index']
        stop_sample = group_df.loc[window_idx]['end_sample_index']
        window_data = raw_data.get_data(start=start_sample, stop=stop_sample)

        transf_window_data = np.expand_dims(window_data, axis=0)

        psd_welch, freqs = psd_array_welch(window_data, sfreq=SAMPLING_FREQ, fmax=50.0, n_per_seg=150,
                                           average='mean', verbose=False)
        psd_welch = 10 * np.log10(psd_welch)
        band_powers = get_brain_waves_power(psd_welch, freqs)

        features = band_powers.flatten()
        feature_matrix[window_idx, :] = features

np.save("data/features.npy", feature_matrix)
np.save("data/labels.npy", index_df["text_label"].to_numpy())