# Pre-processing pipeline

In [None]:
import mne
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import concurrent.futures
import pywt

### configuration setup

In [None]:
RAW_PATH = '/Users/jannis/Git/tuh-eeg-seizure-detection/data/raw'
OUTPUT_PATH = '/Users/jannis/Git/tuh-eeg-seizure-detection/data/processed_fast'

SAMPLING_FREQ = 250
WINDOW_LENGTH = 10
CONFIGURATIONS = []
CHANNELS = []

### load raw data

#### extract events from annotations

In [None]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start_time = float(parts[1])
            stop_time = float(parts[2])
            label = parts[3]
            
            data.append({
                "label": label,
                "onset": start_time,
                "duration": stop_time - start_time
            })
            
            
    return data

#### load the TUH EEG dataset

In [None]:
def load_tuh_eeg():
    cols = ["set", "patient_id", "session_id", "configuration", "recording_id", "recording_path", "label", "onset", "duration"]
    data = []
    
    # get all edf files in RAW/edf
    edf_path = os.path.join(RAW_PATH, "edf")
    for root, dirs, files in os.walk(edf_path):
        for file in files:
            if file.endswith(".edf"):
                rel_path = os.path.relpath(root, edf_path)
                parts = rel_path.split("/")
                
                if len(parts) != 4:
                    continue
                    
                set_name, patient_id, session_id, configuration = parts
                
                recording_path = os.path.join(root, file)
                recording_id = file.replace(".edf", "").split("_")[-1]
                annotation_path = recording_path.replace(".edf", ".csv_bi")
                
                if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                    continue
                
                events = extract_events_from_annotations(annotation_path)
                for event in events:
                    data.append({
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_path": recording_path,
                        "label": event["label"],
                        "onset": event["onset"],
                        "duration": event["duration"]
                    })
                    
    return pd.DataFrame(data, columns=cols)

data = load_tuh_eeg()
data.tail()


### pre-process the data

In [None]:
def preprocess(patient):
    corrupted = []
    
    # output directory
    seizure_output_dir = os.path.join(OUTPUT_PATH, patient, "seizure")
    non_seizure_output_dir = os.path.join(OUTPUT_PATH, patient, "non_seizure")
    os.makedirs(seizure_output_dir, exist_ok=True)
    os.makedirs(non_seizure_output_dir, exist_ok=True)
    
    recordings = data[data["patient_id"] == patient]["recording_path"].unique()
    
    for recording in recordings:
        raw = mne.io.read_raw_edf(recording, preload=True)
        
        # sometimes meas date breaks code
        raw.set_meas_date(None)
        
        # pre-processing to remove noise
        raw.resample(SAMPLING_FREQ)
    
        events = data[data["recording_path"] == recording]
        
        onset = events["onset"].values
        duration = events["duration"].values
        label = events["label"].values
        
        annotations = mne.Annotations(onset=onset, duration=duration, description=label)
        raw.set_annotations(annotations)
        
        for _, event in events.iterrows():
            patient_id = event["patient_id"]
            onset = event["onset"]
            duration = event["duration"]
            label = event["label"]
            
            num_windows = int(duration / WINDOW_LENGTH)
            
            if num_windows == 0:
                continue
                
            windows = [{"start": onset + i * WINDOW_LENGTH, "stop": onset + (i + 1) * WINDOW_LENGTH} for i in range(num_windows)]
            window_df = pd.DataFrame(windows)
            
            for i, window in window_df.iterrows():
                start = window["start"]
                stop = window["stop"]
                
                output_dir = seizure_output_dir if label == "seiz" else non_seizure_output_dir
                file_name = f"{patient_id}_{event['session_id']}_{event["recording_id"]}_{i}_raw.fif"
     
                
                if stop > raw.times[-1]:
                    if stop - 1/SAMPLING_FREQ == raw.times[-1]:
                        stop = raw.times[-1]
                        raw_window = raw.copy().crop(start, stop, include_tmax=True)
                    else:
                        print("Corrupted annotation", file_name)
                        corrupted.append(file_name)
                        continue
                else:
                    raw_window = raw.copy().crop(start, stop, include_tmax=False)
    
                raw_window.save(os.path.join(output_dir, file_name), overwrite=True)
                raw_window.close()
        
        raw.close()
        
    return corrupted

In [None]:
patients = data["patient_id"].unique()
patients = patients[200:]

with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = []
    for patient in patients:
        futures.append(executor.submit(preprocess, patient=patient))
    for future in concurrent.futures.as_completed(futures):
        print(future.result())

In [None]:
len(data["patient_id"].unique())