Import Libraries

In [3]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, welch
from scipy import stats
import mne
from mne.filter import filter_data
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score

Summary file parsing to extract seizure start and end times

In [None]:
def extract_seizure_events_from_txt(data_path):
    """"
    Searches for the .txt summary file in a folder, parses seizure times.
    """
    seizure_info = {}
    summary_file = None

    for fname in os.listdir(data_path):
        if fname.endswith(".txt"):
            summary_file = os.path.join(data_path, fname)
            break
    if summary_file is None:
        raise FileNotFoundError("No Summary Text File Found!")
    
    current_file = None 
    with open(summary_file, "r") as f:
        for line in f:
            line = line.strip()
            if line.startswith("File Name:"):
                current_file = line.split(":")[1].strip()
                seizure_info[current_file] = []
            elif line.startswith("Seizure Start Time:"):
                start_time = int(line.split(":")[1].strip().split()[0])
            elif line.startswith("Seizure End Time:"):
                end_time = int(line.split(":")[1].strip()[0])
                seizure_info[current_file].append((start_time, end_time))
    return seizure_info

Data Loading

In [2]:
def load_edf_with_seizures(edf_path, seizure_times, sampling_rate=256):
    """
    Load EDF file and return data with seizure annotations
    """
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
    seizure_samples = [
        (int(start * sampling_rate), int(end*sampling_rate))
        for start, end in seizure_times
    ]
    return {
        'raw': raw,
        'seizure_samples': seizure_samples
    }


Data Processing

In [4]:
def process_single_file(raw, edf_filename, output_folder, selected_channels, seizure_windows):
    """
    Preprocesses EEG and saves .npz with metadata.
    """
    os.makedirs(output_folder, exist_ok=True)

    raw.pick_channels(selected_channels)
    data = raw.get_data()
    sfreq = raw.info['sfreq']


    data = filter_data(data, sfreq=sfreq, l_freq=0.5, h_freq=25.0, verbose=False)

    base_name = os.path.splitext(edf_filename)[0]
    save_path = os.path.join(output_folder, f"{base_name}_preprocessed.npz")
    np.savez(save_path,
             data=data,
             seizure_windows=np.array(seizure_windows, dtype=np.int32),
             sampling_rate=sfreq,
             channels=np.array(selected_channels),
             file_name=edf_filename)

    return data

In [None]:
def preprocess_and_save(edf_folder):
    """
    Batch preprocesses and saves all EDF files in folder.
    """
    output_folder = os.path.join(edf_folder, 'preprocessed')
    selected_channels = [
        'FP1-F7', 'F7-T7', 'T7-P7', 'P7-O1',
        'FP1-F3', 'F3-C3', 'C3-P3', 'P3-O1',
        'FP2-F4', 'F4-C4', 'C4-P4', 'P4-O2',
        'FP2-F8', 'F8-T8', 'T8-P8', 'P8-O2',
        'FZ-CZ', 'CZ-PZ'
    ]
    seizure_dict = extract_seizure_events_from_txt(edf_folder)

    for fname in os.listdir(edf_folder):
        if fname.endswith('.edf'):
            edf_path = os.path.join(edf_folder, fname)
            seizure_times = seizure_dict.get(fname, [])

            try:
                data_obj = load_edf_with_seizures(edf_path, seizure_times)
                raw = data_obj['raw']
                print(f"Processing: {fname}")

                process_single_file(
                    raw=raw,
                    edf_filename=fname,
                    output_folder=output_folder,
                    selected_channels=selected_channels,
                    seizure_windows=data_obj['seizure_samples']
                )

                base_name = os.path.splitext(fname)[0]
                print(f"Saved to: {os.path.join(output_folder, base_name + '_preprocessed.npz')}\n")

            except Exception as e:
                print(f"Failed to process {fname}: {e}")

In [5]:
def segment_and_label(eeg_data, seizure_windows, window_duration=2.0, sampling_rate=256, overlap=0.5):
    """
    Segments EEG into overlapping windows and labels them
    """
    window_size = int(window_duration*sampling_rate)
    step_size = int(window_size*(1-overlap))
    channels, total_samples = eeg_data.shape
    X,y = [], []
    for start in range(0, total_samples-window_size + 1, step_size):
        end = start + window_size
        window = eeg_data[:, start:end]

        label = 0
        for sz_start, sz_end in seizure_windows:
            if end > sz_start and start < sz_end:
                label = 1
                break
        X.append(window) 
        y.append(label)

    X = np.stack(X)
    y = np.stack(y)

    return X,y

In [None]:
def batch_segment_preprocessed(folder_path):
    """
    Segments all preprocessed .npz files and 2s windows and labels them
    """
    output_folder = os.path.join(folder_path, 'segmented')
    os.makedirs(output_folder, exist_ok=True)

    for fname in os.listdir(folder_path):
        if fname.endswith("_preprocessed.npz"):
            full_path = os.path.join(folder_path, fname)

            try:
                npz = np.load(full_path, allow_pickle=True)
                eeg = npz['data']
                sz_windows = npz["seizure_windows"]
                sfreq = int(npz['sampling_rate'])
                file_name = str(npz['file_name'])
                channels = list(npz['channels'])
                X, y = segment_and_label(eeg_data=eeg,
                         seizure_windows=sz_windows,
                         window_duration=2.0,
                         sampling_rate=sfreq,
                         overlap = 0.5)
                base_name = fname.replace('_preprocessed.npz', '')
                save_path = os.path.join(output_folder, f"{base_name}_segmented.npz")
                np.savez(save_path,
                         X=X,
                         y=y,
                         sampling_rate=sfreq,
                         channels=channels,
                         file_name=file_name)

                print(f"Segmented and saved: {save_path}")

            except Exception as e:
                print(f"Failed on {fname}: {e}")