# Multi-Participant MEG Preprocessing Pipeline


## IMPORTANT: In order to parallel run ICA and other parallelised processes, choose at least a 16-core machine. After finished, return to smaller machine. 

## Pipeline Steps:
1. Load all participant data
2. Preprocessing (filtering, artifact detection)
3. Mark bad channels per participant
4. Continue with epoching, evoked responses, and source reconstruction

**Note:** All code is kept as notebook cells - no functions or scripts

## 1. Setup

In [None]:
# Threading for parallel processing
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'


In [None]:
## INSTALLS

%pip install python-picard

In [1]:
## IMPORTS

from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
from mne.preprocessing import ICA
from joblib import Parallel, delayed
import mne
import glob
import os


In [3]:
## PATHS

data_folder = Path('/work/MEG_data/workshop_data')
subjects_dir = '/work/freesurfer'
behaviour_path = '/work/MEG_data/workshop_data/behavioural_logs'

In [None]:
## FIND ALL PARTICIPANTS - We'll loop through all participant folders and load their MEG data recursively

# Dictionary to store raw data for each participant
participant_data = {}

# Loop through each participant folder
for participant_folder in sorted(data_folder.iterdir()):
    if participant_folder.is_dir() and not participant_folder.name.startswith('.'):
        participant_id = participant_folder.name
        print(f"\n=== Loading {participant_id} ===")
        fif_files = sorted(participant_folder.rglob('*_raw.fif'))
        
        if fif_files:
            print(f"  Found {len(fif_files)} file(s)")
            
            raw_list = []
            for fif_file in fif_files:
                print(f"    Loading: {fif_file.name}")
                raw = mne.io.read_raw_fif(fif_file, preload=True)
                raw_list.append(raw)
            
            if len(raw_list) > 1:
                raw_combined = mne.concatenate_raws(raw_list)
                print(f"  Concatenated {len(raw_list)} files")
            else:
                raw_combined = raw_list[0]
            
            # storing in dict
            participant_data[participant_id] = raw_combined
        else:
            print(f"  No .fif files found")

print(f"\n\nTotal participants loaded: {len(participant_data)}")
print(f"Participant IDs: {list(participant_data.keys())}")

## 2. Inspecting and improving data quality

In [None]:
## INSPECT ONE PARTICIPANT FIRST

first_participant = list(participant_data.keys())[0]
raw_example = participant_data[first_participant]

print(f"Inspecting participant: {first_participant}")
print(f"Duration: {raw_example.times[-1]:.1f} seconds")
print(f"Sampling frequency: {raw_example.info['sfreq']} Hz")
print(f"Number of channels: {len(raw_example.ch_names)}")

In [None]:
## COMPUTE PSD FOR FIRST PARTICIPANT - HPI frequencies

raw_example.compute_psd().plot()
plt.suptitle(f'PSD - {first_participant} (BEFORE filtering)');

#### Applying filtering to remove HPI frequencies

(Typical HPI frequencies are around 150-350 Hz.)


In [None]:
## FILTER ALL PARTICIPANTS

for participant_id, raw in participant_data.items():
    print(f"Filtering {participant_id}...")
    raw.filter(l_freq=1, h_freq=40) 
    print(f"  Done!")

print("\nAll participants filtered!")

In [None]:
## CHECK FILTERING RESULT

# Plot PSD again for the first participant to verify filtering worked
raw_example.compute_psd().plot()
plt.suptitle(f'PSD - {first_participant} (AFTER filtering)');

#### Identifyng bad channels


In [None]:
## MANUAL INSPECTION - PARTICIPANT 1

# Get first participant
participant_ids = list(participant_data.keys())
current_id = participant_ids[0]
raw_current = participant_data[current_id]

print(f"Inspecting: {current_id}")
print(f"Currently marked bad: {raw_current.info['bads']}")
print("\nClick channel names to mark as bad. Close window when done.")

# plotting for manual inspection
raw_current.plot(duration=10.0, n_channels=30)

print(f"\nFinal bad channels for {current_id}: {raw_current.info['bads']}")

In [None]:
## MANUAL INSPECTION - PARTICIPANT 2

if len(participant_ids) > 1:
    current_id = participant_ids[1]
    raw_current = participant_data[current_id]
    
    print(f"Inspecting: {current_id}")
    print(f"Currently marked bad: {raw_current.info['bads']}")
    
    raw_current.plot(duration=10.0, n_channels=20)
    
    print(f"\nFinal bad channels for {current_id}: {raw_current.info['bads']}")
else:
    print("Only one participant in dataset")

In [None]:
## MANUAL INSPECTION - PARTICIPANT 3

if len(participant_ids) > 2:
    current_id = participant_ids[2]
    raw_current = participant_data[current_id]
    
    print(f"Inspecting: {current_id}")
    print(f"Currently marked bad: {raw_current.info['bads']}")
    
    raw_current.plot(duration=10.0, n_channels=30, scalings='auto', block=True)
    
    print(f"\nFinal bad channels for {current_id}: {raw_current.info['bads']}")

In [None]:
## LOOP THROUGH REMAINING PARTICIPANTS

for i, participant_id in enumerate(participant_ids[3:], start=4):
    raw_current = participant_data[participant_id]
    
    print(f"\n=== Inspecting Participant {i}: {participant_id} ===")
    print(f"Currently marked bad: {raw_current.info['bads']}")
    
    raw_current.plot(duration=10.0, n_channels=30, scalings='auto', block=True)
    
    print(f"Final bad channels: {raw_current.info['bads']}")

In [None]:
## FINAL BAD CHANNELS SUMMARY

print("=== Final Bad Channels Summary ===")
print("\n")
for participant_id, raw in participant_data.items():
    n_bad = len(raw.info['bads'])
    print(f"{participant_id}: {n_bad} bad channels")
    if n_bad > 0:
        print(f"  Channels: {raw.info['bads']}")
    print()

#### Annotated bad channels are described below, run this if you don't want to go through the manula process again. 

In [None]:

# Define bad channels dictionary
bad_channels_dict = {
    '0164': ['MEG0321'],
    '0170': ['MEG0423', 'MEG1443', 'MEG1922', 'MEG1933', 'MEG2621']
}

# applying to each participant's raw data
for participant_id, bad_list in bad_channels_dict.items():
    if participant_id in participant_data:
        raw = participant_data[participant_id]

        raw.info['bads'] = bad_list
        print(f"Marked bad channels for {participant_id}: {bad_list}")

#### Running ICA + interpolation

In [None]:
## ICA with parallelisation 

# ocnfig
DRIFT_SEC = 2.0          # exclude first 2 s from detection
N_JOBS    = 8             # 8 participants
SAVE_DIR  = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/ICA_cleaned"
os.makedirs(SAVE_DIR, exist_ok=True)


# core function
def run_ica(raw, n_components=0.95, decim=3, rng=97):
    """
    Fit ICA on drift-free data, detect EOG/ECG on the same window, apply to full raw.
    """
    raw_fit = raw.copy().crop(tmin=DRIFT_SEC)
    picks = mne.pick_types(raw_fit.info, meg=True, eog=False, ecg=False, exclude='bads')

    ica = ICA(n_components=n_components, method='picard',
              random_state=rng, max_iter='auto')
    ica.fit(raw_fit, picks=picks, decim=decim,
            reject=dict(mag=5e-12, grad=4000e-13))

    # detecting artifacts
    raw_det = raw.copy().crop(tmin=DRIFT_SEC)

    # EOG (blinks + eye movements)
    eog_inds, _ = ica.find_bads_eog(raw_det)

    # ECG (heartbeat)
    try:
        ecg_inds, _ = ica.find_bads_ecg(raw_det, method='ctps')
    except Exception:
        ecg_inds, _ = ica.find_bads_ecg(raw_det, method='correlation')

    exclude = sorted(set(list(eog_inds) + list(ecg_inds)))
    ica.exclude = exclude

    # applying to full recording and interpolating bads 
    raw_clean = ica.apply(raw.copy())
    raw_clean.interpolate_bads(reset_bads=True)

    return ica, raw_clean, exclude


In [None]:
# paralellisation
def run_all_ica(participant_data, n_jobs=N_JOBS):
    """Run ICA for all participants in parallel."""
    results = Parallel(n_jobs=n_jobs)(
        delayed(run_ica)(raw) for raw in participant_data.values()
    )
    return results

In [None]:
# pplying ICA to every participant
ica_results = run_all_ica(participant_data, n_jobs=N_JOBS)

participant_ica = {}
for (pid, (ica, raw_clean, exclude)) in zip(participant_data.keys(), ica_results):
    participant_data[pid] = raw_clean
    participant_ica[pid] = ica
    print(f"[{pid}] excluded components: {exclude}")

    raw_clean.save(f"{SAVE_DIR}/{pid}_cleaned_raw.fif", overwrite=True)
    ica.save(f"{SAVE_DIR}/{pid}_ica.fif")

print("\n✓ ICA complete and files saved.")


In [None]:
## getting number, mean and sd of removed ICA channels (for report)

ICA_DIR = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/ICA_cleaned"

rows = []

# inding all ICA files 
ica_files = sorted(glob.glob(os.path.join(ICA_DIR, "*_ica.fif")))
if not ica_files:
    raise FileNotFoundError("No *_ica.fif files found in ICA_cleaned directory.")

for f in ica_files:
    pid = os.path.basename(f).split("_")[0]

    try:
        ica = mne.preprocessing.read_ica(f)
        excluded = ica.exclude
        n_excluded = len(excluded)
    except Exception as e:
        print(f"Could not read {f}: {e}")
        continue

    rows.append({"pid": pid, "n_excluded": n_excluded})

df_ica = pd.DataFrame(rows).sort_values("pid").reset_index(drop=True)

print(df_ica)
print("\nTotal ICA components removed:", df_ica["n_excluded"].sum())
print("Mean =", df_ica["n_excluded"].mean())
print("SD   =", df_ica["n_excluded"].std(ddof=1))


##### Next we are doing some quality checks to ensure that the ICA appropriately filtered both eye blinks and heartbeat artifacts. 

In [None]:
# component summary 
for pid, ica in participant_ica.items():
    print(pid, "components:", ica.n_components_, "excluded:", ica.exclude)


In [None]:
# quick quality check


def qc_eog_ecg(pid, raw_clean, ica):
    rd = raw_clean.copy().crop(tmin=DRIFT_SEC)
    eog_inds, eog_scores = ica.find_bads_eog(rd)
    ecg_inds, ecg_scores = ica.find_bads_ecg(rd, method='correlation')

    eog_scores = np.asarray(eog_scores, float)
    ecg_scores = np.asarray(ecg_scores, float)
    if eog_scores.ndim > 1:
        eog_scores = eog_scores.max(axis=1)
    if ecg_scores.ndim > 1:
        ecg_scores = ecg_scores.max(axis=1)

    return pid, (eog_scores.max() if eog_scores.size else 0.0), (ecg_scores.max() if ecg_scores.size else 0.0)

qc_results = Parallel(n_jobs=N_JOBS)(
    delayed(qc_eog_ecg)(pid, participant_data[pid], participant_ica[pid])
    for pid in participant_data.keys()
)

for pid, max_eog, max_ecg in qc_results:
    print(f"{pid}  max EOG={max_eog:.3f},  max ECG={max_ecg:.3f}")

# saving qc table 
qc_df = pd.DataFrame(qc_results, columns=["participant", "max_EOG", "max_ECG"])
qc_df.to_csv(f"{SAVE_DIR}/qc_scores.csv", index=False)
print(f"\n✓ QC scores saved to {SAVE_DIR}/qc_scores.csv")


In [None]:
# checking one participant 

pid = '0168'
ica = participant_ica[pid]
raw = participant_data[pid]
ica.plot_components()
ica.plot_sources(raw)


## 3. Finding events

Extract event triggers from the data for all participants

Events themselves are as follows: 

- Epochs for all participants
event_id = {
    "stimulus_0": 1, 
    "stimulus_1": 3, 
    "mask": 4,
    "response_stimulus_0": 6, 
    "response_stimulus_1": 8,
    "response_PAS_1": 10, 
    "response_PAS_2": 12, 
    "response_PAS_3": 14, 
    "response_PAS_4": 16
    add back if present: "response_PAS_4": 16, "response_auto": 32, "response_PAS_auto": 64,
}


In [None]:

### SETUP 


# directories
DRIFT_SEC  = 2.0
SAVE_DIR   = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds"
GET_DIR    = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/ICA_cleaned"
EVOKED_DIR = f"{SAVE_DIR}/evokeds"
EVENTS_DIR = f"{SAVE_DIR}/events"
EPOCHS_DIR = f"{SAVE_DIR}/epochs"
os.makedirs(EVENTS_DIR, exist_ok=True)
os.makedirs(EPOCHS_DIR, exist_ok=True)
os.makedirs(EVOKED_DIR, exist_ok=True)


# giving some bounds 
reject = dict(mag=4e-12, grad=4000e-13) # 4k fT and 4k ft/cm

# Neutral stimulus codes (fix for left/rigth if matters, leave as is if not)
STIM_CODES = {1, 3}

# PAS response triggers → PAS labels (PAS-4 may be missing per subject)
PAS_CODE_TO_LABEL = {10: 1, 12: 2, 14: 3, 16: 4}
PAS_LABELS = [1, 2, 3, 4]

# Stimulus-locked window/baseline
tmin, tmax = -0.2, 0.75
baseline   = (None, 0)


In [None]:
### LOAD IN RAWS + FIND EVENTS

participant_data, participant_events = {}, {}
for f in sorted(os.listdir(GET_DIR)):
    if f.endswith("_cleaned_raw.fif"):
        pid = f.split("_cleaned_raw.fif")[0]
        raw = mne.io.read_raw_fif(os.path.join(GET_DIR, f), preload=True)
        participant_data[pid] = raw

        ev = mne.find_events(raw, stim_channel="STI101", shortest_event=1, min_duration=0.002)
        ev = ev[ev[:, 0] >= raw.time_as_index(DRIFT_SEC)[0]]
        participant_events[pid] = ev
        mne.write_events(f"{EVENTS_DIR}/{pid}-eve.fif", ev, overwrite=True)

print("Participants:", list(participant_data.keys()))

In [None]:
### HELPER FUNCTION:
# assigning each stimulus the next PAS presented in that trial -> associates each brain activity with a specific PAS score, so we can plot them 

# returns: stimulus-only events and a metadata df with a PAS column.
def label_stim_with_pas(events, pas_map):
    stim_idx = np.where(np.isin(events[:, 2], list(STIM_CODES)))[0]
    pas_idx  = np.where(np.isin(events[:, 2], list(pas_map.keys())))[0]

    stim_pas = []
    for si in stim_idx:
        nxt = pas_idx[pas_idx > si]
        if len(nxt) == 0:
            stim_pas.append(np.nan)
        else:
            stim_pas.append(pas_map.get(events[nxt[0], 2], np.nan))

    stim_events = events[stim_idx]
    meta = pd.DataFrame({"PAS": stim_pas})
    return stim_events, meta

In [None]:
### PLOTTING EVENTS 

first_pid = list(participant_events.keys())[0]
ev = participant_events[first_pid]
fig = mne.viz.plot_events(ev, sfreq=participant_data[first_pid].info['sfreq'],
                          first_samp=participant_data[first_pid].first_samp)
plt.title(f'Events - {first_pid}')

## 4. Creating epochs 


In [None]:
### EPOCHING + ARTEFACT REJECTION


# epochs will include PAS metadata from here

participant_epochs_clean = {}

for pid, raw in participant_data.items():
    events  = participant_events[pid]
    present = set(np.unique(events[:, 2]))

    if not (present & STIM_CODES):
        print(f"{pid}: no stimulus codes → skip")
        continue

    # per-subject PAS map (this handles missing PAS-4 so we don't get errors!)
    pas_map = {c: PAS_CODE_TO_LABEL[c] for c in PAS_CODE_TO_LABEL if c in present}

    stim_events, metadata = label_stim_with_pas(events, pas_map)
    if len(stim_events) == 0:
        print(f"{pid}: no stimulus events after filter → skip")
        continue

    epochs = mne.Epochs(
        raw,
        stim_events,
        event_id={"stimulus_code1": 1, "stimulus_code3": 3},  # change labels here if we know which was left and right, otherwise keep this way:)
        tmin=tmin, tmax=tmax,
        baseline=baseline,
        preload=True,
        reject_by_annotation=True,
        on_missing="ignore",
    )

    # attaching PAS labels and keeping only trials with a PAS
    epochs.metadata = metadata
    epochs = epochs[epochs.metadata["PAS"].notna().values]

    # epoch-level p2p rejection
    epochs_clean = epochs.copy()
    epochs_clean.drop_bad(reject=reject) # e defined reject in setup chunk

    # saving stuff
    epochs_clean.save(f"{EPOCHS_DIR}/{pid}-epo_stim_withPAS_clean.fif", overwrite=True) # naming is essentially: epoch for stimuli with PAS score added and cleaned
    participant_epochs_clean[pid] = epochs_clean

    print(f"{pid}: kept {len(epochs_clean)}/{len(ep)} PAS-labeled epochs after p2p rejection")
    print("    PAS counts:", epochs_clean.metadata["PAS"].value_counts(dropna=False).sort_index().to_dict())



In [None]:
## PAS scale double check 


# having a look at how many epochs were dropped and the distribution of the PAS ratings

rows = []
for pid, ep in participant_epochs_clean.items():
    counts = ep.metadata["PAS"].value_counts().sort_index().to_dict()
    rows.append({"pid": pid, **{f"PAS{int(k)}": int(v) for k, v in counts.items()}})

pas_counts_df = pd.DataFrame(rows).fillna(0).astype({"pid": str})
print(pas_counts_df.to_string(index=False))

In [None]:
## this is also for the report for epochs 
EPOCH_DIR = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds/epochs"

rows = []

for f in sorted(glob.glob(os.path.join(EPOCH_DIR, "*_clean.fif"))):
    epochs = mne.read_epochs(f, preload=False)
    pid = os.path.basename(f).split("-epo")[0]

    kept = len(epochs)
    
    drop_log = epochs.drop_log
    dropped = sum([1 for d in drop_log if len(d) > 0])
    
    # ejection percentage 
    total = kept + dropped
    rej_pct = 100 * dropped / total if total > 0 else np.nan

    rows.append({"pid": pid, "kept": kept, "dropped": dropped, "total": total, "rej_pct": rej_pct})

rows

In [None]:
## creating df from the stats above 

df = pd.DataFrame(rows)
print(df)

print("\nMean rejection rate: ", df["rej_pct"].mean())
print("SD rejection rate:   ", df["rej_pct"].std(ddof=1))

## 5. Computing evokeds

Here, we will 
1. Compute evokeds
2. plot per-participant overlay plots
3. plot group overlay 


In [None]:

EPOCHS_DIR = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds/epochs"
EVOKED_DIR  = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds/evokeds/pas_1_4"
os.makedirs(EVOKED_DIR, exist_ok=True)

participant_evokeds_by_pas = {}  # {pid: {"PAS1": path, "PAS2": path, "PAS3": path, "PAS4": path}}

for fif_file in sorted(glob.glob(os.path.join(EPOCHS_DIR, "*-epo_stim_withPAS_clean.fif"))):
    pid = os.path.basename(fif_file).split("-epo")[0]
    epochs = mne.read_epochs(fif_file, preload=True, verbose=False)
    if epochs.metadata is None or "PAS" not in epochs.metadata.columns:
        print(f"→ {pid}: no PAS metadata; skipping.")
        continue

    pas_vals = pd.to_numeric(epochs.metadata["PAS"], errors="coerce")
    masks = {f"PAS{i}": (pas_vals == i).to_numpy() for i in (1,2,3,4)}

    ev_paths = {}
    for label, mask in masks.items():
        if not mask.any():
            continue
        sel = epochs[mask]
        if len(sel) == 0:
            continue
        ev = sel.average()
        ev.comment = label
        out_f = os.path.join(EVOKED_DIR, f"{pid}-{label}-ave.fif")
        ev.save(out_f, overwrite=True)
        ev_paths[label] = out_f
        print(f"✓ {pid}: saved {label} → {out_f}")

    participant_evokeds_by_pas[pid] = ev_paths

print("\nSaved per-PAS evokeds to:", EVOKED_DIR)


In [None]:

LABELS = ["PAS1", "PAS2", "PAS3", "PAS4"]

ga_by_label = {}
n_subj_by_label = {}

for lab in LABELS:
    ev_list = []
    for pid, mapping in participant_evokeds_by_pas.items():
        path = mapping.get(lab)
        if not path:
            continue
        ev = mne.read_evokeds(path, condition=0, verbose=False)
        ev.pick_types(meg=True, eeg=False, exclude=[]) 
        ev_list.append(ev)

    print(f"{lab}: collected {len(ev_list)} evokeds")
    if not ev_list:
        continue

    try:
        mne.channels.equalize_channels(ev_list)
    except Exception as e:
        print(f"[WARN] {lab}: equalize_channels failed: {e} — skipping GA.")
        continue

    # ensure identical time base
    t0 = ev_list[0].times
    if not all(np.array_equal(t0, e.times) for e in ev_list):
        print(f"[WARN] {lab}: time vectors differ — skipping GA.")
        continue

    ga = mne.grand_average(ev_list, interpolate_bads=True)
    ga_by_label[lab] = ga
    n_subj_by_label[lab] = len(ev_list)

# plotting
out_dir = f"{EVOKED_DIR}/figs"; os.makedirs(out_dir, exist_ok=True)
plt.figure(figsize=(8,4))
plotted = False

for lab in LABELS:
    ga = ga_by_label.get(lab)
    if ga is None:
        continue
    data = ga.get_data()
    if data.size == 0:
        print(f"[WARN] {lab}: empty GA after picks.")
        continue
    gfp = np.sqrt((data**2).mean(axis=0))
    plt.plot(ga.times, gfp, label=f"{lab} (N={n_subj_by_label[lab]})")
    plotted = True

plt.axvline(0, ls="--", color="k", lw=1)
plt.xlabel("Time (s)"); plt.ylabel("Global Field Power (a.u.)")
plt.title("Group — stimulus-locked GFP (PAS1–PAS4)")
plt.legend(frameon=False); plt.tight_layout()

png = os.path.join(out_dir, "GROUP_evoked_PAS1_4_overlay.png")
plt.savefig(png, dpi=150); plt.close()
print("Saved group figure:", png, "| drew lines:", plotted)


#### Following group discussion: collapsing evokeds PAS 3 and 4 
Note: the 4-level PAS scale will be kept for the plots and to give a good background for why we do the collapsing
Reasoning:
- scarcity of PAS 4
- baseline issues when plotting evokeds

In [None]:
# making collapsed epoch 

EPOCHS_DIR = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds/epochs"
EVOKED_COLLAPSED_DIR = "/work/GrétaHarsányi#3675/Assignment2/2025Neuro/events_epochs_evokeds/evokeds/collapsed_evk"
os.makedirs(EVOKED_COLLAPSED_DIR, exist_ok=True)

participant_evokeds_by_pas = {}  # {pid: {"PAS1": path, "PAS2": path, "PAS3_4": path}}

# oading participants 
epoch_files = sorted(glob.glob(os.path.join(EPOCHS_DIR, "*-epo_stim_withPAS_clean.fif")))

for fif_file in epoch_files:
    pid = os.path.basename(fif_file).split("-epo")[0]

    # loading epochs
    epochs = mne.read_epochs(fif_file, preload=True, verbose=False)
    if epochs.metadata is None or "PAS" not in epochs.metadata.columns:
        print(f"→ {pid}: no PAS metadata, skipping.")
        continue

    # collapse PAS at epoch level: 1, 2, (3 or 4 → 3_4)
    pas_vals = epochs.metadata["PAS"].astype(float)
    masks = {
        "PAS1":  pas_vals == 1,
        "PAS2":  pas_vals == 2,
        "PAS3_4": pas_vals.isin([3, 4]),
    }

    ev_paths = {}
    for label, mask in masks.items():
        mask = mask.to_numpy()
        if not mask.any():
            continue
        sel = epochs[mask]
        if len(sel) == 0:
            continue
        ev = sel.average()
        ev.comment = label
        out_f = os.path.join(EVOKED_COLLAPSED_DIR, f"{pid}-{label}-ave.fif")
        ev.save(out_f, overwrite=True)
        ev_paths[label] = out_f
        print(f"{pid}: saved {label} → {out_f}")

    if not ev_paths:
        print(f"→ {pid}: no PAS1/2/3_4 trials found after collapsing.")
    participant_evokeds_by_pas[pid] = ev_paths

print("\nSaved collapsed evokeds to:", EVOKED_COLLAPSED_DIR)


In [None]:

### Plotting per-participant GFP overlays

out_dir = f"{EVOKED_COLLAPSED_DIR}/figs"
os.makedirs(out_dir, exist_ok=True)

for pid, mapping in participant_evokeds_by_pas.items():
    if not mapping:
        print(f"{pid}: no evokeds — skipping.")
        continue

    plt.figure(figsize=(8,4))
    for pas in sorted(mapping.keys()):
        ev = mne.read_evokeds(mapping[pas], condition=0, verbose=False).pick_types(meg=True)
        data = ev.get_data()
        gfp  = np.sqrt((data**2).mean(axis=0))
        plt.plot(ev.times, gfp, label=f"PAS {pas}")

    plt.axvline(0, ls="--", color="k", lw=1)
    plt.xlabel("Time (s)"); plt.ylabel("Global Field Power (a.u.)")
    plt.title(f"{pid} — stimulus-locked GFP by PAS")
    plt.legend(frameon=False); plt.tight_layout()

    out_png = os.path.join(out_dir, f"{pid}_evoked_PAS_overlay.png")
    plt.savefig(out_png, dpi=150); plt.close()
    print("Saved:", out_png)


In [None]:


LABELS = ["PAS1", "PAS2", "PAS3_4"]

ga_by_label = {}
n_subj_by_label = {}

for lab in LABELS:
    ev_list = []
    for pid, mapping in participant_evokeds_by_pas.items():
        path = mapping.get(lab)
        if not path:
            continue
        ev = mne.read_evokeds(path, condition=0, verbose=False)
        ev.pick_types(meg=True, eeg=False, exclude=[])
        ev_list.append(ev)

    if not ev_list:
        continue

    mne.channels.equalize_channels(ev_list)

    # ensuring identical timebase
    base_t = ev_list[0].times
    if not all(np.array_equal(base_t, e.times) for e in ev_list):
        print(f"[WARN] {lab}: time vectors differ across subjects—skipping GA for this label.")
        continue

    ga = mne.grand_average(ev_list, interpolate_bads=True)
    ga_by_label[lab] = ga
    n_subj_by_label[lab] = len(ev_list)

# plotting gfp
out_dir = f"{EVOKED_COLLAPSED_DIR}/figs"; os.makedirs(out_dir, exist_ok=True)
plt.figure(figsize=(8,4))
plotted = False

for lab, ga in ga_by_label.items():
    data = ga.get_data()
    if data.size == 0:
        print(f"[WARN] {lab}: GA has zero data after picks—skipping.")
        continue
    gfp = np.sqrt((data**2).mean(axis=0))
    plt.plot(ga.times, gfp, label=f"{lab} (N={n_subj_by_label[lab]})")
    plotted = True

plt.axvline(0, ls="--", color="k", lw=1)
plt.xlabel("Time (s)"); plt.ylabel("Global Field Power (a.u.)")
plt.title("Group — stimulus-locked GFP (PAS1, PAS2, PAS3_4)")
plt.legend(frameon=False); plt.tight_layout()

png = os.path.join(out_dir, "GROUP_evoked_PAS_overlay.png")
plt.savefig(png, dpi=150); plt.close()
print("Saved group figure:", png, "| drew lines:", plotted)


In [None]:
print("=" * 60)
print("FINAL PREPROCESSING SUMMARY")
print("=" * 60)

print(f"\nTotal participants loaded: {len(participant_data)}")
print("Participants:", list(participant_data.keys()))

print("\n" + "-" * 60)
print("Epoch counts after artifact rejection (per PAS):")
print("-" * 60)
for pid, ep in participant_epochs_clean.items():
    counts = ep.metadata["PAS"].value_counts().sort_index().to_dict()
    print(f"{pid}: total {len(ep)} epochs → {counts}")

print("\n" + "-" * 60)
print("Evoked files saved to:", EVOKED_DIR)
print("-" * 60)

print("\nPreprocessing complete! Ready for source or decoding analysis.")
print("=" * 60)


## Next Steps

Continue with:
1. **Noise covariance estimation** for each participant
2. **Forward model computation** (requires MRI/BEM)
3. **Inverse solution** (source reconstruction)
4. **Group-level analysis**

See "multi_participant_analysis.ipynb" for analysis process. 