# Pre-processing pipeline

In [21]:
import mne
import pandas as pd
import os
import numpy as np
import re
import xarray as xr
import multiprocessing as mp
from tqdm import tqdm
import yaml
from config import get_preprocessing_config

### configuration setup

In [22]:
mne.set_log_level('WARNING')

In [23]:
conf = get_preprocessing_config()

RAW_PATH = conf["input_path"]
OUTPUT_FILE = conf["output_file"]

SAMPLING_FREQ = conf["sampling_frequency"]
WINDOW_LENGTH = conf["window_length"]
OVERLAP = conf["overlap"] 
CONFIGURATIONS = conf["configurations"]
CHANNELS = conf["channels"]

### load windows

In [24]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start = float(parts[1])
            stop = float(parts[2])
            label = parts[3]
            
            data.append({
                "label": label,
                "start": start,
                "stop": stop,
            })
            
    return pd.DataFrame(data)

In [25]:
def load_windows():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "event_index", "start", "stop", "label"]
    data = []
    
    edf_path = os.path.join(RAW_PATH, "edf")
    
    for root, _, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
        
            set_name, patient_id, session_id, configuration = parts
            
            if configuration not in CONFIGURATIONS:
                continue
        
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue
            
            events = extract_events_from_annotations(annotation_path)
            
            for i, event in events.iterrows():
                start, stop, label = event.loc[["start", "stop", "label"]]
                duration = stop - start

                if duration < WINDOW_LENGTH:
                    continue

                while start + WINDOW_LENGTH < stop:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "event_index": i,
                        "start": start,
                        "stop": start + WINDOW_LENGTH,
                        "label": label,
                    })
                    
                    start += WINDOW_LENGTH - OVERLAP

    return pd.DataFrame(data, columns=cols)

In [26]:
windows = load_windows()
windows 

Unnamed: 0,set,patient_id,session_id,configuration,recording_id,recording_path,event_index,start,stop,label
0,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,689.8535,710.8535,seiz
1,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,700.3535,721.3535,seiz
2,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,710.8535,731.8535,seiz
3,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,721.3535,742.3535,seiz
4,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,731.8535,752.8535,seiz
...,...,...,...,...,...,...,...,...,...,...
4148,train,aaaaanme,s007_2014,01_tcp_ar,t015,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,246.0166,267.0166,seiz
4149,train,aaaaanme,s007_2014,01_tcp_ar,t015,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,256.5166,277.5166,seiz
4150,train,aaaaanme,s007_2014,01_tcp_ar,t015,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,267.0166,288.0166,seiz
4151,train,aaaaanme,s007_2014,01_tcp_ar,t002,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,320.0522,341.0522,seiz


### undersample majority class

In [27]:
seiz_windows = windows[windows["label"] == "seiz"]
bckg_windows = windows[windows["label"] == "bckg"]

print("Seizure windows:", len(seiz_windows))
print("Background windows:", len(bckg_windows))

Seizure windows: 312
Background windows: 3841


In [28]:
if len(bckg_windows) > len(seiz_windows):
    bckg_windows = bckg_windows.sample(n=len(seiz_windows))
else:
    seiz_windows = seiz_windows.sample(n=len(bckg_windows))

windows = pd.concat([seiz_windows, bckg_windows])
windows

Unnamed: 0,set,patient_id,session_id,configuration,recording_id,recording_path,event_index,start,stop,label
0,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,689.8535,710.8535,seiz
1,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,700.3535,721.3535,seiz
2,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,710.8535,731.8535,seiz
3,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,721.3535,742.3535,seiz
4,train,aaaaanme,s011_2014,01_tcp_ar,t000,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,731.8535,752.8535,seiz
...,...,...,...,...,...,...,...,...,...,...
3816,train,aaaaanme,s007_2014,01_tcp_ar,t010,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,987.0000,1008.0000,bckg
3227,train,aaaaanme,s007_2014,01_tcp_ar,t009,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,178.5000,199.5000,bckg
3353,train,aaaaanme,s007_2014,01_tcp_ar,t007,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,283.5000,304.5000,bckg
2555,train,aaaaanme,s002_2012,01_tcp_ar,t010,/home/nis/Git/tuh-eeg-seizure-detection/data/r...,0,420.0000,441.0000,bckg


### preprocessing

In [29]:
def remove_powerline_noise(raw):
    powerline_noises = [60]

    for freq in powerline_noises:
        raw.notch_filter(freqs=freq)

    return raw

In [30]:
def butterworth_filter(raw):
    iir_params = dict(order=4, ftype='butter')
    raw.filter(0.5, 50, method='iir', iir_params=iir_params)
    return raw

In [31]:
def crop_raw(raw, start, stop):
    """Crops the raw data based on the onset and duration, handling edge cases."""
    if stop > raw.times[-1]:
        if stop - 1 / raw.info["sfreq"] == raw.times[-1]:
            return raw.copy().crop(start, raw.times[-1], include_tmax=True), True
        else:
            return None, False
    else:
        return raw.copy().crop(start, stop, include_tmax=False), True

### process recordings

In [32]:
def process_recording(recording_path, recording_windows: pd.DataFrame):
    try:
        
        raw_recording = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        raw_recording.set_meas_date(None)
        
        raw_recording = remove_powerline_noise(raw_recording)
        raw_recording = butterworth_filter(raw_recording)
        raw_recording = raw_recording.resample(SAMPLING_FREQ)

        raw_windows = []

        for _, window in recording_windows.iterrows():
            patient_id, label, start, stop = window[["patient_id", "label", "start", "stop"]]
            raw_window, valid = crop_raw(raw_recording, start, stop)
            if not valid:
                continue
        
            channel_data = raw_window.get_data()

            raw_windows.append({
                "patient_id": patient_id,
                "channel_data": channel_data,
                "label": label, 
            })

            raw_window.close()
        
        raw_recording.close()

        # create xarray dataset from raw windows
        channel_data = np.stack([window["channel_data"] for window in raw_windows])
        labels = np.array([window["label"] for window in raw_windows])
        patient_id = np.array([window["patient_id"] for window in raw_windows])

        data = xr.DataArray(channel_data, dims=("window", "channel", "time"), coords={
            "patient_id": ("window", patient_id),
            "label": ("window", labels),
            "channel": CHANNELS,
        })

        return data
    except Exception as e:
        print(f"Failed to process recording {recording_path}: {e}")
        return None

In [33]:
def process_recordings_parallel(recordings, num_processes=None):
    manager = mp.Manager()
    queue = manager.Queue()

    if num_processes is None:
        num_processes = mp.cpu_count()
    
    def listener(q, total):
        pbar = tqdm(total=total, desc="Processing recordings")
        for _ in range(total):
            q.get()
            pbar.update()
            pbar.refresh()
        pbar.close()

    def callback(_):
        queue.put(1)    

    def error_callback(e):
        print(f"Error: {e}")
        queue.put(1)

    with mp.Pool(num_processes) as pool:
        print("Starting parallel processing...")
        listener_process = mp.Process(target=listener, args=(queue, len(recordings)))
        listener_process.start()

        data = []

        for recording_path, recording_windows in recordings:
            res = pool.apply_async(process_recording, args=(recording_path, recording_windows), callback=callback, error_callback=error_callback)
            data.append(res)

        pool.close()
        pool.join()
        
        listener_process.join()

        data = [d.get() for d in data]
        data = [d for d in data if d is not None]

        print("Combining results...")
        
        data = xr.concat(data, dim="window")

        print("Finished processing recordings.")

        return data

In [34]:
recordings = windows.groupby("recording_path")
data = process_recordings_parallel(recordings, num_processes=16)
data

Starting parallel processing...


Processing recordings: 100%|██████████| 94/94 [00:07<00:00, 12.87it/s] 


Combining results...
Finished processing recordings.


### apply min max normalization for each channel across all windows 

In [35]:
def min_max_normalize(data):
    channel_mins = data.min(dim=["window", "time"])
    channel_maxs = data.max(dim=["window", "time"])

    return (data - channel_mins) / (channel_maxs - channel_mins)

In [36]:
data = min_max_normalize(data)

### write preprocessed data to disk

In [37]:
data.to_netcdf(OUTPUT_FILE)