In [1]:
import os
import gc
import numpy as np
import mne
from mne.preprocessing import ICA, create_eog_epochs

In [2]:
#change paths
RAW_PATH = "/workspace/raw_dataset"
SAVE_PATH = "/workspace/preprocessed_data_2"
os.makedirs(SAVE_PATH, exist_ok=True)

In [3]:
CANONICAL_CHANNELS = [
    "EEG Fp1", "EEG F3", "EEG C3", "EEG P3", "EEG O1",
    "EEG F7", "EEG T3", "EEG T5",
    "EEG Fc1", "EEG Fc5",
    "EEG Cp1", "EEG Cp5",
    "EEG F9", "EEG Fz", "EEG Pz",
    "EEG F4", "EEG C4", "EEG P4", "EEG O2",
    "EEG F8", "EEG T4", "EEG T6",
    "EEG Fc2", "EEG Fc6",
    "EEG Cp2", "EEG Cp6",
    "EEG F10"
]

In [4]:
# Bandpass
LOW_F = 0.5
HIGH_F = 40.0

# Notch
NOTCH = 50.0

# Epoching (USED LATER, NOT HERE)
EPOCH_LEN = 5.0

# Resampling
RESAMPLE_SFREQ = 256

# Memory safety
MAX_DURATION = 600          # seconds (final output)
ICA_MAX_DURATION = 20 * 60  # seconds (ICA only)

# ICA
ICA_METHOD = "picard"
ICA_N_COMPONENTS = 0.95
ICA_RANDOM_STATE = 97

In [5]:
def load_edf(edf_path):
    raw = mne.io.read_raw_edf(edf_path, preload=False, verbose=False)
    raw.pick("eeg")

    if raw.n_times < 1:
        raise ValueError("Empty EDF")

    return raw


In [6]:
def preprocess_raw(raw):
    # ---- HARD MEMORY BOUND FIRST ----
    duration = raw.times[-1]
    crop_t = min(duration, ICA_MAX_DURATION)

    raw.crop(tmax=crop_t)     # crop BEFORE loading data
    raw.load_data()           # now load only cropped data

    # ---- now it's safe ----
    raw.resample(RESAMPLE_SFREQ, npad="auto")
    raw.filter(LOW_F, HIGH_F, verbose=False)
    raw.notch_filter(NOTCH, verbose=False)

    return raw


In [7]:
def run_ica(raw):
    raw_ica = raw.copy().filter(l_freq=1.0, h_freq=None, verbose=False)

    try:
        ica = ICA(
            n_components=ICA_N_COMPONENTS,
            method=ICA_METHOD,
            random_state=ICA_RANDOM_STATE,
            max_iter="auto"
        )

        ica.fit(raw_ica)

        if ica.n_components_ < 2:
            print("Skipping ICA (rank too low)")
            return raw

        ica.apply(raw)

    except Exception as e:
        print("ICA failed, skipping:", e)

    finally:
        del raw_ica
        gc.collect()

    return raw


In [8]:
def apply_final_crop(raw):
    if raw.times[-1] > MAX_DURATION:
        raw.crop(tmax=MAX_DURATION)
    return raw


In [9]:
def extract_data(raw):
    data = raw.get_data()
    ch_names = raw.ch_names
    sfreq = raw.info["sfreq"]
    n_times = raw.n_times
    return data, ch_names, sfreq, n_times


In [10]:
def read_siena_seizure_times(annotation_path):
    seizure_times = []

    with open(annotation_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue

            parts = line.replace(",", " ").split()
            if len(parts) < 2:
                continue

            try:
                start = float(parts[0])
                end = float(parts[1])
                seizure_times.append((start, end))
            except ValueError:
                continue

    return seizure_times


In [11]:
def epoch_and_label(raw, seizure_times, epoch_len=5.0):
    epochs = mne.make_fixed_length_epochs(
        raw,
        duration=epoch_len,
        preload=True
    )

    X = epochs.get_data().astype(np.float32)

    # Normalize per epoch
    mean = X.mean(axis=-1, keepdims=True)
    std = X.std(axis=-1, keepdims=True) + 1e-6
    X = (X - mean) / std

    n_epochs = len(X)
    y = np.zeros(n_epochs, dtype=np.int8)

    for i in range(n_epochs):
        start_sec = i * epoch_len
        end_sec = start_sec + epoch_len

        for s0, s1 in seizure_times:
            if end_sec > s0 and start_sec < s1:
                y[i] = 1
                break

    return X, y, epochs.info["sfreq"], epochs.ch_names


In [12]:
def apply_canonical_channels(data, ch_names, canonical_channels, n_times):
    ch_index = {ch.upper(): i for i, ch in enumerate(ch_names)}

    final_data = np.zeros(
        (len(canonical_channels), n_times),
        dtype=np.float32
    )

    for i, ch in enumerate(canonical_channels):
        key = ch.upper()
        if key in ch_index:
            final_data[i] = data[ch_index[key]]

    return final_data


In [13]:
def save_npz(save_path, X, y, sfreq, channels):
    np.savez_compressed(
        save_path,
        data=X,
        labels=y,
        sfreq=sfreq,
        channels=channels
    )


In [14]:
def cleanup(*objs):
    for obj in objs:
        try:
            del obj
        except Exception:
            pass
    gc.collect()


In [15]:
def process_edf(edf_path, save_path):
    raw = None
    data = None
    final_data = None

    try:
        raw = load_edf(edf_path)
        raw = preprocess_raw(raw)
        raw = run_ica(raw)
        raw = apply_final_crop(raw)

        data, ch_names, sfreq, n_times = extract_data(raw)

        final_data = apply_canonical_channels(
            data,
            ch_names,
            CANONICAL_CHANNELS,
            n_times
        )

        save_npz(save_path, final_data, sfreq, CANONICAL_CHANNELS)
        print("Processed:", os.path.basename(edf_path))

    except Exception as e:
        print("Failed:", edf_path, "|", e)

    finally:
        cleanup(raw, data, final_data)


In [16]:
def process_dataset(raw_dir, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    for root, _, files in os.walk(raw_dir):
        for file in files:
            if not file.lower().endswith(".edf"):
                continue

            edf_path = os.path.join(root, file)
            save_path = os.path.join(
                save_dir,
                file.replace(".edf", ".npz")
            )

            if os.path.exists(save_path):
                print("Skipping:", file)
                continue

            process_edf(edf_path, save_path)


In [17]:
for patient in sorted(os.listdir(RAW_PATH)):
    patient_dir = os.path.join(RAW_PATH, patient)
    if not os.path.isdir(patient_dir):
        continue

    for file in sorted(os.listdir(patient_dir)):
        if not file.endswith(".edf"):
            continue

        edf_path = os.path.join(patient_dir, file)
        # annotation_path = os.path.join(ANN_DIR, file.replace(".edf", ".csv"))
        save_path = os.path.join(SAVE_PATH, file.replace(".edf", ".npz"))

        try:
            raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
            raw.pick("eeg")

            # seizure_times = read_siena_seizure_times(annotation_path)

            X, y, sfreq, channels = epoch_and_label(
                raw,
                seizure_times,
                epoch_len=EPOCH_LEN
            )

            save_npz(save_path, X, y, sfreq, channels)

            print("Saved:", file, "| epochs:", len(y), "| seizures:", y.sum())

        except Exception as e:
            print("Failed:", edf_path, "|", str(e))


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


Failed: /workspace/raw_dataset/PN00/PN00-1.edf | name 'seizure_times' is not defined


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


Failed: /workspace/raw_dataset/PN00/PN00-2.edf | name 'seizure_times' is not defined


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


Failed: /workspace/raw_dataset/PN00/PN00-3.edf | name 'seizure_times' is not defined


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


Failed: /workspace/raw_dataset/PN00/PN00-4.edf | name 'seizure_times' is not defined


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


Failed: /workspace/raw_dataset/PN00/PN00-5.edf | name 'seizure_times' is not defined


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)


KeyboardInterrupt: 

- saved as npz as its compressed stack of numpy arrays compatible with dl models

CNN
-  input layer
- 1 d convolution layer - kernel slides over the signal to find features
- activation layer - make the features non-linear
- batch normalisation - normalise feature maps - better generalisation
- pooling layer - samples temporal dimensions - robust to noise and change in temporal cahnges

- bilstm focuses on temporal features
- memory cells - 3 gates
- forget, input, cell state
- forget removes noise
- input - stores features
- cell state - accumalte all features
- forward - from past to present
- backward - present to past

- attention mechanism
- higher wts to all seizure features
