In [1]:
#!/usr/bin/env python3
"""
XDF → (EEG + optional fNIRS) segmentation around RollerCoaster markers,
with EEG preprocessing (MNE) and MAT export per segment.

Design goals:
- easy to read
- minimal global state
- functions with single responsibility
"""

from __future__ import annotations

import os
import re
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import pyxdf
from scipy.io import savemat
from scipy.stats import zscore

import mne
from mne.preprocessing import ICA

In [2]:
# =============================================================================
# Configuration
# =============================================================================

@dataclass(frozen=True)
class Config:
    xdf_path: str
    out_dir: str = "segmented_clean"
    window_sec: float = 58.0
    participant_id: str = "17"

    # Stream selection (leave as None to auto-detect)
    force_eeg_name: Optional[str] = None
    force_fnirs_name: Optional[str] = None
    force_marker_name: Optional[str] = None

    # Marker filtering
    user_id_filter: Optional[int] = None  # None = accept any user
    keep_first_baseline: bool = True

In [3]:
# =============================================================================
# Small helpers: stream metadata
# =============================================================================

def stream_name(stream: Dict[str, Any]) -> str:
    return stream.get("info", {}).get("name", [""])[0]

def stream_type(stream: Dict[str, Any]) -> str:
    return stream.get("info", {}).get("type", [""])[0]

def get_channel_labels(stream: Optional[Dict[str, Any]]) -> Optional[List[str]]:
    if stream is None:
        return None
    try:
        chs = stream["info"]["desc"][0]["channels"][0]["channel"]
        return [ch.get("label", [""])[0] for ch in chs]
    except Exception:
        return None


# =============================================================================
# XDF stream selection
# =============================================================================

def pick_stream(
    streams: Sequence[Dict[str, Any]],
    force_name: Optional[str],
    kind: str,  # "markers" | "eeg" | "fnirs"
) -> Optional[Dict[str, Any]]:
    """
    Select a stream either by exact name, or by heuristic using stream type/name.
    """
    if force_name:
        for s in streams:
            if stream_name(s) == force_name:
                return s
        raise RuntimeError(f"Could not find {kind} stream with name='{force_name}'")

    for s in streams:
        name = stream_name(s)
        stype = stream_type(s)
        name_l = name.lower()
        stype_u = stype.upper()

        if kind == "markers":
            if stype_u == "MARKERS" or "marker" in name_l:
                return s

        if kind == "eeg":
            if stype_u == "EEG" or re.search(r"\beeg\b", name_l):
                return s

        if kind == "fnirs":
            if stype_u in {"NIRS", "FNIRS"} or re.search(r"(fnirs|nirs|nirstar)", name_l):
                return s

    return None

def print_streams(streams: Sequence[Dict[str, Any]]) -> None:
    print("\nStreams found:")
    for i, s in enumerate(streams):
        print(
            f"  [{i}] name='{stream_name(s)}' type='{stream_type(s)}' "
            f"samples={len(s.get('time_stamps', []))}"
        )


# =============================================================================
# Marker parsing
# =============================================================================

def flatten_marker_labels(markers_time_series: Sequence[Any]) -> List[str]:
    """
    LabRecorder markers often look like: [['label'], ['label2'], ...]
    """
    if not markers_time_series:
        return []
    first = markers_time_series[0]
    if isinstance(first, list):
        return [m[0] if m else "" for m in markers_time_series]
    return [str(m) for m in markers_time_series]


@dataclass(frozen=True)
class StartEvent:
    t0: float
    label: str
    user_id: Optional[int]
    attempt: Optional[int]

def extract_start_events(
    markers_stream: Dict[str, Any],
    user_id_filter: Optional[int],
    keep_first_baseline: bool,
) -> List[StartEvent]:
    """
    Keeps:
    - ALL RollerCoasterStarted markers
    - optionally ONLY the FIRST RollerCoasterBaselineStarted with attempt=0 per user

    TODO : add final baseline for this as well
    """
    markers_times = np.asarray(markers_stream["time_stamps"])
    markers_labels = flatten_marker_labels(markers_stream["time_series"])

    rc_start_pat = re.compile(r"^RollerCoasterStarted\b", re.IGNORECASE)
    baseline_start_pat = re.compile(r"^RollerCoasterBaselineStarted\b", re.IGNORECASE)
    attempt_pat = re.compile(r"\battempt=(\d+)\b", re.IGNORECASE)
    user_pat = re.compile(r"\buser=(\d+)\b", re.IGNORECASE)

    kept_first_baseline_users = set()
    events: List[StartEvent] = []

    for t, lab in zip(markers_times, markers_labels):
        if not isinstance(lab, str):
            continue

        um = user_pat.search(lab)
        uid = int(um.group(1)) if um else None

        if user_id_filter is not None and uid != user_id_filter:
            continue

        am = attempt_pat.search(lab)
        attempt = int(am.group(1)) if am else None

        if rc_start_pat.search(lab):
            events.append(StartEvent(float(t), lab, uid, attempt))
            continue

        if keep_first_baseline and baseline_start_pat.search(lab) and attempt == 0:
            if uid not in kept_first_baseline_users:
                events.append(StartEvent(float(t), lab, uid, attempt))
                kept_first_baseline_users.add(uid)

    events.sort(key=lambda e: e.t0)
    return events

In [4]:
# =============================================================================
# Data slicing utilities
# =============================================================================

def slice_stream_by_time(stream: Dict[str, Any], t0: float, t1: float) -> Tuple[np.ndarray, np.ndarray]:
    times = np.asarray(stream["time_stamps"])
    data = np.asarray(stream["time_series"])
    mask = (times >= t0) & (times <= t1)
    return data[mask], times[mask]

In [5]:
# =============================================================================
# EEG channel separation (EEG vs AUX signals)
# =============================================================================

@dataclass
class SeparatedSignals:
    eeg: Dict[str, Any]
    eda: Dict[str, Any]
    ppg: Dict[str, Any]
    resp: Dict[str, Any]

def separate_eeg_and_aux(eeg_stream: Dict[str, Any]) -> SeparatedSignals:
    """
    Splits channels into EEG (non-AUX) + AUX signals (EDA/PPG/RESP) using:
    - unit-based detection when possible
    - otherwise fallback to "first AUX = EDA, second = PPG, third = RESP"
    """
    channel_info = eeg_stream["info"]["desc"][0]["channels"][0]["channel"]
    channel_labels = [ch["label"][0] for ch in channel_info]
    channel_units = [(ch.get("unit", ["N/A"])[0]) for ch in channel_info]

    time_series = np.asarray(eeg_stream["time_series"])
    time_stamps = np.asarray(eeg_stream["time_stamps"])
    sampling_rate = float(eeg_stream["info"]["nominal_srate"][0])

    def empty_bucket() -> Dict[str, Any]:
        return {"channels": [], "labels": [], "data": []}

    eeg = empty_bucket()
    eda = empty_bucket()
    ppg = empty_bucket()
    resp = empty_bucket()

    aux_indices = [(i, lab, channel_units[i]) for i, lab in enumerate(channel_labels) if "AUX" in lab.upper()]

    print(f"\nTotal channels found: {len(channel_labels)}")
    print(f"Sampling rate: {sampling_rate} Hz")
    print(f"Found {len(aux_indices)} AUX channels.")

    eda_assigned = ppg_assigned = resp_assigned = False

    for idx, label, unit in aux_indices:
        unit_upper = str(unit).upper()
        assigned = False

        # Unit-based EDA detection
        if (not eda_assigned) and ("SIEMENS" in unit_upper or "µS" in str(unit) or "US" in unit_upper):
            eda["channels"].append(idx)
            eda["labels"].append(label)
            eda["data"].append(time_series[:, idx])
            print(f"  {label} (unit: {unit}) → EDA/GSR (by unit)")
            eda_assigned = True
            assigned = True

        # Order fallback
        if not assigned:
            if not eda_assigned:
                eda["channels"].append(idx)
                eda["labels"].append(label)
                eda["data"].append(time_series[:, idx])
                print(f"  {label} (unit: {unit}) → EDA/GSR (by order)")
                eda_assigned = True
            elif not ppg_assigned:
                ppg["channels"].append(idx)
                ppg["labels"].append(label)
                ppg["data"].append(time_series[:, idx])
                print(f"  {label} (unit: {unit}) → PPG (by order)")
                ppg_assigned = True
            elif not resp_assigned:
                resp["channels"].append(idx)
                resp["labels"].append(label)
                resp["data"].append(time_series[:, idx])
                print(f"  {label} (unit: {unit}) → RESP (by order)")
                resp_assigned = True

    # Non-AUX = EEG
    for idx, label in enumerate(channel_labels):
        if "AUX" not in label.upper():
            eeg["channels"].append(idx)
            eeg["labels"].append(label)
            eeg["data"].append(time_series[:, idx])

    def finalize(bucket: Dict[str, Any]) -> Dict[str, Any]:
        if bucket["data"]:
            bucket["data"] = np.array(bucket["data"]).T  # samples x channels
            bucket["time_stamps"] = time_stamps
            bucket["sampling_rate"] = sampling_rate
        return bucket

    return SeparatedSignals(
        eeg=finalize(eeg),
        eda=finalize(eda),
        ppg=finalize(ppg),
        resp=finalize(resp),
    )

In [6]:
# =============================================================================
# EEG preprocessing (MNE)
# =============================================================================

def preprocess_eeg_mne(eeg_data: Dict[str, Any]) -> mne.io.Raw:
    """
    Minimal, readable preprocessing pipeline:
    - set montage (if possible)
    - high-pass 0.5 Hz
    - notch 50 Hz
    - average reference
    - bad channel detection via variance z-score
    - ICA to remove eye artifacts (Fp1/Fp2)
    """
    if "data" not in eeg_data or eeg_data["data"] is None or len(eeg_data["labels"]) == 0:
        raise RuntimeError("No EEG channels found to preprocess.")

    ch_names = eeg_data["labels"]
    sfreq = float(eeg_data["sampling_rate"])
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * len(ch_names))
    raw = mne.io.RawArray(eeg_data["data"].T, info)

    print("\n" + "=" * 70)
    print("EEG PREPROCESSING")
    print("=" * 70)
    print(f"Raw created: {raw.n_times} samples, {raw.info['nchan']} channels")

    # Montage (best-effort)
    try:
        raw.set_montage("standard_1020")
        print("✓ Montage: standard_1020")
    except Exception:
        print("⚠ Montage not set (channel names may not match 10-20).")

    # Filters
    raw.filter(l_freq=0.5, h_freq=None, verbose=False)
    raw.notch_filter(freqs=50, verbose=False)
    print("✓ Filters: high-pass 0.5 Hz, notch 50 Hz")

    # Re-reference
    raw.set_eeg_reference("average", verbose=False)
    print("✓ Reference: average")

    # Bad channel detection (variance z-score)
    channel_vars = np.var(raw.get_data(), axis=1)
    z_scores = zscore(channel_vars)
    bads = [ch for ch, z in zip(raw.ch_names, z_scores) if (z > 3) or (z < -3)]

    if bads:
        raw.info["bads"].extend(bads)
        print(f"✓ Bad channels: {bads} (interpolating)")
        raw.interpolate_bads(reset_bads=True)
    else:
        print("✓ Bad channels: none detected")

    # ICA
    n_components = min(20, raw.info["nchan"] - 1)
    ica = ICA(n_components=n_components, random_state=42, max_iter=800)
    ica.fit(raw)
    print(f"✓ ICA fitted ({n_components} components)")

    # EOG detection (best-effort)
    eog_channels = ["Fp1", "Fp2"]
    try:
        eog_idx, _scores = ica.find_bads_eog(raw, ch_name=eog_channels, threshold=2.5)
        if len(eog_idx) > 0:
            ica.exclude = list(eog_idx)
            print(f"✓ ICA exclude (EOG): {ica.exclude}")
        else:
            print("✓ ICA exclude (EOG): none")
    except Exception as e:
        print(f"⚠ EOG auto-detection failed: {e}")

    raw = ica.apply(raw)
    print(f"✓ ICA applied (removed {len(ica.exclude)} components)")
    return raw

In [7]:
# =============================================================================
# Segmentation + MAT saving
# =============================================================================

def build_clean_eeg_stream(raw: mne.io.Raw, eeg_stream: Dict[str, Any]) -> Dict[str, Any]:
    """
    Make a stream-like dict compatible with slice_stream_by_time().
    Uses original EEG timestamps for segmentation.
    """
    return {
        "time_series": raw.get_data().T,              # samples x channels
        "time_stamps": np.asarray(eeg_stream["time_stamps"]),
        "info": {"name": ["EEG_CLEAN"], "type": ["EEG"]},
    }

def save_segments(
    starts: Sequence[StartEvent],
    eeg_clean_stream: Dict[str, Any],
    eeg_labels: Optional[List[str]],
    fnirs_stream: Optional[Dict[str, Any]],
    fnirs_labels: Optional[List[str]],
    out_dir: str,
    participant_id: str,
    window_sec: float,
) -> None:
    if not starts:
        raise RuntimeError("No start markers found. Cannot segment.")

    os.makedirs(out_dir, exist_ok=True)
    print(f"\nFound {len(starts)} start markers. Segmenting {window_sec:.1f}s after each...")

    for seg_idx, ev in enumerate(starts, start=1):
        t0 = ev.t0
        t1 = t0 + window_sec

        eeg_seg, eeg_t = slice_stream_by_time(eeg_clean_stream, t0, t1)

        mdict: Dict[str, Any] = {
            "segment_index": np.array([[seg_idx]]),
            "attempt": np.array([[ev.attempt if ev.attempt is not None else -1]]),
            "user_id": np.array([[ev.user_id if ev.user_id is not None else -1]]),
            "t_start": np.array([[t0]]),
            "t_end": np.array([[t1]]),
            "marker": np.array([ev.label], dtype=object),
            "EEG": {
                "data": eeg_seg,
                "t": eeg_t,
                "chan_labels": np.array(eeg_labels or [], dtype=object),
                "stream_name": np.array([stream_name(eeg_clean_stream)], dtype=object),
                "stream_type": np.array([stream_type(eeg_clean_stream)], dtype=object),
            },
        }

        if fnirs_stream is not None:
            fnirs_seg, fnirs_t = slice_stream_by_time(fnirs_stream, t0, t1)
            mdict["fNIRS"] = {
                "data": fnirs_seg,
                "t": fnirs_t,
                "chan_labels": np.array(fnirs_labels or [], dtype=object),
                "stream_name": np.array([stream_name(fnirs_stream)], dtype=object),
                "stream_type": np.array([stream_type(fnirs_stream)], dtype=object),
            }

        if ev.attempt is not None:
            fname = f"P{participant_id}_round_{ev.attempt:02d}.mat"
        else:
            fname = f"P{participant_id}_segment_{seg_idx:02d}.mat"

        out_path = os.path.join(out_dir, fname)
        savemat(out_path, mdict, do_compression=True)

        msg = f"Saved {out_path} | EEG samples={len(eeg_t)}"
        if fnirs_stream is not None:
            msg += f" | fNIRS samples={len(fnirs_t)}"
        print(msg)

In [8]:
# =============================================================================
# Main
# =============================================================================

def run_pipeline(cfg: Config) -> None:
    if not os.path.exists(cfg.xdf_path):
        raise FileNotFoundError(f"XDF file not found: {cfg.xdf_path}")

    print(f"Loading XDF: {cfg.xdf_path}")
    streams, _header = pyxdf.load_xdf(cfg.xdf_path)
    print_streams(streams)

    markers_stream = pick_stream(streams, cfg.force_marker_name, "markers")
    eeg_stream = pick_stream(streams, cfg.force_eeg_name, "eeg")
    fnirs_stream = pick_stream(streams, cfg.force_fnirs_name, "fnirs")

    if markers_stream is None:
        raise RuntimeError("Could not detect marker stream. Set force_marker_name.")
    if eeg_stream is None:
        raise RuntimeError("Could not detect EEG stream. Set force_eeg_name.")

    print(f"\nUsing marker stream: name='{stream_name(markers_stream)}' type='{stream_type(markers_stream)}'")
    print(f"Using EEG stream:    name='{stream_name(eeg_stream)}' type='{stream_type(eeg_stream)}'")
    if fnirs_stream is not None:
        print(f"Using fNIRS stream:  name='{stream_name(fnirs_stream)}' type='{stream_type(fnirs_stream)}'")
    else:
        print("fNIRS stream:        NOT FOUND (EEG only)")

    # Markers → events
    starts = extract_start_events(
        markers_stream,
        user_id_filter=cfg.user_id_filter,
        keep_first_baseline=cfg.keep_first_baseline,
    )

    if not starts:
        # Debug: show unique markers
        uniq = sorted(set(flatten_marker_labels(markers_stream["time_series"])))
        print("\nNo relevant start markers found.")
        print("Unique markers (first 60):")
        for u in uniq[:60]:
            print(" ", u)
        return

    # Separate EEG vs AUX
    separated = separate_eeg_and_aux(eeg_stream)

    # Preprocess EEG
    raw_clean = preprocess_eeg_mne(separated.eeg)

    # Build clean stream for slicing (uses original timestamps)
    eeg_clean_stream = build_clean_eeg_stream(raw_clean, eeg_stream)

    # Save segments
    eeg_labels = get_channel_labels(eeg_stream)
    fnirs_labels = get_channel_labels(fnirs_stream)
    save_segments(
        starts=starts,
        eeg_clean_stream=eeg_clean_stream,
        eeg_labels=eeg_labels,
        fnirs_stream=fnirs_stream,
        fnirs_labels=fnirs_labels,
        out_dir=cfg.out_dir,
        participant_id=cfg.participant_id,
        window_sec=cfg.window_sec,
    )

    print("\nDone.")

In [9]:
cfg = Config(
    xdf_path="./Data/17_Mireia/sub-P17_ses-S001_task-Default_run-001_eeg.xdf",
    out_dir="./ISC_Analysis/Data",
    window_sec=58.0,
    participant_id="17",
    force_marker_name="Game_Markers",  # if you know it, set it
    force_eeg_name=None,
    force_fnirs_name=None,
    user_id_filter=None,              # set e.g. 99 to filter
    keep_first_baseline=True,
)
run_pipeline(cfg)


Loading XDF: ./Data/17_Mireia/sub-P17_ses-S001_task-Default_run-001_eeg.xdf

Streams found:
  [0] name='Photon_Cap_C2022044_STATS' type='NIRS' samples=8612
  [1] name='Photon_Cap_C2022044_RAW' type='NIRS' samples=8612
  [2] name='cortivision_markers_mirror' type='Markers' samples=0
  [3] name='actiCHampMarkers-24020270' type='Markers' samples=0
  [4] name='actiCHamp-24020270' type='EEG' samples=554000
  [5] name='Game_Markers' type='Markers' samples=32
  [6] name='FMS_Score' type='Survey' samples=10
  [7] name='Coaster' type='Object' samples=64481
  [8] name='HMD_MotionData' type='VR' samples=64480

Using marker stream: name='Game_Markers' type='Markers'
Using EEG stream:    name='actiCHamp-24020270' type='EEG'
Using fNIRS stream:  name='Photon_Cap_C2022044_STATS' type='NIRS'

Total channels found: 35
Sampling rate: 500.0 Hz
Found 3 AUX channels.
  AUX_1 (unit: microvolts) → EDA/GSR (by order)
  AUX_2 (unit: microvolts) → PPG (by order)
  AUX_3 (unit: microvolts) → RESP (by order)
Crea

In [10]:
cfg = Config(
    xdf_path="./Data/10_Javad/Subject10_Javad_Javad.xdf",
    out_dir="./ISC_Analysis/Data",
    window_sec=58.0,
    participant_id="10",
    force_marker_name="Game_Markers",  # if you know it, set it
    force_eeg_name=None,
    force_fnirs_name=None,
    user_id_filter=None,              # set e.g. 99 to filter
    keep_first_baseline=True,
)
run_pipeline(cfg)


Loading XDF: ./Data/10_Javad/Subject10_Javad_Javad.xdf

Streams found:
  [0] name='Photon_Cap_C2022044_RAW' type='NIRS' samples=12200
  [1] name='Photon_Cap_C2022044_STATS' type='NIRS' samples=12200
  [2] name='cortivision_markers_mirror' type='Markers' samples=0
  [3] name='actiCHampMarkers-24020270' type='Markers' samples=0
  [4] name='actiCHamp-24020270' type='EEG' samples=793744
  [5] name='Coaster' type='Object' samples=89166
  [6] name='FMS_Score' type='Survey' samples=24
  [7] name='Game_Markers' type='Markers' samples=30
  [8] name='HMD_MotionData' type='VR' samples=89166
  [9] name='Shimmer_8462' type='Sensor_Data' samples=196417

Using marker stream: name='Game_Markers' type='Markers'
Using EEG stream:    name='actiCHamp-24020270' type='EEG'
Using fNIRS stream:  name='Photon_Cap_C2022044_RAW' type='NIRS'

Total channels found: 35
Sampling rate: 500.0 Hz
Found 3 AUX channels.
  AUX_1 (unit: microvolts) → EDA/GSR (by order)
  AUX_2 (unit: microvolts) → PPG (by order)
  AUX_3 (u

In [11]:
cfg = Config(
    xdf_path="./Data/99_Sebastian/sub-P99_ses-S001_task-Default_run-001_eeg.xdf",
    out_dir="./ISC_Analysis/Data",
    window_sec=58.0,
    participant_id="99",
    force_marker_name="Game_Markers",  # if you know it, set it
    force_eeg_name=None,
    force_fnirs_name=None,
    user_id_filter=None,              # set e.g. 99 to filter
    keep_first_baseline=True,
)
run_pipeline(cfg)


Loading XDF: ./Data/99_Sebastian/sub-P99_ses-S001_task-Default_run-001_eeg.xdf

Streams found:
  [0] name='actiCHampMarkers-24020270' type='Markers' samples=0
  [1] name='actiCHamp-24020270' type='EEG' samples=447032
  [2] name='Coaster' type='Object' samples=51946
  [3] name='Game_Markers' type='Markers' samples=29
  [4] name='FMS_Score' type='Survey' samples=57
  [5] name='HMD_MotionData' type='VR' samples=51946
  [6] name='Shimmer_8462' type='Sensor_Data' samples=114435

Using marker stream: name='Game_Markers' type='Markers'
Using EEG stream:    name='actiCHamp-24020270' type='EEG'
fNIRS stream:        NOT FOUND (EEG only)

Total channels found: 35
Sampling rate: 500.0 Hz
Found 3 AUX channels.
  AUX_1 (unit: microvolts) → EDA/GSR (by order)
  AUX_2 (unit: microvolts) → PPG (by order)
  AUX_3 (unit: microvolts) → RESP (by order)
Creating RawArray with float64 data, n_channels=32, n_times=447032
    Range : 0 ... 447031 =      0.000 ...   894.062 secs
Ready.

EEG PREPROCESSING
Raw cr