## Setup

In [None]:
# imports
from pathlib import Path
import re
import numpy as np
import pyvista as pv
import pandas as pd
from scipy.stats import ttest_rel
import mne
from src.MovieEEGSourcePipeline.source import _load_epochs, make_forward, make_inverse_from_baseline


DATA_DIR = Path("data/epochs")
SUBJECTS_DIR = Path("data")
FS_SUBJECT = "fsaverage"

# Use an ico-4 source space (coarse; appropriate for Yeo-7 parcellation)
FS_SRC_FNAME = SUBJECTS_DIR / FS_SUBJECT / "bem" / "fsaverage-ico-4-src.fif"
FS_BEM_FNAME = SUBJECTS_DIR / FS_SUBJECT / "bem" / "fsaverage-5120-5120-5120-bem-sol.fif"


def extract_stcs(
    epochs: mne.Epochs,
    inv: mne.minimum_norm.InverseOperator,
) -> list[mne.SourceEstimate]:
    stcs = mne.minimum_norm.apply_inverse_epochs(
        epochs,
        inverse_operator=inv,
        method="eLORETA",
        lambda2=1.0 / 9.0,
        pick_ori="normal",     # explicit orientation choice
        return_generator=False,
        verbose=False,
    )

    return stcs


def average_stcs(stcs: list[mne.SourceEstimate]) -> mne.SourceEstimate:
    if len(stcs) == 0:
        raise ValueError("No STCs to average.")
    # Average across epochs (time x vertices)
    stcs_avg = np.mean([stc.data for stc in stcs], axis=0)

    stc_mean = mne.SourceEstimate(
        stcs_avg,
        vertices=stcs[0].vertices,
        tmin=stcs[0].tmin,
        tstep=stcs[0].tstep,
        subject=stcs[0].subject,
    )
    return stc_mean


# Helpers
def run_source_localisation(data_dir, fs_subject, fs_src_fname, fs_bem_fname):

    # Pick a single example file to define channel set for forward model
    example = next(data_dir.glob("*_city_l_epo.fif"), None)
    if example is None:
        example = next(data_dir.glob("*_epo.fif"), None)
    if example is None:
        raise FileNotFoundError(f"No epoch files found in {data_dir} to build forward model.")
    fwd = make_forward(example, FS_SUBJECT=fs_subject, FS_SRC_FNAME=fs_src_fname, FS_BEM_FNAME=fs_bem_fname)

    inv_cache = {}  # subject -> inverse operator built from baseline1

    for epochs_path in sorted(data_dir.glob("*_epo.fif")):
        m = re.search(r"^(\d+)_([^_]+_[^_]+)_epo$", epochs_path.stem)
        if m is None:
            continue
        subject, film = m.groups()

        epochs = _load_epochs(epochs_path)

        out_dir = Path("data/stcs")
        out_dir.mkdir(parents=True, exist_ok=True)
        output_path_stcs = out_dir / f"{subject}_{film}_stcs.npz"

        if output_path_stcs.exists():
            print(f"STCs for {subject} {film} already exist, skipping.")
            continue

        # Ensure we have an inverse per subject (from baseline1)
        if subject not in inv_cache:
            # pick baseline portion (and avoid immediate pre-cut because of anticipatory activity)
            epochs_base = epochs.copy().crop(tmin=-0.2, tmax=-0.05)
            inv_cache[subject] = make_inverse_from_baseline(epochs_base, fwd)

        inv = inv_cache[subject]

        print(f">>>>>>>> {subject} {film}")
        stcs = extract_stcs(epochs, inv)
        stcs_avg = average_stcs(stcs)
        # save the averaged STC data in compressed format
        np.savez_compressed(
            output_path_stcs,
            data=stcs_avg.data,
            vertices=stcs_avg.vertices,
            tmin=stcs_avg.tmin,
            tstep=stcs_avg.tstep,
            subject=stcs_avg.subject,
        )

# open stcs files and average across subjects for each film, then save the averaged stc data

def load_all_stcs(stcs_dir=Path("data/stcs")):
    '''
    return scambled and linear
    '''
    cond_A_suffixes = ("city_nl", "art_nl")
    cond_B_suffixes = ("city_l",  "art_l")

    evoked_A = []
    evoked_B = []

    # collect subjects from filenames
    subjects = sorted({p.name.split("_")[0] for p in stcs_dir.glob("*_stcs.npz")})

    meta = None

    def _assert_same_meta(meta_ref, meta_new, label):
        if not all(np.array_equal(v, v0) for v, v0 in zip(meta_new["vertices"], meta_ref["vertices"])):
            raise ValueError(f"Vertices mismatch for {label}.")
        if meta_new["tmin"] != meta_ref["tmin"]:
            raise ValueError(f"tmin mismatch for {label}: {meta_new['tmin']} != {meta_ref['tmin']}")
        if meta_new["tstep"] != meta_ref["tstep"]:
            raise ValueError(f"tstep mismatch for {label}: {meta_new['tstep']} != {meta_ref['tstep']}")

    for sub in subjects:
        # scrambled
        files_A = [
            stcs_dir / f"{sub}_{c}_stcs.npz"
            for c in cond_A_suffixes
            if (stcs_dir / f"{sub}_{c}_stcs.npz").exists()
        ]

        # linear
        files_B = [
            stcs_dir / f"{sub}_{c}_stcs.npz"
            for c in cond_B_suffixes
            if (stcs_dir / f"{sub}_{c}_stcs.npz").exists()
        ]

        if not files_A or not files_B:
            print('skip incomplete subjects safely')
            continue

        # average within subject across films for each condition
        sub_A = []
        sub_B = []

        for fp in files_A:
            npz = np.load(fp, allow_pickle=True)
            sub_A.append(npz["data"])
            meta_new = {
                "vertices": npz["vertices"],
                "tmin": float(npz["tmin"]),
                "tstep": float(npz["tstep"]),
                "subject": str(npz["subject"])
            }
            if meta is None:
                meta = meta_new
            else:
                _assert_same_meta(meta, meta_new, f"{fp.name}")
        for fp in files_B:
            npz = np.load(fp, allow_pickle=True)
            sub_B.append(npz["data"])
            meta_new = {
                "vertices": npz["vertices"],
                "tmin": float(npz["tmin"]),
                "tstep": float(npz["tstep"]),
                "subject": str(npz["subject"])
            }
            if meta is None:
                meta = meta_new
            else:
                _assert_same_meta(meta, meta_new, f"{fp.name}")

        evoked_A.append(np.mean(sub_A, axis=0))
        evoked_B.append(np.mean(sub_B, axis=0))

    evoked_A = np.stack(evoked_A)
    evoked_B = np.stack(evoked_B)
    return evoked_A, evoked_B, meta


In [None]:
run_source_localisation(
        data_dir=DATA_DIR,
        fs_subject=FS_SUBJECT,
        fs_src_fname=FS_SRC_FNAME,
        fs_bem_fname=FS_BEM_FNAME,
    )

## Average STC data across subjects

In [56]:
evoked_A, evoked_B, meta = load_all_stcs()

GA_A = evoked_A.mean(0)
GA_B = evoked_B.mean(0)

if meta is None:
    raise RuntimeError("No subjects with complete conditions found.")

GA_A_stc = mne.SourceEstimate(
    GA_A,
    vertices=meta["vertices"].tolist(),
    tmin=meta["tmin"],
    tstep=meta["tstep"],
    subject=meta["subject"],
)

GA_B_stc = mne.SourceEstimate(
    GA_B,
    vertices=meta["vertices"].tolist(),
    tmin=meta["tmin"],
    tstep=meta["tstep"],
    subject=meta["subject"],
)

## Visualusation

In [None]:
# for t in [-0.2, 0.0, 0.1, 0.2, 0.4, 0.8]:
t = 0.5
stc_one = GA_A_stc.copy().crop(t, t)

brain = stc_one.plot(
    subject="fsaverage",
    subjects_dir=SUBJECTS_DIR,
    hemi="both",
    views="lateral",
    colorbar=True,
    time_viewer=True,
    backend="notebook" # for faster html saving
    )

plotter = brain._renderer.plotter

out_dir = Path("data/stc_vis")
out_dir.mkdir(parents=True, exist_ok=True)
out_html = out_dir / f"scrambled_{int(t*1000)}.html"

if hasattr(brain, "save_html"):
    brain.save_html(out_html, time_viewer=True)
else:
    plotter.export_html(str(out_html))

## compare STCs between conditions across subjects

In [84]:
# extract posterior labels and combine them
labels = mne.read_labels_from_annot(
    subject="fsaverage",
    parc="aparc",  # Desikan-Killiany
    subjects_dir=SUBJECTS_DIR
)

# occipital + cuneus + precuneus
posterior_labels = [
    label for label in labels
    if ("occipital" in label.name.lower()) or
    #    ("cuneus" in label.name.lower()) or
       ("precuneus" in label.name.lower())
]

if len(posterior_labels) == 0:
    raise RuntimeError("No posterior labels matched the selection rule.")

posterior_label = posterior_labels[0]
for lab in posterior_labels[1:]:
    posterior_label += lab

Reading labels from parcellation...
   read 35 labels from /Users/yeganeh/Codes/MovieEEG-SourcePipeline/data/fsaverage/label/lh.aparc.annot
   read 34 labels from /Users/yeganeh/Codes/MovieEEG-SourcePipeline/data/fsaverage/label/rh.aparc.annot


In [62]:
# stc_lin and stc_scr are subject-level averaged STCs
def create_sub_stcs(data, meta):
    return mne.SourceEstimate(
        data,
        vertices=meta["vertices"].tolist(),
        tmin=meta["tmin"],
        tstep=meta["tstep"],
        subject=meta["subject"]
        )

scrambled, linear, meta = load_all_stcs()
n_sub = scrambled.shape[0]
subject_results = []
times_of_interest = [0.3, 0.5, 0.7]
window = 0.05

In [86]:
for sub in range(n_sub):

    stc_lin = create_sub_stcs(linear[sub], meta)
    stc_scr = create_sub_stcs(scrambled[sub], meta)

    stc_lin_post = stc_lin.in_label(posterior_label)
    stc_scr_post = stc_scr.in_label(posterior_label)

    for t in times_of_interest:
        tmin = t - window
        tmax = t + window

        idx_min = stc_lin.time_as_index(tmin)[0]
        idx_max = stc_lin.time_as_index(tmax)[0]

        # Include tmax in the averaging window and avoid out-of-bounds slicing.
        idx_max = min(idx_max + 1, stc_lin_post.data.shape[1])
        if idx_min >= idx_max:
            raise RuntimeError(f"Empty time window for t={t}s (idx_min={idx_min}, idx_max={idx_max}).")

        lin_data = stc_lin_post.data[:, idx_min:idx_max]
        scr_data = stc_scr_post.data[:, idx_min:idx_max]

        # RMS across posterior vertices and window samples.

        # lin_rms = np.var(lin_data)
        # scr_rms = np.var(scr_data)
        lin_rms = np.sqrt(np.mean(lin_data ** 2))
        scr_rms = np.sqrt(np.mean(scr_data ** 2))

        subject_results.append({
            "subject": sub,
            "time": t,
            "linear_rms": lin_rms,
            "scrambled_rms": scr_rms,
        })

In [88]:
df = pd.DataFrame(subject_results)

results = []
for t in times_of_interest:
    sub_df = (
        df[df["time"] == t]
        .sort_values("subject")
        .dropna(subset=["linear_rms", "scrambled_rms"])
    )

    tstat, pval = ttest_rel(
        sub_df["linear_rms"].to_numpy(),
        sub_df["scrambled_rms"].to_numpy(),
    )
    diff = sub_df["linear_rms"] - sub_df["scrambled_rms"]
    d = (diff.mean()) / diff.std(ddof=1)

    results.append({"time": t, "n": len(sub_df), "t": tstat, "p": pval, "d": d})

# Bonferroni correction across the 3 planned timepoint tests.
m = len(results)
for r in results:
    r["p_bonf"] = min(r["p"] * m, 1.0)

for r in results:
    print(
        f"{int(r['time']*1000)} ms (n={r['n']}): "
        f"t={r['t']:.3f}, p={r['p']:.4f}, p_bonf={r['p_bonf']:.4f}"
        f", cohen_d={r['d']:.3f}"
    )

300 ms (n=306): t=-0.777, p=0.4379, p_bonf=1.0000, cohen_d=-0.044
500 ms (n=306): t=-1.164, p=0.2455, p_bonf=0.7366, cohen_d=-0.067
700 ms (n=306): t=-0.935, p=0.3507, p_bonf=1.0000, cohen_d=-0.053
