# N-back EEG — Minimal Pipeline
**What it does:** Load XDF → (optional) preprocess → align markers → infer block difficulties → write session marker + 5s difficulty segments → (optional) build epochs.

In [16]:

from pathlib import Path
import sys, re, importlib.util
# Helper
def R(name: str) -> Path:
    return RESULTS_DIR / name

# --- SETTINGS --------------------------------------------------------------
XDF_PATH     = Path(r"../../data/sub-P001_jannik/ses-S001/eeg/sub-P001_ses-S001_task-Default_run-001_eeg.xdf")
MARKERS_CSV  = Path(r"../../data/sub-P001_jannik/ses-S001/marker_log_p1s1_Jannik.csv")
BDE_PATH     = Path(r"../Block_difficulty_extractor.py")     # provides calculate_nvals(df)
PREPROC_MOD  = Path(r"../preprocess_raw.py")                 # optional; if present, will be used

DO_PREPROCESS = True      # set False to skip preprocess_raw
DO_EPOCHS     = True      # set False if you only want annotations

# Outputs
OUT_BASENAME = Path(r"../results").with_suffix("").name

# Where to save everything
RESULTS_DIR = (XDF_PATH.parent / "results").resolve()
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

OUT_BASENAME = XDF_PATH.with_suffix("").name

RAW_FIF = R(f"{OUT_BASENAME}_raw.fif")                # (if you ever save raw)
EPO_FIF = R(f"{OUT_BASENAME}_segments-epo.fif")       # 5s segments epochs




print("XDF:", XDF_PATH, "\nCSV:", MARKERS_CSV, "\nExtractor:", BDE_PATH)
assert XDF_PATH.exists() and MARKERS_CSV.exists() and BDE_PATH.exists(), "One of the input files is missing."


XDF: ..\..\data\sub-P001_jannik\ses-S001\eeg\sub-P001_ses-S001_task-Default_run-001_eeg.xdf 
CSV: ..\..\data\sub-P001_jannik\ses-S001\marker_log_p1s1_Jannik.csv 
Extractor: ..\Block_difficulty_extractor.py


In [17]:

import numpy as np, pandas as pd, mne, pyxdf

# Optional preprocessing module
preprocess_raw = None
if PREPROC_MOD.exists():
    spec = importlib.util.spec_from_file_location("preprocess_raw", str(PREPROC_MOD))
    preprocess_raw = importlib.util.module_from_spec(spec); spec.loader.exec_module(preprocess_raw)
    print("Loaded preprocess_raw from:", PREPROC_MOD)

# Block difficulty extractor
spec = importlib.util.spec_from_file_location("bde", str(BDE_PATH))
bde = importlib.util.module_from_spec(spec); sys.modules["bde"] = bde; spec.loader.exec_module(bde)
assert hasattr(bde, "calculate_nvals") and callable(bde.calculate_nvals), "Extractor must expose calculate_nvals(df)."


Loaded preprocess_raw from: ..\preprocess_raw.py


In [18]:

streams, header = pyxdf.load_xdf(str(XDF_PATH))

def _is_numeric(s): 
    ts = np.asarray(s.get("time_series")); return np.issubdtype(ts.dtype, np.number)
def _nchan(s):
    try:
        desc = s["info"].get("desc", [None])[0]
        return len(desc["channels"][0]["channel"]) if desc and "channels" in desc and desc["channels"] else 1
    except Exception:
        ts = np.asarray(s.get("time_series"), dtype=object)
        return 1 if ts.ndim==1 else ts.shape[1]
def _srate(s): 
    try: return float(s["info"]["nominal_srate"][0])
    except: return 0.0

eeg = [s for s in streams if s["info"].get("type",[""])[0].upper()=="EEG" and _is_numeric(s)]
if not eeg: eeg = [s for s in streams if _is_numeric(s) and _srate(s)>0 and _nchan(s)>=4]
if not eeg: eeg = [s for s in streams if _is_numeric(s)]
eeg = sorted(eeg, key=lambda s: (_nchan(s), _srate(s)), reverse=True)
eeg_stream = eeg[0]

data = np.asarray(eeg_stream["time_series"]); 
if data.ndim==1: data = data[:,None]
data = data.T
sfreq = _srate(eeg_stream)
if sfreq<=0:
    ts = np.asarray(eeg_stream.get("time_stamps")); sfreq = 1.0/np.median(np.diff(ts))

try:
    chs = [c["label"][0] for c in eeg_stream["info"]["desc"][0]["channels"][0]["channel"]]
    if len(chs)!=data.shape[0]: raise ValueError
except Exception:
    chs = [f"EEG{i+1}" for i in range(data.shape[0])]

info = mne.create_info(chs, sfreq=sfreq, ch_types="eeg")
raw  = mne.io.RawArray(data, info, verbose=False)
try: raw.set_montage("standard_1020", on_missing="ignore")
except: pass
print(raw)


<RawArray | 17 x 397344 (1589.4 s), ~51.6 MiB, data loaded>


In [19]:

if DO_PREPROCESS and preprocess_raw is not None:
    cfg = preprocess_raw.default_config()
    cfg["save_base"] = OUT_BASENAME
    raw, ica, rep = preprocess_raw.preprocess_raw(raw, cfg)
    print("ICA excluded:", rep["exclude"])
else:
    print("Skipping preprocess (set DO_PREPROCESS=True and provide preprocess_raw.py to enable).")


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 1651 samples (6.604 s)



Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 40 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.10
- Lower transition bandwidth: 0.10 Hz (-12 dB cutoff frequency: 0.05 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-12 dB cutoff frequency: 45.00 Hz)
- Filter length: 8251 samples (33.004 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Sampling frequency of the instance is already 250.0, returning unmodified.
Filtering raw data in 1 contiguous segment
Setting up high-pass filter at 1 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal highpass filter:
- Windowed time-domain design (firwin) me

In [20]:

df = pd.read_csv(MARKERS_CSV)
# normalize names
lower = {c.lower(): c for c in df.columns}
on_col  = lower.get("timestamp") or lower.get("onset") or lower.get("time")
desc_col= lower.get("marker") or lower.get("description") or lower.get("event")
assert on_col and desc_col, "CSV must have a time (timestamp/onset) and a marker/description column."
df = df.rename(columns={on_col:"timestamp", desc_col:"marker"})
df["marker"] = df["marker"].astype(str).str.strip()

# Align timestamps to Raw start
eeg_ts = np.asarray(eeg_stream.get("time_stamps"), dtype=float); t0 = float(eeg_ts[0]); raw_dur = float(raw.times[-1])
on = pd.to_numeric(df["timestamp"], errors="coerce").to_numpy(dtype=float)
rel = on - t0
if np.nanmax(rel) > raw_dur*10: rel = rel/1000.0 - (t0)  # naive ms guard
df["timestamp"] = rel

# Detect block starts -> forward-fill block_idx
blk_re = re.compile(r"(?i)\b(?:practice|main)_block_(\d+)_start\b")
starts = df["marker"].str.extract(blk_re)[0]
df["_block_idx"] = pd.to_numeric(starts, errors="coerce").ffill()

# Infer difficulties per block using your extractor
levels = bde.calculate_nvals(df[["marker","timestamp"]].copy())
block_order = df.loc[df["marker"].str.contains(blk_re), "_block_idx"].dropna().astype(int).unique().tolist()
# align
k = min(len(levels), len(block_order)); levels = list(map(int, levels[:k])); block_order = block_order[:k]
blocks_df = pd.DataFrame({"block_idx": block_order, "n": levels})
print(blocks_df)


   block_idx  n
0          0  0
1          1  3
2          2  1
3          3  2
4          4  2
5          5  1
6          6  3


  block_order = df.loc[df["marker"].str.contains(blk_re), "_block_idx"].dropna().astype(int).unique().tolist()


In [21]:

# Session marker from path
p = str(XDF_PATH)
session_label = "lab_noise" if re.search(r"ses[-_]?S001", p, flags=re.I) else ("gw2_noise" if re.search(r"ses[-_]?S002", p, flags=re.I) else "noise_unknown")

# Block time ranges from df
blk_times = (df.dropna(subset=["_block_idx"])
               .groupby("_block_idx")["timestamp"]
               .agg(["min","max"]).reset_index()
               .rename(columns={"_block_idx":"block_idx","min":"t_start","max":"t_end"}))

blk = blk_times.merge(blocks_df, on="block_idx", how="inner").sort_values("block_idx")

# Build 5s segments
seg_on, seg_dur, seg_desc = [], [], []
for _, r in blk.iterrows():
    n = int(r["n"]); t0 = max(float(r["t_start"]), 0.0); t1 = min(float(r["t_end"]), float(raw.times[-1]))
    t = t0
    while t < t1:
        t2 = min(t+5.0, t1)
        seg_on.append(t); seg_dur.append(t2-t); seg_desc.append(f"difficulty/{n}-back")
        t = t2

sess_ann = mne.Annotations(onset=[0.0], duration=[float(raw.times[-1])], description=[session_label], orig_time=None)
seg_ann  = mne.Annotations(onset=np.array(seg_on), duration=np.array(seg_dur), description=seg_desc, orig_time=None)

raw.set_annotations(sess_ann + seg_ann)
print("Session:", session_label, "| segments:", len(seg_ann))


Session: lab_noise | segments: 266


In [22]:

if DO_EPOCHS:
    desc = np.asarray(raw.annotations.description, dtype=str)
    present_n = sorted({int(d.split("/")[1].split("-")[0]) for d in desc if d.startswith("difficulty/")})
    event_id = {f"difficulty/{n}-back": 10+n for n in present_n}
    events, eid_used = mne.events_from_annotations(raw, event_id=event_id)
    epochs = mne.Epochs(raw, events, event_id=eid_used, tmin=0.0, tmax=5.0,
                        baseline=None, preload=True, reject_by_annotation=False)
    epochs.metadata = pd.DataFrame({"difficulty":[int([k for k,v in eid_used.items() if v==c][0].split('/')[1].split('-')[0]) for c in epochs.events[:,2]]})
    print(epochs)
    epochs.save(EPO_FIF, overwrite=True)
else:
    print("Skipping epochs (set DO_EPOCHS=True to build).")


Used Annotations descriptions: [np.str_('difficulty/0-back'), np.str_('difficulty/1-back'), np.str_('difficulty/2-back'), np.str_('difficulty/3-back')]
Not setting metadata
266 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 266 events and 1251 original time points ...
0 bad epochs dropped
Adding metadata with 1 columns
<Epochs | 266 events (all good), 0 – 5 s (baseline off), ~43.2 MiB, data loaded, with metadata,
 np.str_('difficulty/0-back'): 28
 np.str_('difficulty/1-back'): 84
 np.str_('difficulty/2-back'): 84
 np.str_('difficulty/3-back'): 70>


### ML (Random Forest): bandpower features → predict n-back
This cell extracts **bandpower features** per epoch (delta/theta/alpha/beta/gamma per channel) using Welch PSD, and trains a **RandomForestClassifier**. It reports stratified 5-fold CV accuracy, a confusion matrix, and top feature importances. If `epochs` isn't in memory, it loads `EPO_FIF`.

In [23]:

# --- Random Forest on bandpower features ---
import numpy as np, pandas as pd, mne, joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Keep only the first 8 channels for decoding
N = min(8, len(epochs.ch_names))
picks8 = epochs.ch_names[:N]
ep8 = epochs.copy().pick(picks=picks8)
print(f"Using {N} channels:", picks8)

epochs = ep8  # use only these channels


# 1) Load epochs if needed
if "epochs" not in globals():
    assert 'EPO_FIF' in globals(), "EPO_FIF path missing; run epoching or set EPO_FIF."
    print("Loading epochs from:", EPO_FIF)
    epochs = mne.read_epochs(EPO_FIF, preload=True)

assert epochs.metadata is not None and "difficulty" in epochs.metadata.columns, "epochs.metadata must contain 'difficulty'."
print(epochs)

# 2) Filter for stable bandpowers
ep_filt = epochs.copy().filter(1., 40., picks="eeg")

# 3) PSD per epoch (new API first; fallback to old if needed)
try:
    # MNE >= 1.2 style
    psd = ep_filt.compute_psd(
        method="welch",
        fmin=1., fmax=40.,
        n_fft=int(ep_filt.info['sfreq'] * 2),
        n_overlap=int(ep_filt.info['sfreq'] * 1),
        picks="eeg",
        verbose=False
    )
    psds, freqs = psd.get_data(return_freqs=True)  # (n_epochs, n_channels, n_freqs)
except Exception:
    # Older MNE fallback (only if your env still has it)
    from mne.time_frequency import psd_welch
    psds, freqs = psd_welch(
        ep_filt, fmin=1., fmax=40.,
        n_fft=int(ep_filt.info['sfreq'] * 2),
        n_overlap=int(ep_filt.info['sfreq'] * 1),
        picks="eeg", average="mean", n_per_seg=None, verbose=False
    )


# 4) Band definitions and integration
bands = {"delta": (1, 4), "theta": (4, 8), "alpha": (8, 13), "beta": (13, 30), "gamma": (30, 40)}
bin_mask = {b: (freqs >= lo) & (freqs < hi) for b, (lo, hi) in bands.items()}
total_pow = psds.sum(axis=2) + 1e-12  # avoid div by zero

# Build feature matrix: relative bandpower per channel
feat_list = []
col_names = []
for bi, (b, m) in enumerate(bin_mask.items()):
    bp = psds[:, :, m].sum(axis=2)  # (n_epochs, n_channels)
    rel = bp / total_pow
    feat_list.append(rel)
    col_names += [f"{ch}_{b}" for ch in ep_filt.ch_names]
X = np.concatenate(feat_list, axis=1)   # shape: (n_epochs, n_channels * n_bands)
y = ep_filt.metadata["difficulty"].astype(int).to_numpy()

print("Feature matrix:", X.shape, "| labels:", y.shape)

# 5) Random Forest
rf = RandomForestClassifier(
    n_estimators=400, max_depth=None, min_samples_split=2, min_samples_leaf=1,
    class_weight="balanced_subsample", random_state=42, n_jobs=-1
)

cv = StratifiedKFold(n_splits=min(5, np.bincount(y).min() if len(np.unique(y))>1 else 2),
                     shuffle=True, random_state=42)
scores = cross_val_score(rf, X, y, cv=cv, scoring="accuracy", n_jobs=-1)
print("RF CV accuracy:", np.round(scores, 3), " | mean±sd:", f"{scores.mean():.3f} ± {scores.std():.3f}")

y_pred = cross_val_predict(rf, X, y, cv=cv, n_jobs=-1)
print("\nClassification report:\n", classification_report(y, y_pred, digits=3))

cm = confusion_matrix(y, y_pred, labels=sorted(np.unique(y)))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

maj = np.bincount(y).argmax()
print(f"Majority-class baseline: {(y==maj).mean():.3f} (label={maj})")

# 6) Fit on full data and report top features
rf.fit(X, y)
importances = rf.feature_importances_
topk = np.argsort(importances)[::-1][:20]
top_table = pd.DataFrame({
    "feature": [col_names[i] for i in topk],
    "importance": importances[topk]
})
print("\nTop 20 features by importance:")
display(top_table)

# 7) Save artifacts (optional)
MODEL_PATH = EPO_FIF.with_suffix("").with_name(f"{OUT_BASENAME}_rf_bandpower.joblib") if 'OUT_BASENAME' in globals() else Path("rf_bandpower.joblib")
FEAT_PATH  = MODEL_PATH.with_suffix(".features.csv")
joblib.dump({"model": rf, "bands": bands, "ch_names": ep_filt.ch_names, "col_names": col_names}, MODEL_PATH)
pd.DataFrame(X, columns=col_names).assign(y=y).to_csv(FEAT_PATH, index=False)
print("Saved model to:", MODEL_PATH)
print("Saved features to:", FEAT_PATH)

# 8) Helper: function to transform new epochs -> features
def extract_bandpower_features(ep: mne.Epochs) -> np.ndarray:
    ep2 = ep.copy().filter(1., 40., picks="eeg")
    psd, fr = mne.time_frequency.psd_welch(
        ep2, fmin=1., fmax=40., n_fft=int(ep2.info['sfreq']*2), n_overlap=int(ep2.info['sfreq']*1),
        picks="eeg", average="mean", n_per_seg=None, verbose=False
    )
    bm = {b: (fr >= lo) & (fr < hi) for b, (lo, hi) in bands.items()}
    tp = psd.sum(axis=2) + 1e-12
    feats = []
    for b, m in bm.items():
        bp = psd[:, :, m].sum(axis=2)
        rel = bp / tp
        feats.append(rel)
    return np.concatenate(feats, axis=1)


Using 8 channels: ['EEG1', 'EEG2', 'EEG3', 'EEG4', 'EEG5', 'EEG6', 'EEG7', 'EEG8']
<Epochs | 266 events (all good), 0 – 5 s (baseline off), ~20.3 MiB, data loaded, with metadata,
 np.str_('difficulty/0-back'): 28
 np.str_('difficulty/1-back'): 84
 np.str_('difficulty/2-back'): 84
 np.str_('difficulty/3-back'): 70>
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 825 samples (3.300 s)



Feature matrix: (266, 40) | labels: (266,)
RF CV accuracy: [0.537 0.585 0.604 0.623 0.66 ]  | mean±sd: 0.602 ± 0.041

Classification report:
               precision    recall  f1-score   support

           0      0.611     0.393     0.478        28
           1      0.573     0.655     0.611        84
           2      0.701     0.560     0.623        84
           3      0.553     0.671     0.606        70

    accuracy                          0.602       266
   macro avg      0.610     0.570     0.580       266
weighted avg      0.612     0.602     0.600       266

Confusion matrix (rows=true, cols=pred):
 [[11  5  2 10]
 [ 1 55 13 15]
 [ 4 20 47 13]
 [ 2 16  5 47]]
Majority-class baseline: 0.316 (label=1)

Top 20 features by importance:


Unnamed: 0,feature,importance
0,EEG2_gamma,0.042043
1,EEG6_gamma,0.04149
2,EEG1_alpha,0.037714
3,EEG8_gamma,0.037671
4,EEG8_beta,0.03222
5,EEG8_theta,0.031024
6,EEG2_theta,0.030745
7,EEG2_beta,0.030556
8,EEG6_theta,0.030532
9,EEG7_theta,0.030349


Saved model to: C:\Users\janni\Documents\GitHub\eeg-brain-interface\data\sub-P001_jannik\ses-S001\eeg\results\sub-P001_ses-S001_task-Default_run-001_eeg_rf_bandpower.joblib
Saved features to: C:\Users\janni\Documents\GitHub\eeg-brain-interface\data\sub-P001_jannik\ses-S001\eeg\results\sub-P001_ses-S001_task-Default_run-001_eeg_rf_bandpower.features.csv
