# Pre-processing pipeline

In [1]:
import mne
import numpy
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures
import pywt
import scipy as sp
from sklearn.model_selection import StratifiedGroupKFold
import re
from mne.preprocessing import ICA

### configuration setup

In [2]:
mne.set_log_level('WARNING')

RAW_PATH = '/dhc/home/jannis.hajda/tuh-eeg-seizure-detection/data/raw'
OUTPUT_PATH = '/dhc/home/jannis.hajda/tuh-eeg-seizure-detection/data/processed_fast'

SAMPLING_FREQ = 250
WINDOW_LENGTH = 21
OVERLAP = 10.5 
CONFIGURATIONS = ["01_tcp_ar"]
CHANNELS = ["EEG FP1-REF", "EEG FP2-REF", "EEG F7-REF", "EEG F3-REF", "EEG F4-REF", "EEG F8-REF", "EEG T3-REF", "EEG C3-REF", "EEG C4-REF", "EEG T4-REF", "EEG T5-REF", "EEG P3-REF", "EEG P4-REF", "EEG T6-REF", "EEG O1-REF", "EEG O2-REF", "EEG CZ-REF", "EEG A1-REF", "EEG A2-REF"]

In [3]:
def split_channels_to_hemispheres(channels: list):
    left_hemisphere = []
    right_hemisphere = []
    
    for channel in channels:
        channel_number = re.search(r'\d+', channel)
        if channel_number is None:
            continue
        
        if int(channel_number.group()) % 2 == 0:
            right_hemisphere.append(channel)
        else:
            left_hemisphere.append(channel)
    
    return left_hemisphere, right_hemisphere

LEFT_HEMISPHERE, RIGHT_HEMISPHERE = split_channels_to_hemispheres(CHANNELS) 

### load windows

In [4]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start = float(parts[1])
            stop = float(parts[2])
            label = parts[3]
            
            data.append({
                "label": label,
                "start": start,
                "stop": stop,
            })
            
    return pd.DataFrame(data)

In [5]:
def load_windows():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "event_index", "start", "stop", "label"]
    data = []
    
    edf_path = os.path.join(RAW_PATH, "edf")
    
    for root, _, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
        
            set_name, patient_id, session_id, configuration = parts
            
            if configuration not in CONFIGURATIONS:
                continue
        
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue
            
            events = extract_events_from_annotations(annotation_path)
            
            for i, event in events.iterrows():
                start, stop, label = event.loc[["start", "stop", "label"]]
                duration = stop - start

                if duration < WINDOW_LENGTH:
                    continue

                while start + WINDOW_LENGTH < stop:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "event_index": i,
                        "start": start,
                        "stop": start + WINDOW_LENGTH,
                        "label": label,
                    })
                    
                    start += WINDOW_LENGTH - OVERLAP

    return pd.DataFrame(data, columns=cols)

windows = load_windows()
windows 

Unnamed: 0,set,patient_id,session_id,configuration,recording_id,recording_path,event_index,start,stop,label
0,train,aaaaaiui,s001_2009,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,0.0000,21.0000,bckg
1,train,aaaaaiui,s001_2009,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,10.5000,31.5000,bckg
2,train,aaaaaiui,s001_2009,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,21.0000,42.0000,bckg
3,train,aaaaaiui,s001_2009,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,31.5000,52.5000,bckg
4,train,aaaaaiui,s001_2009,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,42.0000,63.0000,bckg
...,...,...,...,...,...,...,...,...,...,...
153705,eval,aaaaaiih,s006_2015,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,490.0134,511.0134,seiz
153706,eval,aaaaaiih,s006_2015,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,500.5134,521.5134,seiz
153707,eval,aaaaaiih,s006_2015,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,511.0134,532.0134,seiz
153708,eval,aaaaaiih,s006_2015,01_tcp_ar,t001,/dhc/home/jannis.hajda/tuh-eeg-seizure-detecti...,0,521.5134,542.5134,seiz


In [80]:
### undersample windows

In [6]:
seiz_windows = windows[windows["label"] == "seiz"]
bckg_windows = windows[windows["label"] == "bckg"]

print("Seizure windows:", len(seiz_windows))
print("Background windows:", len(bckg_windows))

Seizure windows: 12421
Background windows: 141289


In [7]:
# undersample majority class
bckg_windows = bckg_windows.sample(n=len(seiz_windows), random_state=42)
windows = pd.concat([seiz_windows, bckg_windows])

### feature extraction

In [8]:
def calc_coeffs_features(coeffs):
    mean = np.mean(coeffs)
    median = np.median(coeffs)
    std = np.std(coeffs)
    variance = np.var(coeffs)
    skew = sp.stats.skew(coeffs)
    kurtosis = sp.stats.kurtosis(coeffs)
    rms = np.sqrt(np.mean(coeffs ** 2))
    energy = np.sum(coeffs ** 2)
    
    return {
        "mean": mean,
        "median": median,
        "variance": variance,
        "std": std,
        "skew": skew,
        "kurtosis": kurtosis,
        "rms": rms,
        "energy": energy,
    }
    
def extract_wavelet_features(channel_data: np.ndarray ) -> dict[str, float]:
    a5, d5, d4, d3, d2, d1 = pywt.wavedec(channel_data, 'db4', level=5)
    
    wavelet_features = {f"{coeff}_{stat}": value 
                        for coeff, data in zip(["a5", "d5", "d4", "d3"], [a5, d5, d4, d3])
                        for stat, value in calc_coeffs_features(data).items()}
    
    return wavelet_features

In [9]:
def extract_band_power(channel_data, sfreq=SAMPLING_FREQ, n_fft=256) -> dict[str, float]:
    frequency_bands = {
        "delta": (0.5, 4),
        "theta": (4, 7),
        "alpha": (7, 12),
        "beta": (12, 30),
        "gamma": (30, 50)
    }

    band_powers = {}

    n_fft = min(n_fft, sfreq)
    psds, freqs = mne.time_frequency.psd_array_welch(channel_data, sfreq=sfreq, n_fft=n_fft, fmin=0.5, fmax=50)

    # Calculate power within each frequency band
    for band, (fmin, fmax) in frequency_bands.items():
        # Find indices of frequencies within the band
        band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]

        # Sum the power spectral density values within the band
        band_power = np.sum(psds[band_indices])

        band_powers[band] = band_power

    return band_powers 


In [10]:
def calc_power_ratios(band_powers : dict[str, float]) -> dict[str, float]:
    alpha_beta_ratio = band_powers["alpha"] / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_beta_ratio = band_powers["theta"] / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_alpha_beta_ratio = (band_powers["theta"] + band_powers["alpha"]) / band_powers["beta"] if band_powers["beta"] != 0 else np.nan
    theta_alpha_beta_alpha_ratio = (band_powers["theta"] + band_powers["alpha"]) / (band_powers["beta"] + band_powers["alpha"]) if (band_powers["beta"] + band_powers["alpha"]) != 0 else np.nan
    alpha_theta_ratio = band_powers["alpha"] / band_powers["theta"] if band_powers["theta"] != 0 else np.nan
    theta_alpha_ratio = band_powers["theta"] / band_powers["alpha"] if band_powers["alpha"] != 0 else np.nan
    
    return {
        "alpha_beta_ratio": alpha_beta_ratio,
        "theta_beta_ratio": theta_beta_ratio,
        "theta_alpha_beta_ratio": theta_alpha_beta_ratio,
        "theta_alpha_beta_alpha_ratio": theta_alpha_beta_alpha_ratio,
        "alpha_theta_ratio": alpha_theta_ratio,
        "theta_alpha_ratio": theta_alpha_ratio
    }

In [23]:
def calc_avg_band_powers(band_powers):
    avg_band_powers = {}    
    bands = band_powers[0].keys()
     
    for band in bands:
        avg_band_powers[band] = np.mean([channel[band] for channel in band_powers])
        
    return avg_band_powers

In [24]:
def calc_asymmetry(band_powers):
    left_power = 0
    right_power = 0
    
    for i, channel in enumerate(CHANNELS):
        if channel in LEFT_HEMISPHERE:
            powers = list(band_powers[i].values())
            for power in powers:
                left_power += power
        elif channel in RIGHT_HEMISPHERE:
            powers = list(band_powers[i].values())
            for power in powers:
                right_power += power
            
    left_power = np.log(left_power) if left_power != 0 else 0
    right_power = np.log(right_power) if right_power != 0 else 0
    
    asymmetry = left_power - right_power
    return asymmetry

### preprocessing

In [25]:
def remove_powerline_noise(raw):
    powerline_noises = [60]

    for freq in powerline_noises:
        raw.notch_filter(freqs=freq)

    return raw

In [26]:
def butterworth_filter(raw):
    iir_params = dict(order=4, ftype='butter')
    raw.filter(0.5, 50, method='iir', iir_params=iir_params)
    return raw

In [27]:
def min_max_normalization(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

In [28]:
def crop_raw_event(raw, start, stop):
    """Crops the raw data based on the onset and duration, handling edge cases."""
    if stop > raw.times[-1]:
        if stop - 1 / raw.info["sfreq"] == raw.times[-1]:
            return raw.copy().crop(start, raw.times[-1], include_tmax=True), True
        else:
            return None, False
    else:
        return raw.copy().crop(start, stop, include_tmax=False), True

### process windows

In [32]:
def process_windows(windows: pd.DataFrame):
    corrupted = []
    features = []

    recordings = windows.groupby("recording_path")
    
    for recording_path, recording in recordings:
        raw_recording = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        raw_recording.set_meas_date(None)

        raw_recording = butterworth_filter(raw_recording)
        raw_recording = remove_powerline_noise(raw_recording)
        raw_recording = raw_recording.resample(SAMPLING_FREQ)

        for _, window in recording.iterrows():
            patient_id, session_id, recording_id, event_index, start, stop, label = window.loc[["patient_id", "session_id", "recording_id", "event_index", "start", "stop", "label"]]

            raw_window, success = crop_raw_event(raw_recording, start, stop)
            if not success:
                print(f"Failed to crop event {event_index} in recording {recording_id} for patient {patient_id} and session {session_id}.")
                corrupted.append((patient_id, session_id, recording_id))
                continue

            channels = raw_window.info["ch_names"]

            window_features = {}

            band_powers = {}

            for i, channel in enumerate(channels):
                channel_data, _ = raw_window[i]
                channel_data = channel_data.flatten()
                
                channel_data = min_max_normalization(channel_data)

                wavelet_features = extract_wavelet_features(channel_data)
                for key, value in wavelet_features.items():
                    window_features[f"{channel}_{key}"] = value

                band_power = extract_band_power(channel_data)

                band_powers[channel] = band_power

            left_band_powers = [band_powers[channel] for channel in LEFT_HEMISPHERE]
            right_band_powers = [band_powers[channel] for channel in RIGHT_HEMISPHERE]
           
            left_avg_band_powers = calc_avg_band_powers(left_band_powers)
            right_avg_band_powers = calc_avg_band_powers(right_band_powers)

            left_avg_band_powers = {f"left_{key}": value for key, value in left_avg_band_powers.items()}
            right_avg_band_powers = {f"right_{key}": value for key, value in right_avg_band_powers.items()}
            
            #avg_band_powers = calc_avg_band_powers(band_powers.items())
            #power_ratios = calc_power_ratios(avg_band_powers)
            
            #asymmetry = calc_asymmetry(band_powers)
            
            window_features.update({
                "patient_id": patient_id,
                #"asymmetry": asymmetry,
                "label": label,
            })
            
            window_features = {**window_features, **left_avg_band_powers, **right_avg_band_powers}
            features.append(window_features)
                
            raw_window.close()
            
        raw_recording.close()
        
    return pd.DataFrame(features), corrupted

In [33]:
patient_ids = windows["patient_id"].unique()
patient_windows = windows[windows["patient_id"] == patient_ids[0]]
process_windows(patient_windows)

KeyError: 0

In [32]:
import multiprocessing as mp
from tqdm import tqdm

def process_patient_windows(patient_windows):
    patient_id = patient_windows["patient_id"].iloc[0]
    
    try:
        features, corrupted = process_windows(patient_windows)

        if patient_windows.shape[0] != features.shape[0]:
            print(f"Window mismatch for patient {patient_id}")

        if corrupted:
            print(f"Corrupted patient {patient_id}")

        return patient_id, features, corrupted
    except Exception as e:
        print(f"Error processing patient {patient_id}: {e}")
        return patient_id, None, None

def process_windows_parallel(windows, num_processes=None):
    patient_ids = windows["patient_id"].unique()
    
    if num_processes is None:
        num_processes = mp.cpu_count()

    manager = mp.Manager()
    queue = manager.Queue()
    
    def update_progress(result):
        queue.put(1)
    
    def listener(q, total):
        pbar = tqdm(total=total, desc="Processing patients")
        for _ in range(total):
            q.get()
            pbar.update()
        pbar.close()
    
    with mp.Pool(num_processes) as pool:
        results = []
        
        listener_process = mp.Process(target=listener, args=(queue, len(patient_ids)))
        listener_process.start()
        
        for patient_id in patient_ids:
            patient_windows = windows[windows["patient_id"] == patient_id]
            result = pool.apply_async(process_patient_windows, args=(patient_windows,), callback=update_progress)
            results.append(result)
        
        pool.close()
        pool.join()
        
        queue.put(None) 
        listener_process.join()

        processed_results = [result.get() for result in results]
    
    all_features = []
    for patient_id, features, corrupted in processed_results:
        if features is not None:
            all_features.append(features)

    if all_features:
        aggregated_features = pd.concat(all_features, ignore_index=True)
        return aggregated_features
    else:
        return pd.DataFrame()
   
features = process_windows_parallel(windows)


Processing patients:   0%|          | 0/357 [00:00<?, ?it/s]

Error processing patient aaaaapas: 0
Error processing patient aaaaaovm: 0

Processing patients:   0%|          | 1/357 [00:02<15:57,  2.69s/it]


Error processing patient aaaaaqro: 0Error processing patient aaaaandx: 0Error processing patient aaaaamyy: 0
Error processing patient aaaaagpk: 0

Error processing patient aaaaakkm: 0



Processing patients:   1%|          | 3/357 [00:02<04:21,  1.35it/s]

Error processing patient aaaaaosc: 0
Error processing patient aaaaatlp: 0Error processing patient aaaaaocy: 0


Processing patients:   3%|▎         | 9/357 [00:02<01:09,  5.01it/s]


Error processing patient aaaaacyf: 0Error processing patient aaaaabxe: 0

Error processing patient aaaaaron: 0Error processing patient aaaaaajy: 0



Processing patients:   4%|▎         | 13/357 [00:03<00:43,  7.83it/s]

Error processing patient aaaaaqrs: 0
Error processing patient aaaaaooy: 0


Processing patients:   4%|▍         | 16/357 [00:03<00:33, 10.05it/s]

Error processing patient aaaaaqkh: 0
Error processing patient aaaaaasy: 0
Error processing patient aaaaaltg: 0


Processing patients:   5%|▌         | 19/357 [00:03<00:34,  9.85it/s]

Error processing patient aaaaadsz: 0
Error processing patient aaaaaqtl: 0
Error processing patient aaaaanmh: 0
Error processing patient aaaaampk: 0
Error processing patient aaaaarqt: 0

Processing patients:   6%|▌         | 22/357 [00:03<00:27, 12.15it/s]


Error processing patient aaaaaoqf: 0
Error processing patient aaaaaogk: 0

Processing patients:   7%|▋         | 25/357 [00:03<00:24, 13.56it/s]

Error processing patient aaaaaedy: 0

Error processing patient aaaaaoya: 0Error processing patient aaaaaosa: 0
Error processing patient aaaaaroo: 0



Processing patients:   8%|▊         | 29/357 [00:03<00:19, 17.13it/s]

Error processing patient aaaaaqtx: 0
Error processing patient aaaaarqh: 0


Processing patients:   9%|▉         | 32/357 [00:04<00:21, 15.45it/s]

Error processing patient aaaaajnw: 0
Error processing patient aaaaaoiz: 0
Error processing patient aaaaatvr: 0Error processing patient aaaaatds: 0



Processing patients:  10%|▉         | 35/357 [00:04<00:25, 12.86it/s]

Error processing patient aaaaanwk: 0


Processing patients:  10%|█         | 37/357 [00:04<00:23, 13.72it/s]

Error processing patient aaaaamnd: 0
Error processing patient aaaaatdt: 0


Processing patients:  11%|█         | 39/357 [00:04<00:26, 12.06it/s]

Error processing patient aaaaaghb: 0Error processing patient aaaaalpn: 0

Error processing patient aaaaapxw: 0

Processing patients:  11%|█▏        | 41/357 [00:05<00:28, 11.25it/s]


Error processing patient aaaaaoxa: 0


Processing patients:  12%|█▏        | 43/357 [00:05<00:31,  9.87it/s]

Error processing patient aaaaabiw: 0
Error processing patient aaaaaovn: 0
Error processing patient aaaaarei: 0


Processing patients:  13%|█▎        | 45/357 [00:05<00:36,  8.59it/s]

Error processing patient aaaaankc: 0
Error processing patient aaaaaelb: 0Error processing patient aaaaaouk: 0
Error processing patient aaaaajoe: 0


Processing patients:  13%|█▎        | 48/357 [00:05<00:36,  8.58it/s]


Error processing patient aaaaaoek: 0


Processing patients:  14%|█▍        | 51/357 [00:06<00:27, 11.08it/s]

Error processing patient aaaaammu: 0
Error processing patient aaaaajsl: 0


Processing patients:  15%|█▍        | 53/357 [00:06<00:40,  7.56it/s]

Error processing patient aaaaatba: 0
Error processing patient aaaaaror: 0


Processing patients:  15%|█▌        | 55/357 [00:06<00:34,  8.79it/s]

Error processing patient aaaaapks: 0Error processing patient aaaaaqtq: 0

Error processing patient aaaaasjh: 0Error processing patient aaaaaoys: 0



Processing patients:  16%|█▌        | 58/357 [00:07<00:31,  9.48it/s]

Error processing patient aaaaarso: 0
Error processing patient aaaaanrc: 0
Error processing patient aaaaaijh: 0

Processing patients:  17%|█▋        | 61/357 [00:07<00:24, 12.14it/s]


Error processing patient aaaaaqjn: 0
Error processing patient aaaaamoa: 0Error processing patient aaaaambs: 0

Processing patients:  18%|█▊        | 63/357 [00:07<00:22, 12.85it/s]



Error processing patient aaaaanjw: 0


Processing patients:  18%|█▊        | 66/357 [00:07<00:18, 15.96it/s]

Error processing patient aaaaamck: 0Error processing patient aaaaaqbt: 0

Error processing patient aaaaaswt: 0Error processing patient aaaaaizb: 0



Processing patients:  19%|█▉        | 69/357 [00:07<00:16, 17.03it/s]

Error processing patient aaaaarpv: 0
Error processing patient aaaaamiz: 0Error processing patient aaaaatao: 0

Error processing patient aaaaaarq: 0

Processing patients:  20%|██        | 73/357 [00:07<00:13, 21.01it/s]


Error processing patient aaaaarcs: 0Error processing patient aaaaalzg: 0

Error processing patient aaaaaayg: 0

Processing patients:  21%|██▏       | 76/357 [00:07<00:14, 20.05it/s]


Error processing patient aaaaaiat: 0
Error processing patient aaaaamoq: 0
Error processing patient aaaaapcr: 0

Processing patients:  22%|██▏       | 79/357 [00:07<00:13, 20.44it/s]


Error processing patient aaaaakbz: 0
Error processing patient aaaaaict: 0Error processing patient aaaaalak: 0



Processing patients:  23%|██▎       | 82/357 [00:08<00:13, 21.02it/s]

Error processing patient aaaaahsi: 0
Error processing patient aaaaamyc: 0Error processing patient aaaaanoa: 0
Error processing patient aaaaalnt: 0Error processing patient aaaaatjr: 0


Processing patients:  24%|██▍       | 85/357 [00:08<00:17, 15.35it/s]



Error processing patient aaaaaooo: 0


Processing patients:  25%|██▍       | 89/357 [00:08<00:13, 19.27it/s]

Error processing patient aaaaanog: 0Error processing patient aaaaaegi: 0

Error processing patient aaaaakis: 0Error processing patient aaaaaprf: 0Error processing patient aaaaaqtw: 0


Error processing patient aaaaalnv: 0

Processing patients:  26%|██▌       | 92/357 [00:08<00:12, 20.48it/s]


Error processing patient aaaaapon: 0Error processing patient aaaaasfw: 0

Error processing patient aaaaajdn: 0

Processing patients:  27%|██▋       | 96/357 [00:08<00:10, 24.01it/s]


Error processing patient aaaaapnl: 0
Error processing patient aaaaajat: 0
Error processing patient aaaaamod: 0


Processing patients:  28%|██▊       | 100/357 [00:09<00:14, 18.35it/s]

Error processing patient aaaaamtj: 0
Error processing patient aaaaantp: 0


Processing patients:  29%|██▉       | 103/357 [00:09<00:15, 15.91it/s]

Error processing patient aaaaajru: 0
Error processing patient aaaaamww: 0


Processing patients:  29%|██▉       | 105/357 [00:09<00:16, 14.88it/s]

Error processing patient aaaaaoee: 0
Error processing patient aaaaalfs: 0Error processing patient aaaaaoxr: 0Error processing patient aaaaaint: 0



Processing patients:  30%|██▉       | 107/357 [00:09<00:16, 15.13it/s]




KeyboardInterrupt: 

Processing patients:  31%|███       | 109/357 [00:20<00:16, 15.13it/s]

In [30]:
features.to_csv("data/processed/windows_21_10_balanced_avg_energy.csv", index=False)