In [None]:
## IMPORTS AND DEFAULT PLOTTING PARAMETERS

import mne ## MNE-Python for analysing data
## below magic provides interactive plots in notebook
%matplotlib widget
from os import chdir
from os.path import join
import matplotlib.pyplot as plt ## for basic plotting
import matplotlib as mpl ## for setting default parameters
import pandas as pd
import numpy as np
import os

# 1. Define path and load in data

In [None]:
# --- Base paths ---
MEG_path = '/work/MEG_data/workshop_data'
subjects_dir = '/work/freesurfer'
behaviour_path = os.path.join(MEG_path, 'behavioural_logs')

# --- List of subjects ---
subjects = ["0163", "0164", "0165", "0166", "0167", "0168", "0169", "0170"]

# dictionary to store the raw files
raws = {}

In [None]:
# --- Loop over subjects ---
for subj in subjects:
    print(f"\n=== Loading subject {subj} ===")

    subj_path = os.path.join(MEG_path, subj)

    # find all subfolders that start with "2025"
    session_dirs = [
        d for d in os.listdir(subj_path)
        if d.startswith("2025") and os.path.isdir(os.path.join(subj_path, d))
    ]

    if not session_dirs:
        print(f"No session folder starting with '2025' found for {subj}")
        continue

    # load the data
    session_dirs.sort()
    session_path = os.path.join(subj_path, session_dirs[-1])

    raw_path = os.path.join(session_path, "workshop_2025_raw.fif")

    if not os.path.exists(raw_path):
        print(f"Raw file not found for {subj}: {raw_path}")
        continue

    raw = mne.io.read_raw_fif(raw_path, preload=True)
    raws[subj] = raw

print("Finished loading all subjects!")


# 3. filter the data

In [None]:
for subj, raw in raws.items():
    print(f"\n=== Filtering subject {subj} ===")
    raw.filter(h_freq=40, l_freq=1)

# 4. events


In [None]:
events_by_subj = {}

for subj, raw in raws.items():
    print(f"Finding events for subject {subj} ...")
    events = mne.find_events(raw, shortest_event = 1) 
    events_by_subj[subj] = events

# 5. link the events with the behavioural logs

In [None]:
behaviour_by_subj = {}

# --- Loop over subjects ---
for subj in subjects:
    print(f"\n=== Loading behavioural data for subject {subj} ===")

    # find file that starts with the subject ID and 2025
    for file in os.listdir(behaviour_path):
        if file.startswith(f"{subj}_2025") and file.endswith("_experiment_data.csv"):
            behaviour_file = os.path.join(behaviour_path, file)
            break
    else:
        print(f" No behavioural file found for {subj}")
        continue

    behaviour = pd.read_csv(behaviour_file, index_col=False)
    behaviour['PAS_score'] = behaviour['subjective_response'].astype(str) + "00"
    behaviour_by_subj[subj] = behaviour

    print(behaviour.columns)
    print(behaviour)

print("Finished loading behavioural data for all subjects!")

# 6. Only keep stimulus events

In [None]:
# --- Loop over subjects ---
for subj in subjects:
    print(f"\n=== Applying behavioural PAS labels for subject {subj} ===")

    # skip if we don't have behaviour for this subject
    if subj not in behaviour_by_subj:
        print(f" No behavioural data for {subj}, skipping.")
        continue

    events = events_by_subj[subj]
    behaviour = behaviour_by_subj[subj]

    target_indices = np.isin(events[:, 2], [1, 3]) # i changed this bit
    events = events[target_indices, :] 
    

    if len(events) == 0:
        print(f" No stimulus events found for {subj}, skipping.")
        continue

    n_behav = len(behaviour)
    n_events = len(events)
    if n_behav != n_events:
        print(f" {subj}: mismatch (events={n_events}, behaviour={n_behav}), trimming to shortest.")
        min_len = min(n_behav, n_events)
        events = events[:min_len, :]
        behaviour = behaviour.iloc[:min_len]

    events[:, 2] = behaviour["PAS_score"].astype(int).to_numpy()
    events_by_subj[subj] = events

print("\n Finished applying PAS labels to all subjects!")


# 7. Merge pas 4 and 3

In [None]:
# --- Loop over subjects ---
for subj in subjects:
    print(f"\n=== Summarizing and merging PAS scores for subject {subj} ===")

    events = events_by_subj[subj]

    # merge PAS 4 (400) with PAS 3 (300)
    events[events[:, 2] == 400, 2] = 300

    # update dictionary
    events_by_subj[subj] = events

     # show counts before merge
    unique, counts = np.unique(events[:, 2], return_counts=True)
    print(np.asarray((unique, counts)).T)


print("Finished summarizing and merging PAS scores for all subjects!")

# Only keep Pas 1 and 3:

In [None]:
# --- Create filtered event sets: PAS 200 vs 300 ---
events_2vs3_by_subj = {}

for subj in subjects:
    print(f"\n=== Filtering PAS 100 vs 200 for subject {subj} ===")

    if subj not in events_by_subj:
        print(f" No events found for {subj}, skipping.")
        continue

    events = events_by_subj[subj]

    # keep only PAS 100 and 300
    mask = np.isin(events[:, 2], [200, 300])
    events_2vs3 = events[mask]

    if len(events_2vs3) == 0:
        print(f" No PAS 100/200 trials for {subj}, skipping.")
        continue

    events_2vs3_by_subj[subj] = events_2vs3
    print(f" {subj}: kept {len(events_2vs3)} trials (PAS 100/200 only)")

print("\n Finished filtering PAS 100 vs 100 for all subjects!")


# 8. Creating the epochs

In [None]:
# --- Loop over subjects --- (skipping 0164)
for subj in subjects:

    print(f"\n=== Creating epochs for subject {subj} ===")

    raw = raws[subj]
    events = events_2vs3_by_subj[subj]

    # create dict with event ID mapped to desired labels
    event_ids = {"PAS2": 200,
                 "PAS3": 300}

    # define input to epoch function
    tmin = -0.200
    tmax = 0.550
    baseline = (-0.200, 0)
    reject = {'eog': 250e-6}

    # epoch data with EOG rejection
    epochs = mne.Epochs(raw,
                        events=events,
                        event_id=event_ids,
                        tmin=tmin,
                        tmax=tmax,
                        baseline=baseline,
                        preload=True,
                        reject=reject,
                        on_missing='ignore')

    # store both in dictionaries if you want to keep them
    if 'epochs_by_subj' not in locals():
        epochs_by_subj = {}

    epochs_by_subj[subj] = epochs

print("\n Finished creating epochs for all subjects (skipped 0164).")



# Force Balance

In [None]:
# --- Force-balance PAS1 vs PAS3 per subject ---

epochs_balanced_by_subj = {}

for subj, ep in epochs_by_subj.items():
    print(f"\n=== Balancing epochs for subject {subj} ===")

    # get condition names, e.g. ["PAS1", "PAS3"]
    conds = list(ep.event_id.keys())
    if len(conds) != 2:
        print(f" {subj}: Expected 2 conditions, found {conds}. Skipping.")
        continue

    # split epochs into the two PAS conditions
    ep_list = [ep[conds[0]], ep[conds[1]]]

    # balance the number of trials
    mne.epochs.equalize_epoch_counts(ep_list)

    # merge back together
    ep_balanced = mne.concatenate_epochs(ep_list)
    epochs_balanced_by_subj[subj] = ep_balanced

    print(f" {subj}: kept {len(ep_list[0])} epochs per condition (total {len(ep_balanced)})")

print("\n Finished balancing PAS1 vs PAS3 for all subjects!")


# Multinomial regression in sensor space

In [None]:
# packages for logistic regression 
import numpy as np
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt


In [None]:

# --- config ---
subjects_use = subjects                     # use all subjects you kept
picks = "meg"
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=2002)

logr = make_pipeline(
    StandardScaler(),
    LogisticRegression(max_iter=2000, class_weight='balanced', solver='lbfgs')
)

scores_by_subj = {}  # {subj: scores_t}
times = None

for subj in subjects_use:
    print(f"\n=== Decoding PAS200 vs PAS300 for {subj} ===")

    # pick MEG channels
    e = epochs_balanced_by_subj[subj].copy().pick(picks)

    X = e.get_data()         # shape: (trials, sensors, time)
    y = e.events[:, 2]       # contains only 200 & 300 now

    # recode labels: 200 -> 0, 300 -> 1
    y_bin = (y == 300).astype(int)

    if times is None:
        times = e.times
    n_times = X.shape[2]

    scores_t = np.zeros(n_times)

    for t in range(n_times):
        print(f"{subj}: time {t+1}/{n_times}", end="\r")
        Xt = X[:, :, t]   # (trials, sensors)
        scores_t[t] = cross_val_score(logr, Xt, y_bin, cv=cv, n_jobs=-1).mean()

    scores_by_subj[subj] = scores_t
    print()  # new line after subject

# --- compute group mean curve ---
S = np.vstack([scores_by_subj[subj] for subj in scores_by_subj])
group_mean = S.mean(axis=0)

# --- plot ---
chance = 0.5

plt.figure()
plt.axhline(chance, linestyle="--", linewidth=1, label=f"Chance ({chance:.2f})")
plt.plot(times, group_mean, label="PAS200 vs PAS300")

plt.xlabel("Time (s)")
plt.ylabel("Decoding Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# --- plot ---
chance = 0.5

plt.figure()
plt.axhline(chance, linestyle="--", linewidth=1, label=f"Chance ({chance:.2f})")
plt.plot(times, group_mean, label="PAS200 vs PAS300")

plt.xlabel("Time (s)")
plt.ylabel("Decoding Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
#Loading results for later use:
S_2vs3 = np.load("group_scores_2vs3.npy")

# Load the .npz file
data_2vs3 = np.load("group_mean_results_2vs3.npz")

# Access the arrays
group_mean_2vs3 = data_2vs3["mean"]
times = data_2vs3["times"]

In [None]:
from mne.stats import permutation_cluster_1samp_test

# S = subjects x times
T_obs_2vs3, clusters_2vs3, cluster_p_values_2vs3, H0_2vs3 = permutation_cluster_1samp_test(
    S_2vs3 - 0.5,  # subtract chance
    n_permutations=1000,
    tail=1,  # test above chance
    threshold=None,  # t-threshold can be auto-determined
    out_type='mask'
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

chance = 0.5

plt.figure(figsize=(10, 4))

# Plot chance line
plt.axhline(chance, linestyle="--", linewidth=1, color='k', label=f"Chance ({chance:.2f})")

# Plot group mean decoding
plt.plot(times, group_mean_2vs3, label="PAS200 vs PAS300", color='b')

# Shade significant clusters
for i_c, c in enumerate(clusters_2vs3):
    if cluster_p_values_2vs3[i_c] < 0.05:
        # Convert the slice/boolean mask to actual time points
        if isinstance(c, tuple):  # MNE may return a tuple with a slice
            cluster_indices = np.arange(*c[0].indices(len(times)))
        else:  # boolean array
            cluster_indices = np.where(c)[0]

        # Shade the cluster
        plt.axvspan(times[cluster_indices[0]], times[cluster_indices[-1]], color='red', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Balanced Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
#Loading results for later use:
S_1vs3 = np.load("group_scores_1vs3.npy")

# Load the .npz file
data_1vs3 = np.load("group_mean_results_1vs3.npz")

# Access the arrays
group_mean_1vs3 = data_1vs3["mean"]
times = data_1vs3["times"]

In [None]:
from mne.stats import permutation_cluster_1samp_test

# S = subjects x times
T_obs_1vs3, clusters_1vs3, cluster_p_values_1vs3, H0_1vs3 = permutation_cluster_1samp_test(
    S_1vs3 - 0.5,  # subtract chance
    n_permutations=1000,
    tail=1,  # test above chance
    threshold=None,  # t-threshold can be auto-determined
    out_type='mask'
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

chance = 0.5

plt.figure(figsize=(10, 4))

# Plot chance line
plt.axhline(chance, linestyle="--", linewidth=1, color='k', label=f"Chance ({chance:.2f})")

# Plot group mean decoding
plt.plot(times, group_mean_1vs3, label="PAS100 vs PAS300", color='b')

# Shade significant clusters
for i_c, c in enumerate(clusters_1vs3):
    if cluster_p_values_1vs3[i_c] < 0.05:
        # Convert the slice/boolean mask to actual time points
        if isinstance(c, tuple):  # MNE may return a tuple with a slice
            cluster_indices = np.arange(*c[0].indices(len(times)))
        else:  # boolean array
            cluster_indices = np.where(c)[0]

        # Shade the cluster
        plt.axvspan(times[cluster_indices[0]], times[cluster_indices[-1]], color='red', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Balanced Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
#Loading results for later use:
S_1vs2 = np.load("group_scores_1vs2.npy")

# Load the .npz file
data_1vs2 = np.load("group_mean_results_1vs2.npz")

# Access the arrays
group_mean_1vs2 = data_1vs2["mean"]
times = data_1vs2["times"]

In [None]:
from mne.stats import permutation_cluster_1samp_test

# S = subjects x times
T_obs_1vs2, clusters_1vs2, cluster_p_values_1vs2, H0_1vs2 = permutation_cluster_1samp_test(
    S_1vs2 - 0.5,  # subtract chance
    n_permutations=1000,
    tail=1,  # test above chance
    threshold=None,  # t-threshold can be auto-determined
    out_type='mask'
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

chance = 0.5

plt.figure(figsize=(10, 4))

# Plot chance line
plt.axhline(chance, linestyle="--", linewidth=1, color='k', label=f"Chance ({chance:.2f})")

# Plot group mean decoding
plt.plot(times, group_mean_1vs2, label="PAS100 vs PAS200", color='b')

# Shade significant clusters
for i_c, c in enumerate(clusters_1vs2):
    if cluster_p_values_1vs2[i_c] < 0.05:
        # Convert the slice/boolean mask to actual time points
        if isinstance(c, tuple):  # MNE may return a tuple with a slice
            cluster_indices = np.arange(*c[0].indices(len(times)))
        else:  # boolean array
            cluster_indices = np.where(c)[0]

        # Shade the cluster
        plt.axvspan(times[cluster_indices[0]], times[cluster_indices[-1]], color='red', alpha=0.3)

plt.xlabel("Time (s)")
plt.ylabel("Balanced Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
import numpy as np

# Example for one subject
times = epochs_balanced_by_subj["0163"].times  # array of time points in seconds

# Loop over all clusters
significant_clusters = []

for i_c, c in enumerate(clusters_2vs3):
    p_val = cluster_p_values_2vs3[i_c]
    if p_val < 0.05:  # significant cluster
        # Convert slice or boolean mask to indices
        if isinstance(c, tuple):
            cluster_indices = np.arange(*c[0].indices(len(times)))
        else:  # boolean array
            cluster_indices = np.where(c)[0]
        
        # Map indices to time in ms
        cluster_times_ms = times[cluster_indices] * 1000
        start_ms = cluster_times_ms[0]
        end_ms = cluster_times_ms[-1]
        
        significant_clusters.append({
            "cluster_index": i_c,
            "p_value": p_val,
            "time_range_ms": (start_ms, end_ms)
        })

# Print results
for cl in significant_clusters:
    print(f"Cluster {cl['cluster_index']} (p={cl['p_value']:.3f}) "
          f"time range: {cl['time_range_ms'][0]:.1f}â€“{cl['time_range_ms'][1]:.1f} ms")


In [None]:
import matplotlib.pyplot as plt

chance = 0.5

plt.figure(figsize=(10, 4))

# Plot chance line
plt.axhline(chance, linestyle="--", linewidth=1, color='k', label=f"Chance ({chance:.2f})")

# Plot multiple group means
plt.plot(times, group_mean_1vs3, label="PAS-1 vs PAS-3", color='#1B3A6F')  # navy
plt.plot(times, group_mean_2vs3, label="PAS-2 vs PAS-3", color='#2878B5')  # medium blue
plt.plot(times, group_mean_1vs2, label="PAS-1 vs PAS-2", color='#76B7E5')  # sky blue

plt.xlabel("Time (s)")
plt.ylabel("Decoding Accuracy")
plt.title("Time-resolved PAS decoding (sensor space)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import numpy as np

# Find the index of the peak mean
peak_idx = np.argmax(group_mean_2vs3)

# Get the peak value
peak_value = group_mean_2vs3[peak_idx]

# Get the corresponding time (assuming you have a 'times' array)
peak_time = times[peak_idx]*1000

print(f"Peak mean decoding: {peak_value:.3f} at {peak_time:.1f} ms")
