# Convert new EDF EEG files → preprocessed MNE FIF

This notebook mirrors the original preprocessing pipeline (resample 200 Hz, 1–70 Hz band-pass, 50 Hz notch, keep first 19 channels, 1 s fixed-length epochs) but replaces the old I/O:

- **Input:** `.edf` files in `G:\ChristianMusaeus\New_EEG\Clean_raw`
- **Output:** MNE `.fif` files in `G:\ChristianMusaeus\New_EEG\Processed`

Additionally, it uses EDF annotations to extract the requested segments:

- **EO:** from `EO2` → `EO3`
- **EC:** from `EC2` → `EC3`


Notes:

- The EDF files already include event annotations (e.g., `EO1/EO2/EO3`, `EC1/EC2/EC3`).
- This notebook saves both the **preprocessed Raw segments** and the **1-second Epochs** as `.fif`.


In [None]:
import os
from pathlib import Path
import warnings
import json
import traceback

# Avoid numba caching/JIT issues in some environments
os.environ.setdefault("NUMBA_DISABLE_JIT", "1")

import numpy as np
import mne
from IPython.display import clear_output


In [None]:
def _guess_project_root(max_up: int = 6) -> Path:
    """Find repo/project root so relative paths work regardless of notebook CWD."""
    p = Path.cwd().resolve()
    for _ in range(max_up + 1):
        if (p / "data").exists():
            return p
        p = p.parent
    return Path.cwd().resolve()

project_root = _guess_project_root()

# Windows default paths (requested)
windows_edf_dir = Path(r"G:\ChristianMusaeus\New_EEG\Clean_raw")
windows_output_dir = Path(r"G:\ChristianMusaeus\New_EEG\Processed")

# Keep the notebook usable on non-Windows machines by falling back to repo-relative data dirs.
edf_dir = windows_edf_dir if os.name == "nt" else (project_root / "data" / "New_EEG.nosync")
output_dir = windows_output_dir if os.name == "nt" else (project_root / "data" / "NEW_processed.nosync")
output_dir.mkdir(parents=True, exist_ok=True)

print("CWD:", Path.cwd())
print("Project root:", project_root)
print("EDF dir:", edf_dir, "(exists:", edf_dir.exists(), ")")
print("Output dir:", output_dir)

# SUPPRESS VERBOSE OUTPUT
mne.set_log_level('ERROR')
warnings.filterwarnings('ignore')

def _first_annotation_onset(raw: mne.io.BaseRaw, description: str) -> float:
    """Return the first onset (seconds) for a given annotation description.

    EDF annotation labels are not always perfectly consistent (e.g. case/whitespace),
    so we match on a normalized (strip + upper) representation.
    """
    if raw.annotations is None or len(raw.annotations) == 0:
        raise ValueError("No annotations found in EDF.")

    ann = np.asarray(raw.annotations.description, dtype=str)
    ann_norm = np.char.upper(np.char.strip(ann))
    target = str(description).strip().upper()
    idx = np.where(ann_norm == target)[0]
    if idx.size == 0:
        available = sorted({str(x).strip() for x in ann.tolist() if str(x).strip()})
        preview = available[:30]
        suffix = "..." if len(available) > len(preview) else ""
        raise ValueError(
            f"Missing annotation: {description}. Available (n={len(available)}): {preview}{suffix}"
        )
    return float(raw.annotations.onset[int(idx[0])])

import re

def _canonicalize_ch_name(name: str) -> str:
    s = str(name).strip()
    s = re.sub(r"^EEG\s+", "", s, flags=re.IGNORECASE)
    s = re.sub(r"-REF$", "", s, flags=re.IGNORECASE)
    s = re.sub(r"\s+", "", s)
    return s.upper()

# Desired *stored* channel names (exact order)
TARGET_CHANNELS = [
    "EEG Fp1-REF",
    "EEG Fp2-REF",
    "EEG F3-REF",
    "EEG F4-REF",
    "EEG C3-REF",
    "EEG C4-REF",
    "EEG P3-REF",
    "EEG P4-REF",
    "EEG O1-REF",
    "EEG O2-REF",
    "EEG F7-REF",
    "EEG F8-REF",
    "EEG T7-REF",
    "EEG T8-REF",
    "EEG P7-REF",
    "EEG P8-REF",
    "EEG Fz-REF",
    "EEG Cz-REF",
    "EEG Pz-REF",
]
TARGET_CANONICAL = [_canonicalize_ch_name(x) for x in TARGET_CHANNELS]

# If some recordings contain T9/T10 instead of Cz/Pz, you can substitute (not recommended unless you know it is correct).
# By default we fail fast so all saved FIFs are truly comparable.
ALLOW_T9_T10_SUBSTITUTE_FOR_CZ_PZ = False
SUBSTITUTE_MAP = {"CZ": "T9", "PZ": "T10"}

def _preprocess_raw_like_original(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    """Mirror the original pipeline, but enforce a fixed 19-channel 10-20 set.

    This avoids per-file channel differences (e.g. Cz/Pz vs T9/T10) that would otherwise
    change downstream feature dimensionality.
    """
    # Step 1: Resample to 200 Hz
    raw = raw.resample(200)
    # Step 2: Band-pass filter from 1-70 Hz
    raw = raw.filter(l_freq=1.0, h_freq=70.0)
    # Step 3: Notch filter at 50 Hz
    raw = raw.notch_filter(freqs=50)

    # Step 4: Pick the desired channels by *name* (not by position)
    available = {_canonicalize_ch_name(ch): ch for ch in raw.ch_names}
    desired_actual = []
    missing = []
    for canon in TARGET_CANONICAL:
        key = canon
        if key not in available and ALLOW_T9_T10_SUBSTITUTE_FOR_CZ_PZ:
            sub = SUBSTITUTE_MAP.get(key, None)
            if sub:
                key = sub.upper()
        actual = available.get(key, None)
        if actual is None:
            missing.append(canon)
        else:
            desired_actual.append(actual)

    if missing:
        raise ValueError(
            f"{edf_path.name}: Missing required channels: {missing}. "
            f"Available (canonical): {sorted(available.keys())}"
        )

    # Enforce ordering
    try:
        raw.pick_channels(desired_actual, ordered=True)
    except TypeError:
        raw.pick_channels(desired_actual)
        raw.reorder_channels(desired_actual)

    # Rename to the desired *stored* names
    mapping = {old: new for old, new in zip(raw.ch_names, TARGET_CHANNELS) if old != new}
    if mapping:
        raw.rename_channels(mapping)

    if list(raw.ch_names) != list(TARGET_CHANNELS):
        raise RuntimeError(f"Channel standardization failed. Got: {raw.ch_names}")

    return raw

def _epochs_like_original(raw: mne.io.BaseRaw) -> mne.Epochs:
    """Create 1-second fixed-length epochs (same as original)."""
    events = mne.make_fixed_length_events(raw, duration=1.0)
    epochs = mne.Epochs(
        raw,
        events,
        tmin=0.0,
        tmax=1.0,
        baseline=None,
        preload=True,
        reject_by_annotation=True,
    )
    return epochs

def preprocess_edf_to_fif(edf_path: Path, out_dir: Path):
    """One output FIF per EDF: whole recording epoched into 1s windows.

    Epochs that fall fully within EO2→EO3 are labeled 'EO'.
    Epochs that fall fully within EC2→EC3 are labeled 'EC'.
    All other epochs are labeled 'OTHER'.
    """
    raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose="ERROR")
    raw = _preprocess_raw_like_original(raw)

    # Sanity check: show resulting channels (debug)
    try:
        print('  Channels after standardization (n=%d):' % len(raw.ch_names))
        for ch in raw.ch_names:
            print('   -', ch)
    except Exception:
        pass


    # Extract EO2→EO3 and EC2→EC3 windows using annotation onsets (seconds)
    t_eo2 = _first_annotation_onset(raw, "EO2")
    t_eo3 = _first_annotation_onset(raw, "EO3")
    t_ec2 = _first_annotation_onset(raw, "EC2")
    t_ec3 = _first_annotation_onset(raw, "EC3")
    if not (t_eo2 < t_eo3):
        raise ValueError(f"Invalid EO window: EO2={t_eo2} EO3={t_eo3}")
    if not (t_ec2 < t_ec3):
        raise ValueError(f"Invalid EC window: EC2={t_ec2} EC3={t_ec3}")

    # Fixed-length events over the entire recording (same logic as original)
    events = mne.make_fixed_length_events(raw, duration=1.0)
    sfreq = float(raw.info['sfreq'])

    # Label each epoch by whether it lies inside the requested windows
    epoch_starts = (events[:, 0] - raw.first_samp) / sfreq
    epoch_ends = epoch_starts + 1.0

    EO_CODE, EC_CODE, OTHER_CODE = 1, 2, 3
    codes = np.full(events.shape[0], OTHER_CODE, dtype=int)
    in_eo = (epoch_starts >= t_eo2) & (epoch_ends <= t_eo3)
    in_ec = (epoch_starts >= t_ec2) & (epoch_ends <= t_ec3)
    codes[in_eo] = EO_CODE
    codes[in_ec] = EC_CODE
    events_labeled = events.copy()
    events_labeled[:, 2] = codes

    event_id = {"EO": EO_CODE, "EC": EC_CODE, "OTHER": OTHER_CODE}
    epochs = mne.Epochs(
        raw,
        events_labeled,
        event_id=event_id,
        tmin=0.0,
        tmax=1.0,
        baseline=None,
        preload=True,
        reject_by_annotation=True,
    )

    out_path = out_dir / f"{edf_path.stem}_epo.fif"
    epochs.save(str(out_path), overwrite=True)
    return out_path

edf_files = sorted(edf_dir.glob("*.edf")) + sorted(edf_dir.glob("*.EDF"))
edf_files = sorted({p.resolve() for p in edf_files})
print(f"Found {len(edf_files)} EDF files in {edf_dir}")

# If True, keeps the output tidy, but it also hides errors as the notebook runs.
# For debugging missing files, keep this False.
USE_CLEAR_OUTPUT = False

processed = 0
failures = []  # list[dict]

for i, edf_path in enumerate(edf_files, 1):
    if USE_CLEAR_OUTPUT:
        clear_output(wait=True)
    print(f"[{i}/{len(edf_files)}] Processing {edf_path.name}...")
    try:
        out_path = preprocess_edf_to_fif(edf_path, output_dir)
        if not out_path.exists():
            raise FileNotFoundError(f"Expected output not found after save: {out_path}")
        processed += 1
        print(f"  Saved {out_path.name} ({processed} processed so far)")
    except Exception as e:
        failures.append(
            {
                "edf": str(edf_path),
                "error": repr(e),
                "traceback": traceback.format_exc(),
            }
        )
        print(f"  Error: {edf_path.name}: {repr(e)}")

print(f"\nTotal EDF files successfully processed: {processed} / {len(edf_files)}")

# Summarize failures clearly (and persist a report file for sharing/debugging)
if failures:
    failed_names = [Path(x["edf"]).name for x in failures]
    print(f"\nFailed EDF files (n={len(failed_names)}):")
    for name in failed_names:
        print(" -", name)

    report_path = output_dir / "edf_to_fif_failures.json"
    try:
        report_path.write_text(json.dumps(failures, indent=2), encoding="utf-8")
        print(f"\nWrote failure report: {report_path}")
    except Exception as e:
        print(f"\nCould not write failure report ({report_path}): {repr(e)}")

# Cross-check: which outputs are missing?
expected_outputs = {f"{p.stem}_epo.fif" for p in edf_files}
existing_outputs = {p.name for p in output_dir.glob("*_epo.fif")}
missing_outputs = sorted(expected_outputs - existing_outputs)
if missing_outputs:
    print(f"\nMissing output FIFs (n={len(missing_outputs)}):")
    for name in missing_outputs:
        print(" -", name)
else:
    print("\nAll expected output FIFs are present.")

# Optional: retry only the failed EDF files once more (useful if the failure was transient).
RETRY_FAILED = True
if RETRY_FAILED and failures:
    print("\nRetrying failed EDF files...")
    retry_failures = []
    retry_success = 0

    retry_targets = sorted({Path(x["edf"]).resolve() for x in failures})
    for i, edf_path in enumerate(retry_targets, 1):
        print(f"[retry {i}/{len(retry_targets)}] {edf_path.name}...")
        try:
            out_path = preprocess_edf_to_fif(edf_path, output_dir)
            if not out_path.exists():
                raise FileNotFoundError(f"Expected output not found after save: {out_path}")
            retry_success += 1
            print(f"  Saved {out_path.name}")
        except Exception as e:
            retry_failures.append(
                {
                    "edf": str(edf_path),
                    "error": repr(e),
                    "traceback": traceback.format_exc(),
                }
            )
            print(f"  Retry error: {edf_path.name}: {repr(e)}")

    print(f"\nRetry summary: {retry_success} succeeded, {len(retry_failures)} still failing")
    if retry_failures:
        retry_report_path = output_dir / "edf_to_fif_retry_failures.json"
        try:
            retry_report_path.write_text(json.dumps(retry_failures, indent=2), encoding="utf-8")
            print(f"Wrote retry failure report: {retry_report_path}")
        except Exception as e:
            print(f"Could not write retry failure report ({retry_report_path}): {repr(e)}")


### Sanity check: print the number of epochs per saved FIF

The epoch count depends on the duration of the EO2→EO3 and EC2→EC3 windows and the fixed 1 s epoching.

In [None]:
epo_files = sorted(output_dir.glob("*_epo.fif"))
print(f"Found {len(epo_files)} epoch FIF files in {output_dir}")

for i, path in enumerate(epo_files, 1):
    try:
        epochs = mne.read_epochs(str(path), preload=False, verbose=False)
        counts = {name: int(np.sum(epochs.events[:, 2] == code)) for name, code in epochs.event_id.items()}
        print(f"[{i}/{len(epo_files)}] {path.name} — Epochs: {len(epochs)} — counts: {counts}")
    except Exception as e:
        print(f"[{i}/{len(epo_files)}] {path.name} — Error: {e}")
