# Pre-processing pipeline

In [None]:
import mne
import numpy
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures
import pywt
import scipy as sp
from sklearn.model_selection import StratifiedGroupKFold
import re
from mne.preprocessing import ICA
import xarray as xr
import multiprocessing as mp
from tqdm import tqdm

### configuration setup

In [None]:
mne.set_log_level('WARNING')

RAW_PATH = '/dhc/home/jannis.hajda/tuh-eeg-seizure-detection/data/raw'
OUTPUT_PATH = '/dhc/home/jannis.hajda/tuh-eeg-seizure-detection/data/processed'

SAMPLING_FREQ = 250
WINDOW_LENGTH = 21
OVERLAP = 10.5 
CONFIGURATIONS = ["01_tcp_ar"]
CHANNELS = ["EEG FP1-REF", "EEG FP2-REF", "EEG F7-REF", "EEG F3-REF", "EEG F4-REF", "EEG F8-REF", "EEG T3-REF", "EEG C3-REF", "EEG C4-REF", "EEG T4-REF", "EEG T5-REF", "EEG P3-REF", "EEG P4-REF", "EEG T6-REF", "EEG O1-REF", "EEG O2-REF", "EEG CZ-REF", "EEG A1-REF", "EEG A2-REF"]

In [None]:
def split_channels_to_hemispheres(channels: list):
    left_hemisphere = []
    right_hemisphere = []
    
    for channel in channels:
        channel_number = re.search(r'\d+', channel)
        if channel_number is None:
            continue
        
        if int(channel_number.group()) % 2 == 0:
            right_hemisphere.append(channel)
        else:
            left_hemisphere.append(channel)
    
    return left_hemisphere, right_hemisphere

LEFT_HEMISPHERE, RIGHT_HEMISPHERE = split_channels_to_hemispheres(CHANNELS) 

### load windows

In [None]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start = float(parts[1])
            stop = float(parts[2])
            label = parts[3]
            
            data.append({
                "label": label,
                "start": start,
                "stop": stop,
            })
            
    return pd.DataFrame(data)

In [None]:
def load_windows():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "event_index", "start", "stop", "label"]
    data = []
    
    edf_path = os.path.join(RAW_PATH, "edf")
    
    for root, _, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
        
            set_name, patient_id, session_id, configuration = parts
            
            if configuration not in CONFIGURATIONS:
                continue
        
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue
            
            events = extract_events_from_annotations(annotation_path)
            
            for i, event in events.iterrows():
                start, stop, label = event.loc[["start", "stop", "label"]]
                duration = stop - start

                if duration < WINDOW_LENGTH:
                    continue

                while start + WINDOW_LENGTH < stop:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "event_index": i,
                        "start": start,
                        "stop": start + WINDOW_LENGTH,
                        "label": label,
                    })
                    
                    start += WINDOW_LENGTH - OVERLAP

    return pd.DataFrame(data, columns=cols)

In [None]:
windows = load_windows()
windows 

### undersample majority class

In [None]:
seiz_windows = windows[windows["label"] == "seiz"]
bckg_windows = windows[windows["label"] == "bckg"]

print("Seizure windows:", len(seiz_windows))
print("Background windows:", len(bckg_windows))

In [None]:
if len(bckg_windows) > len(seiz_windows):
    bckg_windows = bckg_windows.sample(n=len(seiz_windows))
else:
    seiz_windows = seiz_windows.sample(n=len(bckg_windows))

windows = pd.concat([seiz_windows, bckg_windows])
windows

### preprocessing

In [None]:
def remove_powerline_noise(raw):
    powerline_noises = [60]

    for freq in powerline_noises:
        raw.notch_filter(freqs=freq)

    return raw

In [None]:
def butterworth_filter(raw):
    iir_params = dict(order=4, ftype='butter')
    raw.filter(0.5, 50, method='iir', iir_params=iir_params)
    return raw

In [None]:
def crop_raw(raw, start, stop):
    """Crops the raw data based on the onset and duration, handling edge cases."""
    if stop > raw.times[-1]:
        if stop - 1 / raw.info["sfreq"] == raw.times[-1]:
            return raw.copy().crop(start, raw.times[-1], include_tmax=True), True
        else:
            return None, False
    else:
        return raw.copy().crop(start, stop, include_tmax=False), True

### process recordings

In [None]:
def process_recording(recording_path, recording_windows: pd.DataFrame):
    try:
        
        raw_recording = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        raw_recording.set_meas_date(None)
        
        raw_recording = remove_powerline_noise(raw_recording)
        raw_recording = butterworth_filter(raw_recording)
        raw_recording = raw_recording.resample(SAMPLING_FREQ)

        raw_windows = []

        for _, window in recording_windows.iterrows():
            patient_id, label, start, stop = window[["patient_id", "label", "start", "stop"]]
            raw_window, valid = crop_raw(raw_recording, start, stop)
            if not valid:
                continue
        
            channel_data = raw_window.get_data()

            raw_windows.append({
                "patient_id": patient_id,
                "channel_data": channel_data,
                "label": label, 
            })

            raw_window.close()
        
        raw_recording.close()

        # create xarray dataset from raw windows
        channel_data = np.stack([window["channel_data"] for window in raw_windows])
        labels = np.array([window["label"] for window in raw_windows])
        patient_id = np.array([window["patient_id"] for window in raw_windows])

        data = xr.DataArray(channel_data, dims=("window", "channel", "time"), coords={
            "patient_id": ("window", patient_id),
            "label": ("window", labels),
            "channel": CHANNELS,
        })

        return data
    except Exception as e:
        print(f"Failed to process recording {recording_path}: {e}")
        return None

In [None]:
def process_recordings_parallel(recordings, num_processes=None):
    manager = mp.Manager()
    queue = manager.Queue()

    if num_processes is None:
        num_processes = mp.cpu_count()
    
    def listener(q, total):
        pbar = tqdm(total=total, desc="Processing recordings")
        for _ in range(total):
            q.get()
            pbar.update()
            pbar.refresh()
        pbar.close()

    def callback(_):
        queue.put(1)    

    def error_callback(e):
        print(f"Error: {e}")
        queue.put(1)

    with mp.Pool(num_processes) as pool:
        print("Starting parallel processing...")
        listener_process = mp.Process(target=listener, args=(queue, len(recordings)))
        listener_process.start()

        results = []

        # Process recordings
        for recording_path, recording_windows in recordings:
            result = pool.apply_async(process_recording, args=(recording_path, recording_windows), callback=callback, error_callback=error_callback)
            results.append(result)

        pool.close()
        pool.join()
        
        listener_process.join()

        results = [r.get() for r in results]
        results = [r for r in results if r is not None]

        print("Combining results...")

        data = xr.concat(results, dim="window")

        print("Finished processing recordings.")

        return data

In [None]:
recordings = windows.groupby("recording_path")
data = process_recordings_parallel(recordings, num_processes=24)
data

### normalize data across channels

In [None]:
def min_max_normalize(data):
    min_vals = data.min(dim="window")
    max_vals = data.max(dim="window")
    normalized_data = (data - min_vals) / (max_vals - min_vals)
    return normalized_data

In [None]:
data = min_max_normalize(data)

### write preprocessed data to disk

In [None]:
data.to_netcdf(os.path.join(OUTPUT_PATH, "windows_normalized.nc"))