In [2]:
!pip install mne

Defaulting to user installation because normal site-packages is not writeable


[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip





In [2]:
#!/usr/bin/env python3
"""
Optimized 3-Channel Event-Locked Feature Extraction (FIXED)
- Voltage scaled to microvolts before feature extraction
- Fractal dimension normalized to prevent scale explosion
- Enhanced artifact rejection (amplitude + kurtosis + edge effects)
- Faster feature computation with vectorization
"""

import numpy as np
import pandas as pd
import mne
from pathlib import Path
from scipy.signal import detrend, welch, hilbert
from scipy.stats import entropy as scipy_entropy, kurtosis
from mne.filter import filter_data
import warnings
warnings.filterwarnings('ignore')

# === CONFIG ===
DATA_ROOT = Path(r"C:\Users\rapol\Downloads\manifold\subjects")
SAVE_DIR = Path(r"C:\Users\rapol\Downloads\eeg_features_3ch_event_locked_optimized")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

SESSIONS = ["ses-S1", "ses-S2", "ses-S3"]
MUSE_CHANNELS = ['Fp1', 'Fp2', 'TP10']
BASELINE_WINDOW = (-0.5, 0.0)
TASK_WINDOW = (0.0, 2.0)
PERCENTILE = 95
KURTOSIS_THRESH = 5.0
EDGE_BUFFER = 0.5  # seconds

BANDS = {
    'delta': (0.5, 4),
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30),
    'gamma': (30, 45)
}

DISCRETE_TASKS = [
    'zeroBACK','oneBACK','twoBACK',
    'PVT','Flanker',
    'RS_Beg_EO','RS_Beg_EC','RS_End_EO','RS_End_EC'
]
CONTINUOUS_TASKS = ['MATBeasy','MATBmed','MATBdiff']

TASK_MAPPINGS = {
    "zeroBACK":"nback_0","oneBACK":"nback_1","twoBACK":"nback_2",
    "MATBeasy":"matb_easy","MATBmed":"matb_med","MATBdiff":"matb_diff",
    "PVT":"pvt","Flanker":"flanker",
    "RS_Beg_EO":"rest_begin_open","RS_Beg_EC":"rest_begin_closed",
    "RS_End_EO":"rest_end_open","RS_End_EC":"rest_end_closed"
}

print(f"Output: {SAVE_DIR.resolve()}\n")

# === UTILITIES ===

def _autoscale(data):
    """Autoscale data already in microvolts."""
    m = np.median(np.abs(data))
    if m < 1.0:  # Already scaled to µV, just return
        return data, 1.0
    elif m < 1000.0:
        return data, 1.0
    return data, 1.0

def _artifact_reject(baseline_data, task_data, thresh_amp):
    """Multi-criterion artifact rejection."""
    if baseline_data.max() > thresh_amp or task_data.max() > thresh_amp:
        return True, "amplitude"
    
    k_base = kurtosis(baseline_data, axis=1).max()
    k_task = kurtosis(task_data, axis=1).max()
    if k_base > KURTOSIS_THRESH or k_task > KURTOSIS_THRESH:
        return True, "kurtosis"
    
    return False, "none"

# === FEATURE COMPUTATION ===

def compute_features_batch(baseline_data, task_data, sfreq):
    """Compute all features. Data assumed to be in µV."""
    features = {}
    
    # 1. Bandpowers
    nperseg_base = min(256, max(16, baseline_data.shape[1] // 2))
    nperseg_task = min(256, max(16, task_data.shape[1] // 2))
    
    freqs, psd_base = welch(baseline_data, fs=sfreq, nperseg=nperseg_base)
    freqs_task, psd_task = welch(task_data, fs=sfreq, nperseg=nperseg_task)
    
    for band_name, (lo, hi) in BANDS.items():
        mask = (freqs >= lo) & (freqs <= hi)
        features[f'baseline_bp_{band_name}'] = psd_base[:, mask].mean(axis=1)
        mask_task = (freqs_task >= lo) & (freqs_task <= hi)
        features[f'task_bp_{band_name}'] = psd_task[:, mask_task].mean(axis=1)
    
    # 2. Spectral entropy
    for data, pfx in [(baseline_data, 'baseline'), (task_data, 'task')]:
        ent = []
        for ch in data:
            ch_norm = (ch - ch.mean()) / (ch.std() + 1e-10)
            _, p = welch(ch_norm, nperseg=min(256, len(ch)//4))
            p /= p.sum() + 1e-12
            ent.append(scipy_entropy(p))
        features[f'{pfx}_entropy'] = np.array(ent)
    
    # 3. Band entropy
    for band_name, (lo, hi) in BANDS.items():
        filt_base = filter_data(baseline_data, sfreq, lo, hi, verbose=False)
        filt_task = filter_data(task_data, sfreq, lo, hi, verbose=False)
        
        be_base = []
        be_task = []
        for ch_b, ch_t in zip(filt_base, filt_task):
            _, p_b = welch(ch_b, fs=sfreq, nperseg=min(128, len(ch_b)//8))
            p_b /= p_b.sum() + 1e-10
            be_base.append(scipy_entropy(p_b))
            
            _, p_t = welch(ch_t, fs=sfreq, nperseg=min(128, len(ch_t)//8))
            p_t /= p_t.sum() + 1e-10
            be_task.append(scipy_entropy(p_t))
        
        features[f'baseline_be_{band_name}'] = np.array(be_base)
        features[f'task_be_{band_name}'] = np.array(be_task)
    
    # 4. Fractal & Hjorth (NORMALIZED to prevent scale explosion)
    for data, pfx in [(baseline_data, 'baseline'), (task_data, 'task')]:
        fract_hjorth = []
        for ch in data:
            ch_det = detrend(ch)
            # Normalize to unit variance BEFORE fractal computation
            ch_det = (ch_det - ch_det.mean()) / (ch_det.std() + 1e-10)
            
            # Higuchi FD
            lk = []
            for k in range(1, 6):
                lm = 0
                nmax = int((len(ch_det) - 1) / k)
                for m in range(k):
                    for j in range(1, nmax):
                        lm += abs(ch_det[m + j*k] - ch_det[m + (j-1)*k])
                lk.append(lm * (len(ch_det) - 1) / (nmax * k))
            
            x = np.log(1 / np.arange(1, 6))
            y = np.log(lk)
            higuchi_fd = -np.polyfit(x, y, 1)[0]
            
            # Katz FD
            l = np.sum(np.abs(np.diff(ch_det)))
            d = np.max(np.abs(ch_det - ch_det[0]))
            katz_fd = np.log10(l / d) / np.log10(l / (d + 1e-10)) if d > 0 else 1.0
            
            # Hjorth
            d1 = np.diff(ch_det)
            d2 = np.diff(d1) if len(d1) > 1 else np.array([0.0])
            v0 = np.var(ch_det)
            v1 = np.var(d1) if len(d1) > 0 else 0.0
            v2 = np.var(d2) if len(d2) > 0 else 0.0
            
            mob = np.sqrt(v1 / (v0 + 1e-10))
            comp = np.sqrt(v2 / (v1 + 1e-10)) / (mob + 1e-10)
            
            fract_hjorth.append([higuchi_fd, katz_fd, v0, mob, comp])
        
        features[f'{pfx}_fh'] = np.array(fract_hjorth)
    
    # 5. PLI connectivity
    for data, pfx in [(baseline_data, 'baseline'), (task_data, 'task')]:
        filt = filter_data(data, sfreq, 4, 30, verbose=False)
        analytic = np.array([hilbert(ch) for ch in filt])
        pli = []
        for i, j in [(0, 1), (0, 2), (1, 2)]:
            pd = np.angle(analytic[i]) - np.angle(analytic[j])
            pli.append(abs(np.mean(np.sign(np.imag(np.exp(1j * pd))))))
        features[f'{pfx}_pli'] = np.array(pli)
    
    return features

def flatten_features(feature_dict):
    """Flatten nested feature dict to 1D array."""
    flat = []
    for v in feature_dict.values():
        if isinstance(v, np.ndarray):
            flat.extend(v.flatten())
        else:
            flat.append(v)
    return np.array(flat, dtype=float)

# === EXTRACTION FUNCTIONS ===

def extract_discrete(raw, task, subject, session):
    """Extract event-locked features for discrete tasks."""
    sfreq = raw.info['sfreq']
    chn = raw.ch_names
    av = [c for c in MUSE_CHANNELS if c in chn]
    
    if len(av) < 3:
        return []
    
    raw.pick_channels(av)
    # FIXED: Scale to microvolts FIRST
    data = raw.get_data() * 1e6
    data, scale = _autoscale(data)
    thresh = np.percentile(np.abs(data), PERCENTILE)
    
    edge_samp = int(EDGE_BUFFER * sfreq)
    bs = int(abs(BASELINE_WINDOW[0]) * sfreq)
    ts = int(TASK_WINDOW[1] * sfreq)
    
    print(f"{subject} {session} {task} thresh={thresh:.1f}µV")
    
    try:
        ev, _ = mne.events_from_annotations(raw, verbose=False)
    except:
        return []
    
    out = []
    rej_counts = {'amplitude': 0, 'kurtosis': 0, 'edge': 0, 'boundary': 0}
    tot = 0
    
    for idx, e in enumerate(ev):
        tot += 1
        o = e[0]
        
        if o - bs < 0 or o + ts > data.shape[1]:
            rej_counts['boundary'] += 1
            continue
        
        if o - bs < edge_samp or o + ts > data.shape[1] - edge_samp:
            rej_counts['edge'] += 1
            continue
        
        bd = data[:, o-bs:o]
        td = data[:, o:o+ts]
        
        is_artifact, reason = _artifact_reject(bd, td, thresh)
        if is_artifact:
            rej_counts[reason] += 1
            continue
        
        try:
            feat_dict = compute_features_batch(bd, td, sfreq)
            feats_flat = flatten_features(feat_dict)
            
            rec = {
                'subject': subject,
                'session': session,
                'task': task,
                'trial_idx': idx,
                'event_code': int(e[2]),
                'onset_sample': o,
                'onset_time': o / sfreq
            }
            
            for i, v in enumerate(feats_flat):
                rec[f'f{i}'] = v
            
            out.append(rec)
        except Exception as ex:
            print(f"  Feature error trial {idx}: {ex}")
            continue
    
    if sum(rej_counts.values()) > 0:
        print(f"  Rejected: {rej_counts} ({sum(rej_counts.values())}/{tot})")
    
    return out

def extract_continuous(raw, task, subject, session):
    """Extract windowed features for continuous tasks."""
    sfreq = raw.info['sfreq']
    chn = raw.ch_names
    av = [c for c in MUSE_CHANNELS if c in chn]
    
    if len(av) < 3:
        return []
    
    raw.pick_channels(av)
    # FIXED: Scale to microvolts FIRST
    data = raw.get_data() * 1e6
    data, scale = _autoscale(data)
    thresh = np.percentile(np.abs(data), PERCENTILE)
    
    edge_samp = int(EDGE_BUFFER * sfreq)
    ws = int(2.0 * sfreq)
    ov = int(0.5 * sfreq)
    step = ws - ov
    
    print(f"{subject} {session} {task} thresh={thresh:.1f}µV")
    
    out = []
    rej_counts = {'amplitude': 0, 'kurtosis': 0, 'edge': 0}
    tot = 0
    
    nw = (data.shape[1] - ws) // step + 1
    for w in range(nw):
        tot += 1
        s = w * step
        e = s + ws
        
        if e > data.shape[1]:
            break
        
        if s < edge_samp or e > data.shape[1] - edge_samp:
            rej_counts['edge'] += 1
            continue
        
        wd = data[:, s:e]
        
        is_artifact, reason = _artifact_reject(wd, wd, thresh)
        if is_artifact:
            rej_counts[reason] += 1
            continue
        
        try:
            mid = wd.shape[1] // 2
            feat_dict = compute_features_batch(wd[:, :mid], wd[:, mid:], sfreq)
            feats_flat = flatten_features(feat_dict)
            
            rec = {
                'subject': subject,
                'session': session,
                'task': task,
                'trial_idx': w,
                'event_code': -1,
                'onset_sample': s,
                'onset_time': s / sfreq
            }
            
            for i, v in enumerate(feats_flat):
                rec[f'f{i}'] = v
            
            out.append(rec)
        except Exception as ex:
            print(f"  Feature error window {w}: {ex}")
            continue
    
    if sum(rej_counts.values()) > 0:
        print(f"  Rejected: {rej_counts} ({sum(rej_counts.values())}/{tot})")
    
    return out

def process_session(subject, session):
    """Process all tasks for one subject-session."""
    eeg_dir = DATA_ROOT / subject / subject / session / "eeg"
    
    if not eeg_dir.exists():
        return None
    
    all_recs = []
    
    for fn, task in TASK_MAPPINGS.items():
        sf = eeg_dir / f"{fn}.set"
        if not sf.exists():
            continue
        
        try:
            raw = mne.io.read_raw_eeglab(str(sf), preload=True, verbose=False)
            
            if fn in DISCRETE_TASKS:
                recs = extract_discrete(raw, task, subject, session)
            else:
                recs = extract_continuous(raw, task, subject, session)
            
            all_recs.extend(recs)
            print(f"  {fn:15s}: {len(recs):4d}")
        except Exception as ex:
            print(f"  {fn:15s}: ERROR - {ex}")
            continue
    
    if all_recs:
        out_file = SAVE_DIR / f"{subject}_{session}_trials_event_locked.csv"
        pd.DataFrame(all_recs).to_csv(out_file, index=False)
        print(f"✓ {subject} {session}: {len(all_recs)} trials → {out_file.name}")
        return (subject, session, len(all_recs))
    
    return None

def main():
    job_queue = []
    for i in range(1, 22):
        subj = f"sub-{i:02d}"
        for sess in SESSIONS:
            job_queue.append((subj, sess))
    
    print(f"Processing {len(job_queue)} subject-session pairs...\n")
    
    results = []
    for i, (subject, session) in enumerate(job_queue, start=1):
        print(f"\n[{i}/{len(job_queue)}] {subject} {session}")
        result = process_session(subject, session)
        if result:
            results.append(result)
    
    results = [r for r in results if r is not None]
    
    print("\n" + "="*80)
    print(f"EXTRACTION COMPLETE: {len(results)} sessions processed")
    print(f"Total trials: {sum(r[2] for r in results)}")
    print(f"Output: {SAVE_DIR.resolve()}")
    print("="*80)

if __name__ == "__main__":
    main()

Output: C:\Users\rapol\Downloads\eeg_features_3ch_event_locked_optimized

Processing 63 subject-session pairs...


[1/63] sub-01 ses-S1
sub-01 ses-S1 nback_0 thresh=12529.0µV
  Rejected: {'amplitude': 0, 'kurtosis': 2, 'edge': 0, 'boundary': 3} (5/203)
  zeroBACK       :  198
sub-01 ses-S1 nback_1 thresh=14366.3µV
  Rejected: {'amplitude': 0, 'kurtosis': 0, 'edge': 0, 'boundary': 2} (2/198)
  oneBACK        :  196
sub-01 ses-S1 nback_2 thresh=14255.9µV
  Rejected: {'amplitude': 0, 'kurtosis': 7, 'edge': 0, 'boundary': 3} (10/208)
  twoBACK        :  198
sub-01 ses-S1 matb_easy thresh=14503.6µV
  Rejected: {'amplitude': 0, 'kurtosis': 0, 'edge': 2} (2/199)
  MATBeasy       :  197
sub-01 ses-S1 matb_med thresh=12209.6µV
  Rejected: {'amplitude': 0, 'kurtosis': 0, 'edge': 2} (2/199)
  MATBmed        :  197
sub-01 ses-S1 matb_diff thresh=10558.0µV
  Rejected: {'amplitude': 0, 'kurtosis': 2, 'edge': 2} (4/199)
  MATBdiff       :  195
sub-01 ses-S1 pvt thresh=13481.3µV
  Rejected: {'amplitud