In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt



In [6]:
# --- Core packages
import mne
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt

# --- Complementary packages
import os
import sys
from mne.decoding import UnsupervisedSpatialFilter
from sklearn.decomposition import PCA
import joblib

# --- Local modules
sys.path.append('/MEG/2-MEGSEQ')
from modules import *


In [2]:
# Path names
# -- root
root_path='/Volumes/T5_EVO/REPLAYSEQ/ICM/Data_ICM'
# ---- Behavior
behavior_path=root_path+'/behavior'
behavior_sub_01=behavior_path+'/sub-02_2024-04-10_09h41.08.639_results.csv'

# ---- Epochs
full_sequences_path=root_path+'/epochs_sequences/mne-bids-pipeline/sub-01/meg'
items_path=root_path+'/2-epochs_items/mne-bids-pipeline/sub-01/meg'
# ------ fif path
seq_run_01_path=full_sequences_path+'/sub-01_task-reproduction_epo.fif'
items_run_01_path=items_path+'/sub-01_task-reproduction-epo.fif'

# -- Save plot path
decoder_path="/Users/et/Documents/UNICOG/MEG/2-MEGSEQ/plots/decoder"

In [4]:
# Formating behavioral dataFrame
df=pd.read_csv(behavior_sub_01)
behavior=format_df(df)

In [5]:
# Load epochs with preload=False to optimize memory usage
epochs_items_run01 = mne.read_epochs(items_run_01_path, preload=True)

Reading /Volumes/T5_EVO/REPLAYSEQ/ICM/Data_ICM/2-epochs_items/mne-bids-pipeline/sub-01/meg/sub-01_task-reproduction-epo.fif ...
    Found the data of interest:
        t =    -200.00 ...     600.00 ms
        0 CTF compensation matrices available
Adding metadata with 58 columns
3240 matching events found
No baseline correction applied
0 projection items activated


---

# Decoding

In [6]:
# Create a big object containing all the elements shown to the participant in order to append it to the epochs.events
all_items_shown=np.concatenate(behavior['PresentedSequence'])



In [7]:
9*30*12

3240

In [8]:
15*12

180

In [9]:
len(epochs_items_run01.events[:,2])

3240

In [10]:
len(np.unique(epochs_items_run01.events[:,2]))

54

In [11]:
np.unique(epochs_items_run01.events[:,2])

array([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
       21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
       38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
       55, 56, 57], dtype=int32)

In [12]:
12*40

480

from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

# Create an iterator over the cross-validation splits
cv_iter = StratifiedKFold(n_splits=5)  # Example: 5-fold cross-validation

# Initialize a list to store scores
scores = []

# Iterate over cross-validation splits
for train, test in tqdm(cv_iter.split(epochs_items_run01.get_data()), desc='Cross-validation progress'):
    # Fit the decoder on training data
    dec_seq.fit(epochs_items_run01[train])

    # Predict and score on test data
    score = dec_seq.score(epochs_items_run01[test])

    # Append score to list
    scores.append(score)

# Calculate average score
mean_score = np.mean(scores)
print(f"Average decoding score: {mean_score}")

In [13]:
dec_seq=SVM_decoder()

In [14]:
#score_seq=mne.decoding.cross_val_multiscore(dec_seq, epochs_items_run01.pick_types(meg='mag')._data, y=all_items_shown)


In [15]:
#dec = SVM_decoder()

#score = mne.decoding.cross_val_multiscore(dec, epochs_items.pick_types(meg='mag')._data, y=epochs_items.events[:,2])
#plotting_funcs.pretty_gat(np.mean(score, axis=0), epochs_items.times,chance=1/6)


In [16]:
#pretty_gat(np.mean(score_seq,axis=0),epochs_items_run01.times,chance=1/6)

In [17]:
#epochs_items_run01.times

In [18]:
def pretty_gat(score,times,chance):
    # gat= generalization across time.
    fig, ax = plt.subplots(1, 1)
    im = ax.imshow(
        score,
        interpolation="lanczos",
        origin="lower",
        cmap="RdBu_r",
        extent=times,
        vmin=chance-2*np.std(score),
        vmax=chance+2*np.std(score),
    )
    ax.set_xlabel("Testing Time (s)")
    ax.set_ylabel("Training Time (s)")
    ax.set_title("Temporal generalization")
    ax.axvline(0, color="k")
    ax.axhline(0, color="k")
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("score")

    plt.show()

In [19]:
#plt.imshow(np.random.rand(1000,800))

In [20]:
#plt.imshow(np.mean(score_seq,axis=0),interpolation='lanczos',origin='lower',vmin=1/6-2*np.std(score_seq),vmax=1/6+2*np.std(score_seq))


In [21]:
## I forgot to baseline
epochs_items_baseline = mne.read_epochs(items_run_01_path, preload=True)

Reading /Volumes/T5_EVO/REPLAYSEQ/ICM/Data_ICM/2-epochs_items/mne-bids-pipeline/sub-01/meg/sub-01_task-reproduction-epo.fif ...
    Found the data of interest:
        t =    -200.00 ...     600.00 ms
        0 CTF compensation matrices available
Adding metadata with 58 columns
3240 matching events found
No baseline correction applied
0 projection items activated


In [22]:
epochs_items_baseline.apply_baseline((-0.2,0))

Applying baseline correction (mode: mean)


0,1
Number of events,3240
Events,SequenceID-C1RepEmbed/Position-1: 68 SequenceID-C1RepEmbed/Position-2: 44 SequenceID-C1RepEmbed/Position-3: 80 SequenceID-C1RepEmbed/Position-4: 68 SequenceID-C1RepEmbed/Position-5: 56 SequenceID-C1RepEmbed/Position-6: 44 SequenceID-C2RepEmbed/Position-1: 44 SequenceID-C2RepEmbed/Position-2: 64 SequenceID-C2RepEmbed/Position-3: 52 SequenceID-C2RepEmbed/Position-4: 64 SequenceID-C2RepEmbed/Position-5: 68 SequenceID-C2RepEmbed/Position-6: 68 SequenceID-CRep2/Position-1: 60 SequenceID-CRep2/Position-2: 78 SequenceID-CRep2/Position-3: 48 SequenceID-CRep2/Position-4: 66 SequenceID-CRep2/Position-5: 42 SequenceID-CRep2/Position-6: 66 SequenceID-CRep3/Position-1: 72 SequenceID-CRep3/Position-2: 76 SequenceID-CRep3/Position-3: 52 SequenceID-CRep3/Position-4: 44 SequenceID-CRep3/Position-5: 56 SequenceID-CRep3/Position-6: 60 SequenceID-CRep4/Position-1: 63 SequenceID-CRep4/Position-2: 51 SequenceID-CRep4/Position-3: 60 SequenceID-CRep4/Position-4: 57 SequenceID-CRep4/Position-5: 72 SequenceID-CRep4/Position-6: 57 SequenceID-Rep2/Position-1: 60 SequenceID-Rep2/Position-2: 60 SequenceID-Rep2/Position-3: 42 SequenceID-Rep2/Position-4: 84 SequenceID-Rep2/Position-5: 48 SequenceID-Rep2/Position-6: 66 SequenceID-Rep3/Position-1: 44 SequenceID-Rep3/Position-2: 72 SequenceID-Rep3/Position-3: 72 SequenceID-Rep3/Position-4: 60 SequenceID-Rep3/Position-5: 60 SequenceID-Rep3/Position-6: 52 SequenceID-Rep4/Position-1: 69 SequenceID-Rep4/Position-2: 69 SequenceID-Rep4/Position-3: 54 SequenceID-Rep4/Position-4: 45 SequenceID-Rep4/Position-5: 66 SequenceID-Rep4/Position-6: 57 SequenceID-RepEmbed/Position-1: 60 SequenceID-RepEmbed/Position-2: 56 SequenceID-RepEmbed/Position-3: 48 SequenceID-RepEmbed/Position-4: 60 SequenceID-RepEmbed/Position-5: 60 SequenceID-RepEmbed/Position-6: 76
Time range,-0.200 – 0.600 s
Baseline,-0.200 – 0.000 s


In [23]:
score_seq_baseline=mne.decoding.cross_val_multiscore(dec_seq, epochs_items_baseline.pick_types(meg='mag')._data, y=all_items_shown)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


  0%|          | Fitting GeneralizingEstimator : 0/201 [00:00<?,       ?it/s]



  0%|          | Scoring GeneralizingEstimator : 0/40401 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/201 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/40401 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/201 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/40401 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/201 [00:00<?,       ?it/s]



  0%|          | Scoring GeneralizingEstimator : 0/40401 [00:00<?,       ?it/s]

  0%|          | Fitting GeneralizingEstimator : 0/201 [00:00<?,       ?it/s]

  0%|          | Scoring GeneralizingEstimator : 0/40401 [00:00<?,       ?it/s]



In [41]:
plt.imshow(np.mean(score_seq_baseline, axis=0), interpolation='lanczos', origin='lower', cmap="RdBu_r", 
           vmin=1/6-2*np.std(score_seq_baseline), vmax=1/6+2*np.std(score_seq_baseline))
plt.xlabel("Testing Time (s)")
plt.ylabel("Training Time (s)")
plt.title("Temporal generalization")
plt.axvline(0, color="k")  # Changed vline to axvline
plt.axhline(0, color="k")  # Changed hline to axhline
cbar = plt.colorbar(orientation='vertical')  # Set orientation to vertical
cbar.set_label("score")
plt.savefig(decoder_path+'/sub01_decoder_position.png')
plt.show()

In [35]:
# Save the scores
#joblib.dump(score_seq_baseline, '/Users/et/Documents/UNICOG/MEG/2-MEGSEQ/saved_cross_val/MEG/sub_01/score_seq_baseline.pkl')


['/Users/et/Documents/UNICOG/MEG/2-MEGSEQ/saved_cross_val/MEG/sub_01/score_seq_baseline.pkl']

In [8]:
# Load the scores

score_seq_baseline = joblib.load('/Users/et/Documents/UNICOG/MEG/2-MEGSEQ/saved_cross_val/MEG/sub_01/score_seq_baseline.pkl')

In [10]:
plt.imshow(np.mean(score_seq_baseline, axis=0), interpolation='lanczos', origin='lower', cmap="RdBu_r", 
           vmin=1/6-2*np.std(score_seq_baseline), vmax=1/6+2*np.std(score_seq_baseline))

<matplotlib.image.AxesImage at 0x1583db0b0>