# Pre-processing pipeline

In [1]:
import mne
import numpy
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures
import pywt
import scipy as sp
from sklearn.model_selection import StratifiedGroupKFold
import re

### configuration setup

In [2]:
mne.set_log_level('WARNING')

RAW_PATH = '/home/nis/Git/tuh-eeg-seizure-detection/data/raw'
OUTPUT_PATH = '/home/nis/Git/tuh-eeg-seizure-detection/data/processed_fast'

SAMPLING_FREQ = 250
WINDOW_LENGTH = 30
OVERLAP = 10 
CONFIGURATIONS = ["01_tcp_ar"]
CHANNELS = ["EEG FP1-REF", "EEG FP2-REF", "EEG F7-REF", "EEG F3-REF", "EEG F4-REF", "EEG F8-REF", "EEG T3-REF", "EEG C3-REF", "EEG C4-REF", "EEG T4-REF", "EEG T5-REF", "EEG P3-REF", "EEG P4-REF", "EEG T6-REF", "EEG O1-REF", "EEG O2-REF", "EEG CZ-REF", "EEG A1-REF", "EEG A2-REF"]

In [3]:
def split_channels_to_hemispheres(channels: list):
    left_hemisphere = []
    right_hemisphere = []
    
    for channel in channels:
        channel_number = re.search(r'\d+', channel)
        if channel_number is None:
            continue
        
        if int(channel_number.group()) % 2 == 0:
            right_hemisphere.append(channel)
        else:
            left_hemisphere.append(channel)
    
    return left_hemisphere, right_hemisphere

LEFT_HEMISPHERE, RIGHT_HEMISPHERE = split_channels_to_hemispheres(CHANNELS) 

### load raw data

#### extract events from annotations

In [4]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start_time = float(parts[1])
            stop_time = float(parts[2])
            label = parts[3]
            
            data.append({
                "label": label,
                "onset": start_time,
                "duration": stop_time - start_time
            })
            
            
    return data

#### load the TUH EEG dataset

In [13]:
def load_tuh_eeg():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "label", "onset", "duration"]
    data = []
    
    # get all edf files in RAW/edf
    edf_path = os.path.join(RAW_PATH, "edf")
    for root, dirs, files in os.walk(edf_path):
        for file in files:
            if file.endswith(".edf"):
                rel_path = os.path.relpath(root, edf_path)
                parts = rel_path.split("/")
                
                if len(parts) != 4:
                    continue
                    
                set_name, patient_id, session_id, configuration = parts
                
                if configuration not in CONFIGURATIONS:
                    continue
                
                recording_path = os.path.join(root, file)
                recording_id = file.replace(".edf", "").split("_")[-1]
                annotation_path = recording_path.replace(".edf", ".csv_bi")
                
                if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                    continue
                
                events = extract_events_from_annotations(annotation_path)
                for event in events:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "label": event["label"],
                        "onset": event["onset"],
                        "duration": event["duration"]
                    })
                    
    return pd.DataFrame(data, columns=cols)

data = load_tuh_eeg()
data.tail()

Unnamed: 0,set,patient_id,session_id,configuration,recording_id,recording_path,label,onset,duration
66,train,aaaaajrj,s006_2012,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,bckg,0.0,300.0
67,train,aaaaajrj,s006_2012,01_tcp_ar,t004,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,bckg,0.0,601.0
68,train,aaaaajrj,s006_2012,01_tcp_ar,t008,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,bckg,0.0,1224.0
69,train,aaaaajrj,s006_2012,01_tcp_ar,t006,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,seiz,2352.0111,35.9867
70,train,aaaaajrj,s006_2012,01_tcp_ar,t005,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,bckg,0.0,300.0


In [33]:
def load_windows(tuh_eeg_seizure_corpus_path: str, window_length: int, overlap: int, configurations: list):
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "event_index", "onset", "stop", "label"]
    data = []
    
    edf_path = os.path.join(RAW_PATH, "edf")
    
    for root, dirs, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
        
            set_name, patient_id, session_id, configuration = parts
            
            if configuration not in CONFIGURATIONS:
                continue
        
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue
            
            events = extract_events_from_annotations(annotation_path)
            
            for i, event in enumerate(events):
                start_time = event["onset"]
                stop_time = event["onset"] + event["duration"]

                if stop_time - start_time < WINDOW_LENGTH:
                    continue

                while start_time + WINDOW_LENGTH < stop_time:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "event_index": i,
                        "onset": start_time,
                        "stop": start_time + WINDOW_LENGTH,
                        "label": event["label"],
                    })
                    
                    start_time += WINDOW_LENGTH - OVERLAP

    return pd.DataFrame(data, columns=cols)

windows = load_windows(RAW_PATH, WINDOW_LENGTH, OVERLAP, CONFIGURATIONS)
windows
            

Unnamed: 0,set,patient_id,session_id,configuration,recording_id,recording_path,event_index,onset,stop,label
0,train,aaaaaozv,s001_2013,01_tcp_ar,t001,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,0.0,30.0,bckg
1,train,aaaaaozv,s001_2013,01_tcp_ar,t001,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,20.0,50.0,bckg
2,train,aaaaaozv,s001_2013,01_tcp_ar,t001,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,40.0,70.0,bckg
3,train,aaaaaozv,s001_2013,01_tcp_ar,t001,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,60.0,90.0,bckg
4,train,aaaaaozv,s001_2013,01_tcp_ar,t001,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,80.0,110.0,bckg
...,...,...,...,...,...,...,...,...,...,...
1521,train,aaaaajrj,s006_2012,01_tcp_ar,t005,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,180.0,210.0,bckg
1522,train,aaaaajrj,s006_2012,01_tcp_ar,t005,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,200.0,230.0,bckg
1523,train,aaaaajrj,s006_2012,01_tcp_ar,t005,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,220.0,250.0,bckg
1524,train,aaaaajrj,s006_2012,01_tcp_ar,t005,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,240.0,270.0,bckg


### train/test split

In [None]:
windows = []

for i, event in data.iterrows():
    patient_id = event["patient_id"]
    onset = event["onset"]
    duration = event["duration"]
    label = event["label"]

    num_windows = int(duration / WINDOW_LENGTH)

    if num_windows == 0:
        continue
        
    for i in range(num_windows):
        windows.append({
            "patient_id": patient_id,
            "onset": onset + i * WINDOW_LENGTH,
            "duration": WINDOW_LENGTH,
            "label": label
        })

windows = pd.DataFrame(windows)
x = np.array(windows["duration"])
y = np.array(windows["label"])
groups = np.array(windows["patient_id"])

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

splits = list(cv.split(x, y, groups))
test_fold_idx = np.random.choice(len(splits))

train_idx, test_idx = splits[test_fold_idx]

x_train, y_train = x[train_idx], y[train_idx]
x_test, y_test = x[test_idx], y[test_idx]

# Get the number of positive/negative samples in both train and test and their ratio
unique, counts = np.unique(y_train, return_counts=True)
train_ratio = counts[1] / counts[0]
print(f"Train - Unique: {unique} Counts: {counts}")
print(f"Train ratio: {train_ratio}")

unique, counts = np.unique(y_test, return_counts=True)
test_ratio = counts[1] / counts[0]
print(f"Test - Unique: {unique} Counts: {counts}")
print(f"Test ratio: {test_ratio}")


### feature extraction

In [38]:
def calc_coeffs_features(coeffs):
    mean = np.mean(coeffs)
    median = np.median(coeffs)
    std = np.std(coeffs)
    variance = np.var(coeffs)
    skew = sp.stats.skew(coeffs)
    kurtosis = sp.stats.kurtosis(coeffs)
    rms = np.sqrt(np.mean(coeffs ** 2))
    
    return {
        "mean": mean,
        "median": median,
        "variance": variance,
        "std": std,
        "skew": skew,
        "kurtosis": kurtosis,
        "rms": rms
    }
    
def extract_wavelet_features(channel_data: np.ndarray ) -> dict[str, float]:
    a5, d5, d4, d3, d2, d1 = pywt.wavedec(channel_data, 'db4', level=5)
    
    wavelet_features = {f"{coeff}_{stat}": value 
                        for coeff, data in zip(["a5", "d5", "d4", "d3"], [a5, d5, d4, d3])
                        for stat, value in calc_coeffs_features(data).items()}
    
    return wavelet_features

In [39]:
def extract_band_power(channel_data, sfreq=SAMPLING_FREQ, n_fft=256) -> dict[str, float]:
    frequency_bands = {
        "delta": (0.5, 4),
        "theta": (4, 7),
        "alpha": (7, 12),
        "beta": (12, 30),
        "gamma": (30, 50)
    }

    band_powers = {}

    n_fft = min(n_fft, channel_data.shape[-1])
    psds, freqs = mne.time_frequency.psd_array_welch(channel_data, sfreq=sfreq, n_fft=n_fft, fmin=0.5, fmax=50)

    # Calculate power within each frequency band
    for band, (fmin, fmax) in frequency_bands.items():
        # Find indices of frequencies within the band
        band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]

        # Sum the power spectral density values within the band
        band_power = np.sum(psds[band_indices])

        band_powers[band] = band_power

    return band_powers 


In [40]:
def calc_power_ratios(band_powers : dict[str, float]) -> dict[str, float]:
    alpha_beta_ratio = band_powers["alpha"] / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_beta_ratio = band_powers["theta"] / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_alpha_beta_ratio = (band_powers["theta"] + band_powers["alpha"]) / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_alpha_beta_alpha_ratio = (band_powers["theta"] + band_powers["alpha"]) / (band_powers["beta"] + band_powers["alpha"]) if (band_powers["beta"] + band_powers["alpha"]) != 0 else np.nan
    alpha_theta_ratio = band_powers["alpha"] / band_powers["theta"] if band_powers["theta"] != 0 else np.nan
    theta_alpha_ratio = band_powers["theta"] / band_powers["alpha"] if band_powers["alpha"] != 0 else np.nan
    
    return {
        "alpha_beta_ratio": alpha_beta_ratio,
        "theta_beta_ratio": theta_beta_ratio,
        "theta_alpha_beta_ratio": theta_alpha_beta_ratio,
        "theta_alpha_beta_alpha_ratio": theta_alpha_beta_alpha_ratio,
        "alpha_theta_ratio": alpha_theta_ratio,
        "theta_alpha_ratio": theta_alpha_ratio
    }

In [41]:
def calc_avg_band_powers(band_powers):
    avg_band_powers = {}
    for band in band_powers[0].keys():
        avg_band_powers[band] = np.mean([bp[band] for bp in band_powers])
        
    return avg_band_powers

In [42]:
def calc_asymmetry(band_powers):
    left_power = 0
    right_power = 0
    
    for i, channel in enumerate(CHANNELS):
        if channel in LEFT_HEMISPHERE:
            powers = list(band_powers[i].values())
            for power in powers:
                left_power += power
        elif channel in RIGHT_HEMISPHERE:
            powers = list(band_powers[i].values())
            for power in powers:
                right_power += power
            
    left_power = np.log(left_power) if left_power != 0 else 0
    right_power = np.log(right_power) if right_power != 0 else 0
    
    asymmetry = left_power - right_power
    return asymmetry


### pre-process the data

In [43]:
def remove_powerline_noise(raw):
    powerline_noises = [50, 60]

    for freq in powerline_noises:
        raw.notch_filter(freqs=freq)

    return raw

In [None]:
def preprocess(patient_id):
    corrupted = []
    windows = []
    
    # output directory
    #seizure_output_dir = os.path.join(OUTPUT_PATH, patient_id, "seizure")
    #non_seizure_output_dir = os.path.join(OUTPUT_PATH, patient_id, "non_seizure")
    #os.makedirs(seizure_output_dir, exist_ok=True)
    #os.makedirs(non_seizure_output_dir, exist_ok=True)
    
    patient_events = data[data["patient_id"] == patient_id]
    recordings = patient_events.groupby("recording_path")
    
    for recording_path, events in recordings:
        raw = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        
        # sometimes meas date is corrupted/missing
        raw.set_meas_date(None)
        
        events_onsets = events["onset"].values
        events_durations = events["duration"].values
        events_labels = events["label"].values
        
        # create annotations
        annotations = mne.Annotations(onset=events_onsets, duration=events_durations, description=events_labels)
        raw.set_annotations(annotations)
        
        for _, event in events.iterrows():
            patient_id = event["patient_id"]
            onset = event["onset"]
            duration = event["duration"]
            label = event["label"]
            
            min_windows = int(duration / WINDOW_LENGTH)
            
            if min_windows == 0:
                continue
                
            if onset + duration > raw.times[-1]:
                if onset + duration - 1/raw.info["sfreq"] == raw.times[-1]:
                    raw_event = raw.copy().crop(onset, raw.times[-1], include_tmax=True)
                else:
                    print("Corrupted annotation", patient_id, event["session_id"], event["recording_id"])
                    corrupted.append((patient_id, event["session_id"], event["recording_id"]))
                    continue
            else:
                raw_event = raw.copy().crop(onset, onset + duration, include_tmax=False)
                
            epochs = mne.make_fixed_length_epochs(raw_event, duration=WINDOW_LENGTH, overlap=OVERLAP, preload=True)
            
            # resample to desired sampling frequency
            raw_event.resample(SAMPLING_FREQ)
        
            # apply 4th order butterworth filter
            iir_params = dict(order=4, ftype='butter')
            raw_event.filter(0.5, 50, method='iir', iir_params=iir_params)
        
            # remove powerline noise
            raw_event = remove_powerline_noise(raw_event)
            
            for epoch in epochs:
                channels = epochs.info["ch_names"]
                
                epoch_features = {}
                band_powers = []
                
                for i, channel in enumerate(channels):
                    channel_data = epoch[i]
                    
                    # normalize channel data using min-max scaling
                    channel_data = (channel_data - np.min(channel_data)) / (np.max(channel_data) - np.min(channel_data))
                    
                    wavelet_features = extract_wavelet_features(channel_data)
                    band_power = extract_band_power(channel_data)
                    
                    for key, value in wavelet_features.items():
                        epoch_features[f"{channel}_{key}"] = value
                    
                    band_powers.append(band_power)
                
                avg_band_powers = calc_avg_band_powers(band_powers)
                power_ratios = calc_power_ratios(avg_band_powers)
                
                asymmetry = calc_asymmetry(band_powers) 
                
                epoch_features["asymmetry"] = asymmetry
                epoch_features["label"] = label
                epoch_features["patient_id"] = patient_id
                epoch_features = {**epoch_features, **avg_band_powers, **power_ratios}
                windows.append(epoch_features)
                
            #output_dir = seizure_output_dir if label == "seiz" else non_seizure_output_dir
            #file_name = f"{patient_id}_{event['session_id']}_{event["recording_id"]}_{i}_epo.fif"
            #epochs.save(os.path.join(output_dir, file_name), overwrite=True)
            raw_event.close()
            
        raw.close()
        
    return pd.DataFrame(windows), corrupted

In [44]:
def preprocess_recordings(patient: pd.DataFrame):
    corrupted = []
    windows = []

    recordings = patient.groupby("recording_path")
    
    for recording_path, events in recordings:
        raw = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        
        # sometimes meas date is corrupted/missing
        raw.set_meas_date(None)
        
        events_onsets = events["onset"].values
        events_stops = events["stop"].values
        events_durations = events_stops - events_onsets
        events_labels = events["label"].values
        
        # create annotations
        annotations = mne.Annotations(onset=events_onsets, duration=events_durations, description=events_labels)
        raw.set_annotations(annotations)
        
        for _, event in events.iterrows():
            patient_id = event["patient_id"]
            onset = event["onset"]
            stop = event["stop"]
            duration = stop - onset
            label = event["label"]
            
            if onset + duration > raw.times[-1]:
                if onset + duration - 1/raw.info["sfreq"] == raw.times[-1]:
                    raw_event = raw.copy().crop(onset, raw.times[-1], include_tmax=True)
                else:
                    print("Corrupted annotation", patient_id, event["session_id"], event["recording_id"])
                    corrupted.append((patient_id, event["session_id"], event["recording_id"]))
                    continue
            else:
                raw_event = raw.copy().crop(onset, onset + duration, include_tmax=False)
                
            epochs = mne.make_fixed_length_epochs(raw_event, duration=WINDOW_LENGTH, overlap=OVERLAP, preload=True)
            
            # resample to desired sampling frequency
            raw_event.resample(SAMPLING_FREQ)
        
            # apply 4th order butterworth filter
            iir_params = dict(order=4, ftype='butter')
            raw_event.filter(0.5, 50, method='iir', iir_params=iir_params)
        
            # remove powerline noise
            raw_event = remove_powerline_noise(raw_event)
            
            for epoch in epochs:
                channels = epochs.info["ch_names"]
                
                epoch_features = {}
                band_powers = []
                
                for i, channel in enumerate(channels):
                    channel_data = epoch[i]
                    
                    # normalize channel data using min-max scaling
                    channel_data = (channel_data - np.min(channel_data)) / (np.max(channel_data) - np.min(channel_data))
                    
                    wavelet_features = extract_wavelet_features(channel_data)
                    band_power = extract_band_power(channel_data)
                    
                    for key, value in wavelet_features.items():
                        epoch_features[f"{channel}_{key}"] = value
                    
                    band_powers.append(band_power)
                
                avg_band_powers = calc_avg_band_powers(band_powers)
                power_ratios = calc_power_ratios(avg_band_powers)
                
                asymmetry = calc_asymmetry(band_powers) 
                
                epoch_features["asymmetry"] = asymmetry
                epoch_features["label"] = label
                epoch_features["patient_id"] = patient_id
                epoch_features = {**epoch_features, **avg_band_powers, **power_ratios}
                windows.append(epoch_features)
                
            #output_dir = seizure_output_dir if label == "seiz" else non_seizure_output_dir
            #file_name = f"{patient_id}_{event['session_id']}_{event["recording_id"]}_{i}_epo.fif"
            #epochs.save(os.path.join(output_dir, file_name), overwrite=True)
            raw_event.close()
            
        raw.close()
        
    return pd.DataFrame(windows), corrupted

In [62]:
# group patient windows by recording path, aggregate onset to min and stop to max
patient_recordings = windows.groupby(["patient_id", "session_id", "recording_id", "recording_path", "event_index"]).agg({"onset": "min", "stop": "max", "label": "first"}).reset_index()
patient_ids = patient_recordings["patient_id"].unique()

for patient_id in patient_ids:

    p = patient_recordings[patient_recordings["patient_id"] == patient_id]

    print("Processing", patient_id)

    features, corrupted = preprocess_recordings(p)

    w = windows[windows["patient_id"] == patient_id]
    
    if features.shape[0] != w.shape[0]:
        print("Mismatch", features.shape[0], w.shape[0])
        
    if len(corrupted) > 0:
        print("Corrupted", corrupted)


Processing aaaaajrj
Processing aaaaamoq
Processing aaaaaozv


In [60]:
features.shape[0]

983

In [61]:
windows[windows["patient_id"] == patient_ids[0]].shape[0]

983