# BART → P300 (Explosion-locked) ERP pipeline (XDF → CSV/XLSX)

This notebook:
- loads one or more **LabRecorder `.xdf`** files
- finds **EEG** + **BART_Markers** streams
- extracts **BART_EXPLODE** events and epochs EEG around them
- computes per-explosion **P300 amplitude + latency**
- exports `p300_explosions.csv` / `p300_explosions.xlsx` and a `p300_session_summary.csv`

**Best practice:** keep your BART task sending the `BART_EXPLODE` marker on the same frame as the BOOM/flash visual.


In [2]:
!pip -q install pyxdf mne openpyxl

import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import pyxdf
import mne

mne.set_log_level("WARNING")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m5.6/7.5 MB[0m [31m169.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.4/7.5 MB[0m [31m151.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m92.2 MB/s[0m eta [36m0:00:00[0m
[?25h

## Upload XDF files
Choose multiple `.xdf` files in the uploader (or drag-and-drop into the Files pane).


In [3]:
from pathlib import Path

# Find all .xdf files in /content
xdf_files = sorted([str(p) for p in Path("/content").glob("*.xdf")])

print("XDF files:", xdf_files)
assert len(xdf_files) > 0, "Upload at least one .xdf file to /content."

XDF files: []


AssertionError: Upload at least one .xdf file to /content.

## Configuration


In [None]:
# ---------------- USER CONFIG (ActiCAP 16ch montage) ----------------

# Streams
EEG_STREAM_NAME_CANDIDATES = ["openvibeSignal", "EEG", "ActiCAP", "BrainVision"]
MARKER_STREAM_NAME = "BART_Markers"

# Sampling (your cap/diagram uses 512 Hz; we still read nominal_srate from XDF if present)
EXPECTED_SFREQ = 512.0

# Epoching around explosions
TMIN = -0.200
TMAX =  0.800
BASELINE = (-0.200, 0.0)

# P300 windows
P300_WIN = (0.250, 0.500)        # peak window
P300_MEAN_WIN = (0.300, 0.450)   # mean window (often more stable)

# Channel map (1..16) — corrected labels where "2" in the drawing means "z"
#  1 Fz,  2 Cz,  3 Pz,  4 POz,
#  5 Fp1, 6 Fp2, 7 F3,  8 F4,
#  9 FCz, 10 C3, 11 C4, 12 CPz,
#  13 P3, 14 P4, 15 O1, 16 O2
ACTICAP_16_CH_NAMES = [
    "Fz", "Cz", "Pz", "POz",
    "Fp1", "Fp2", "F3", "F4",
    "FCz", "C3", "C4", "CPz",
    "P3", "P4", "O1", "O2"
]

# P300 channel preference / ROI
P300_CHANNEL_PREFERENCE = ["Pz", "CPz", "POz", "Cz", "P3", "P4"]
P300_ROI = ["Pz", "CPz", "P3", "P4", "POz"]

# Artifact rejection (simple peak-to-peak); set None to disable
REJECT_UV = 120.0
REJECT = None if REJECT_UV is None else dict(eeg=REJECT_UV * 1e-6)

print("Reject:", REJECT)
print("ActiCAP16 channel map:", ACTICAP_16_CH_NAMES)


## Helpers


In [None]:
def parse_bids_from_xdf_filename(xdf_path: str):
    # Parse sub/ses/run/task if filename is BIDS-like:
    # sub-P001_ses-S032_task-Default_run-001_eeg.xdf
    name = Path(xdf_path).name
    out = {"sub": None, "ses": None, "run": None, "task": None, "file": name}
    m = re.search(r"sub-([A-Za-z0-9]+)", name)
    if m: out["sub"] = m.group(1)
    m = re.search(r"ses-([A-Za-z0-9]+)", name)
    if m: out["ses"] = m.group(1)
    m = re.search(r"run-([A-Za-z0-9]+)", name)
    if m: out["run"] = m.group(1)
    m = re.search(r"task-([A-Za-z0-9]+)", name)
    if m: out["task"] = m.group(1)
    return out

def find_stream(streams, want_name=None, want_type=None):
    for s in streams:
        info = s["info"]
        name = info.get("name", [""])[0]
        stype = info.get("type", [""])[0]
        if want_name is not None and name == want_name:
            return s
        if want_type is not None and stype == want_type:
            return s
    return None

def find_eeg_stream(streams):
    for nm in EEG_STREAM_NAME_CANDIDATES:
        s = find_stream(streams, want_name=nm)
        if s is not None:
            return s
    s = find_stream(streams, want_type="EEG")
    if s is not None:
        return s
    # fallback: choose the largest numeric 2D stream
    best, best_n = None, -1
    for s in streams:
        ts = s.get("time_series", None)
        if ts is None:
            continue
        arr = np.asarray(ts)
        if arr.ndim == 2 and arr.shape[1] >= 4 and arr.shape[1] > best_n:
            best, best_n = s, arr.shape[1]
    return best

def parse_marker_strings(marker_stream):
    msgs = marker_stream["time_series"]
    ts = marker_stream["time_stamps"]
    out = []
    for t, m in zip(ts, msgs):
        if isinstance(m, (list, tuple, np.ndarray)):
            m = m[0] if len(m) else ""
        if isinstance(m, bytes):
            m = m.decode("utf-8", errors="ignore")
        out.append((float(t), str(m)))
    return out

def nearest_sample_index(eeg_t, event_t):
    return int(np.argmin(np.abs(eeg_t - event_t)))

def compute_p300_features(epoch_1d, sfreq):
    '''
    epoch_1d: (n_times,) in Volts

    Returns:
      peak_amp_V, peak_lat_s, mean_amp_V

    - peak is the max within P300_WIN
    - mean is the average within P300_MEAN_WIN (often more stable than a peak)
    '''
    t = np.arange(epoch_1d.size) / sfreq + TMIN

    # --- peak within P300_WIN ---
    w0, w1 = P300_WIN
    mask = (t >= w0) & (t <= w1)
    if not mask.any():
        peak_amp, peak_lat = np.nan, np.nan
    else:
        seg = epoch_1d[mask]
        i_peak = int(np.argmax(seg))
        peak_amp = float(seg[i_peak])
        peak_lat = float(t[mask][i_peak])

    # --- mean within P300_MEAN_WIN ---
    m0, m1 = P300_MEAN_WIN
    mmask = (t >= m0) & (t <= m1)
    mean_amp = float(np.mean(epoch_1d[mmask])) if mmask.any() else np.nan

    return peak_amp, peak_lat, mean_amp


## Extract explosions + compute P300 (per explosion)


In [None]:
rows = []

for xdf_path in xdf_files:
    bids = parse_bids_from_xdf_filename(xdf_path)
    print("\n=== Loading:", xdf_path, "===")
    streams, header = pyxdf.load_xdf(xdf_path)

    eeg_stream = find_eeg_stream(streams)
    marker_stream = find_stream(streams, want_name=MARKER_STREAM_NAME)

    if eeg_stream is None:
        print("⚠️ No EEG stream found. Skipping.")
        continue
    if marker_stream is None:
        print("⚠️ No marker stream named BART_Markers found. Skipping.")
        continue

    eeg = np.asarray(eeg_stream["time_series"])
    eeg_t = np.asarray(eeg_stream["time_stamps"])
    sfreq = float(eeg_stream["info"].get("nominal_srate", [0])[0] or 0)

    if not sfreq or sfreq <= 0:
        sfreq = 1.0 / np.median(np.diff(eeg_t))

    print("EEG shape:", eeg.shape, "sfreq~", sfreq)

    # Try to extract channel labels
    ch_names = None
    try:
        desc = eeg_stream["info"]["desc"][0]
        if "channels" in desc and "channel" in desc["channels"][0]:
            ch = desc["channels"][0]["channel"]
            labels = []
            for c in ch:
                if "label" in c and len(c["label"]):
                    labels.append(c["label"][0])
            if len(labels) == eeg.shape[1]:
                ch_names = labels
    except Exception:
        ch_names = None

    if ch_names is None:
        if eeg.shape[1] == 16:
            ch_names = ACTICAP_16_CH_NAMES
            print("✅ Applied ActiCAP 16-channel label map.")
        else:
            ch_names = [f"Ch{i+1}" for i in range(eeg.shape[1])]
            print("⚠️ No channel labels; using Ch1..")

    data = eeg.T.astype(float)  # (n_ch, n_times)

    # Heuristic: if values look like µV, convert to V
    if np.median(np.abs(data)) > 1e-3:
        data = data * 1e-6

    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
    raw = mne.io.RawArray(data, info, verbose="ERROR")

    markers = parse_marker_strings(marker_stream)
    explode = [(t, msg) for (t, msg) in markers if msg.startswith("BART_EXPLODE")]
    print("Explosions found:", len(explode))
    if len(explode) == 0:
        continue

    event_samps = [nearest_sample_index(eeg_t, t) for (t, _) in explode]
    events = np.column_stack([event_samps, np.zeros(len(event_samps), dtype=int), np.ones(len(event_samps), dtype=int)])

    epochs = mne.Epochs(raw, events, event_id={"explode": 1},
                        tmin=TMIN, tmax=TMAX, baseline=BASELINE,
                        reject=REJECT, preload=True, verbose="ERROR")

    available = set(epochs.ch_names)
    pick = next((ch for ch in P300_CHANNEL_PREFERENCE if ch in available), None)
    roi = [ch for ch in P300_ROI if ch in available]
    use_roi = (pick is None and len(roi) > 0)

    if pick is None and not use_roi:
        pick = "Cz" if "Cz" in available else epochs.ch_names[0]
        print("⚠️ Using fallback channel:", pick)

    for i, ep in enumerate(epochs):
        msg = explode[i][1]
        meta = {}
        parts = msg.split(";")
        for p in parts[1:]:
            if "=" in p:
                k, v = p.split("=", 1)
                meta[k.strip()] = v.strip()

        if use_roi:
            idxs = [epochs.ch_names.index(ch) for ch in roi]
            ep_1d = ep[idxs, :].mean(axis=0)
            ch_used = "ROI:" + ",".join(roi)
        else:
            ep_1d = ep[epochs.ch_names.index(pick), :]
            ch_used = pick

        peak_amp, peak_lat, mean_amp = compute_p300_features(ep_1d, sfreq)

        rows.append({
            **bids,
            "event_index": int(i),
            "event_time_lsl": float(explode[i][0]),
            "event_sample": int(event_samps[i]),
            "channel_used": ch_used,
            "p300_peak_amp_V": peak_amp,
            "p300_peak_amp_uV": (peak_amp * 1e6 if np.isfinite(peak_amp) else np.nan),
            "p300_peak_lat_s": peak_lat,
            "p300_mean_amp_V": mean_amp,
            "p300_mean_amp_uV": (mean_amp * 1e6 if np.isfinite(mean_amp) else np.nan),
            "block": meta.get("block", ""),
            "trial": meta.get("trial", ""),
            "pump": meta.get("pump", ""),
            "loss": meta.get("loss", ""),
            "total": meta.get("total", "")
        })

p300_df = pd.DataFrame(rows)
print("Rows:", len(p300_df))
p300_df.head()


## Export results


In [None]:
out_csv = "p300_explosions.csv"
out_xlsx = "p300_explosions.xlsx"

p300_df.to_csv(out_csv, index=False)

with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
    p300_df.to_excel(w, sheet_name="explosions", index=False)

print("Saved:", out_csv, out_xlsx)


## Session summary (for plotting across sessions)


In [None]:
if len(p300_df) == 0:
    session_df = pd.DataFrame()
else:
    session_df = (p300_df
                  .groupby(["sub","ses","run","task"], dropna=False)
                  .agg(
                      n_explosions=("p300_peak_amp_uV","size"),
                      p300_peak_amp_uV_mean=("p300_peak_amp_uV","mean"),
                      p300_peak_amp_uV_median=("p300_peak_amp_uV","median"),
                      p300_mean_amp_uV_mean=("p300_mean_amp_uV","mean"),
                      p300_mean_amp_uV_median=("p300_mean_amp_uV","median"),
                      p300_peak_lat_ms_mean=("p300_peak_lat_s", lambda x: np.nanmean(x)*1000.0),
                      p300_peak_lat_ms_median=("p300_peak_lat_s", lambda x: np.nanmedian(x)*1000.0),
                      file=("file","first"),
                      channel_used=("channel_used","first")
                  )
                  .reset_index())

session_df.to_csv("p300_session_summary.csv", index=False)
session_df.head()
