# Pre-processing pipeline

In [None]:
import mne
import pandas as pd
import os
import numpy as np
import re
import xarray as xr
import multiprocessing as mp
from tqdm import tqdm
import yaml
from config import get_preprocessing_config
import matplotlib.pyplot as plt

### configuration setup

In [None]:
mne.set_log_level('WARNING')

In [None]:
conf = get_preprocessing_config()

RAW_PATH = conf["input_path"]
OUTPUT_FILE = conf["output_file"]

SAMPLING_FREQ = conf["sampling_frequency"]
WINDOW_LENGTH = conf["window_length"] 
OVERLAP = conf["overlap"] 
CONFIGURATIONS = conf["configurations"]
CHANNELS = conf["channels"]

In [None]:
WINDOW_LENGTH = 30
OVERLAP = 15

### load windows

In [None]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start = float(parts[1])
            stop = float(parts[2])

            label = parts[3]
            label_map = {"bckg": 0, "seiz": 1}
            label = label_map[label]
            
            data.append({
                "label": label,
                "start": start,
                "stop": stop,
            })
            
    return pd.DataFrame(data)

In [None]:
def load_windows():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "event_index", "start", "stop", "label"]
    data = []
    
    edf_path = os.path.join(RAW_PATH, "edf")
    
    for root, _, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
        
            set_name, patient_id, session_id, configuration = parts
            
            if configuration not in CONFIGURATIONS:
                continue
        
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue
            
            events = extract_events_from_annotations(annotation_path)
            
            for i, event in events.iterrows():
                start, stop, label = event.loc[["start", "stop", "label"]]
                label = int(label)
                duration = stop - start

                if duration < WINDOW_LENGTH:
                    continue

                while start + WINDOW_LENGTH < stop:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "event_index": i,
                        "start": start,
                        "stop": start + WINDOW_LENGTH,
                        "label": label,
                    })
                    
                    start += WINDOW_LENGTH - OVERLAP

    return pd.DataFrame(data, columns=cols)

In [None]:
windows = load_windows()
windows 

### undersample majority class

In [None]:
bckg_windows = windows[windows["label"] == 0]
seiz_windows = windows[windows["label"] == 1]

print("Seizure windows:", len(seiz_windows))
print("Background windows:", len(bckg_windows))

In [None]:
if len(bckg_windows) > len(seiz_windows):
    bckg_windows = bckg_windows.sample(n=len(seiz_windows))
else:
    seiz_windows = seiz_windows.sample(n=len(bckg_windows))

windows = pd.concat([seiz_windows, bckg_windows]).reset_index(drop=True)
windows

### preprocessing

In [None]:
def remove_powerline_noise(raw):
    powerline_noises = [60]

    for freq in powerline_noises:
        raw.notch_filter(freqs=freq)

    return raw

In [None]:
def butterworth_filter(raw):
    iir_params = dict(order=4, ftype='butter')
    raw.filter(0.5, 50, method='iir', iir_params=iir_params)
    return raw

In [None]:
def crop_raw(raw, start, stop):
    """Crops the raw data based on the onset and duration, handling edge cases."""
    if stop > raw.times[-1]:
        if stop - 1 / raw.info["sfreq"] == raw.times[-1]:
            return raw.copy().crop(start, raw.times[-1], include_tmax=True), True
        else:
            return None, False
    else:
        return raw.copy().crop(start, stop, include_tmax=False), True

In [None]:
def convert_to_freq_domain(channel_data):
    # Number of data points
    n = len(channel_data)

    # Compute FFT
    fft_result = np.fft.fft(channel_data)
    amplitude = np.abs(fft_result) / n  # Normalize magnitude

    # Create frequency axis
    frequencies = np.fft.fftfreq(n, d=1/SAMPLING_FREQ)

    # Take only the positive half of the spectrum (up to Nyquist frequency)
    half_n = n // 2
    frequencies = frequencies[:half_n]
    amplitude = amplitude[:half_n]

    return amplitude 
    

### process recordings

In [None]:
def process_recording(recording_path, recording_windows: pd.DataFrame, domain: str = "time"):
    try:
        
        raw_recording = mne.io.read_raw_edf(recording_path, preload=True).pick(picks=CHANNELS)
        raw_recording.set_meas_date(None)
        
        raw_recording = remove_powerline_noise(raw_recording)
        raw_recording = butterworth_filter(raw_recording)
        raw_recording = raw_recording.resample(SAMPLING_FREQ)

        raw_windows = []

        for _, window in recording_windows.iterrows():
            patient_id, label, start, stop = window[["patient_id", "label", "start", "stop"]]
            raw_window, valid = crop_raw(raw_recording, start, stop)
            if not valid:
                continue
        
            channel_data = raw_window.get_data()

            if domain == "freq":
                channel_data = np.apply_along_axis(convert_to_freq_domain, 1, channel_data) 

            raw_windows.append({
                "patient_id": patient_id,
                "channel_data": channel_data,
                "label": label, 
            })

            raw_window.close()
        
        raw_recording.close()

        # create xarray dataset from raw windows
        channel_data = np.stack([window["channel_data"] for window in raw_windows])
        labels = np.array([window["label"] for window in raw_windows])
        patient_id = np.array([window["patient_id"] for window in raw_windows])

        data = xr.DataArray(channel_data, dims=("window", "channel", "time"), coords={
            "patient_id": ("window", patient_id),
            "label": ("window", labels),
            "channel": CHANNELS,
        })

        return data
    except Exception as e:
        print(f"Failed to process recording {recording_path}: {e}")
        return None

In [None]:
def process_recordings_parallel(recordings, num_processes=None, domain="time"):
    manager = mp.Manager()
    queue = manager.Queue()

    if num_processes is None:
        num_processes = mp.cpu_count()
    
    def listener(q, total):
        pbar = tqdm(total=total, desc="Processing recordings")
        for _ in range(total):
            q.get()
            pbar.update()
            pbar.refresh()
        pbar.close()

    def callback(_):
        queue.put(1)    

    def error_callback(e):
        print(f"Error: {e}")
        queue.put(1)

    with mp.Pool(num_processes) as pool:
        print("Starting parallel processing...")
        listener_process = mp.Process(target=listener, args=(queue, len(recordings)))
        listener_process.start()

        data = []

        for recording_path, recording_windows in recordings:
            res = pool.apply_async(process_recording, args=(recording_path, recording_windows, domain), callback=callback, error_callback=error_callback)
            data.append(res)

        pool.close()
        pool.join()
        
        listener_process.join()

        data = [d.get() for d in data]
        data = [d for d in data if d is not None]

        print("Combining results...")
        
        data = xr.concat(data, dim="window")

        print("Finished processing recordings.")

        return data

In [None]:
recordings = windows.groupby("recording_path")
data = process_recordings_parallel(recordings, domain="time", num_processes=64)

In [None]:
OUTPUT_FILE = '/dhc/home/jannis.hajda/tuh-eeg-seizure-detection/data/preprocessed/windows_30.nc'

### write preprocessed data to disk

In [None]:
data.to_netcdf(OUTPUT_FILE)