# Decoding through the epoch's window time the sentence / constituent embeddings

We:

- Calculate the embeddings for each sentence / constituent

- Decode it on the epoch window

- Plot it for each condition (level / start)


Todo: 

Integrate this:

```sent_starts = epochs['word_id==0'].apply_baseline((-.300, 0.))
sent_starts.average().plot()

sent_stops = epochs['is_last_word']
bsl = (epochs.times>-.300 )*(epochs.times<=0)
baseline_starts = sent_starts.get_data()[:, :, bsl].mean(-2)

sent_stop_data = sent_stops.get_data()
n_sentences, n_channels, n_times = sent_stop_data.shape
sent_stop_data -= baseline_starts[:, :, None]```


In [1]:
from dataset import read_raw, get_subjects, get_path, add_embeddings
from utils import decod_xy, mne_events
import mne
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from utils import match_list
import spacy

In [2]:
def plot_scores(all_scores):
    from matplotlib.pyplot import figure

    figure(figsize=(16, 10), dpi=80)

    fig, axes = plt.subplots(3, 2)

    for axes_, level in zip( axes, levels):  
        for ax, start in zip( axes_, starts):  
            cond1 = all_scores.level==f'{level}'
            cond2 = all_scores.start==f'{start}'
            data = all_scores[ cond1 & cond2]
            y = []
            x = []
            for s, t in data.groupby('t'):
                score_avg = t.score.mean()
                y.append(score_avg)
                x.append(s)

            ax.plot(x,y)
            ax.set_title(f'{level} {start}')


In [3]:
def epoch_add_metadata(modality, subject, levels, starts, runs=9):
    """
    Takes as input subject number, modality, levels of epoching wanted (word, sentence, constituent)
    and starts (onset, offset) as well as the number of total runs (for debugging).
    
    Returns: 
    
    A dict of epochs objects, concatenated on the key (levels x starts)
    
    e.g: {'word_onset': <Epochs 10000 objects>, 'sentence_offset': <Epochs 1000 objects> ....}
    """
    dict_epochs = dict() # DICT containing epochs grouped by conditions (start x level)
    
    # Initialization of the dictionary
    for start in starts: 
            for level in levels:
                epoch_key = f'{level}_{start}'
                dict_epochs[epoch_key] = [] 
                
    # Iterating on runs, building the metadata and re-epoching
    for run in range(1,runs+1):
        raw, meta_, events = read_raw(subject, run, events_return = True, modality=modality)
        meta = meta_.copy()
        
        # Metadata update
        meta['word_onset'] = True
        meta['word_stop'] = meta.start + meta.duration
        meta['sentence_onset'] = meta.word_id == 0
        meta['prev_closing'] = meta['n_closing'].shift(1)
        meta['constituent_onset'] = meta.apply(lambda x: x['prev_closing'] > x['n_closing'] and x['n_closing'] == 1, axis=1)
        meta['constituent_onset'].fillna(False, inplace=True)
        meta['const_end'] = meta.constituent_onset.shift(-1)
        meta.drop('prev_closing', axis=1, inplace=True)
        
        # Adding the sentence stop info
        meta['sentence_id'] = np.cumsum(meta.sentence_onset)
        for s, d in meta.groupby('sentence_id'):
            meta.loc[d.index, 'sent_word_id'] = range(len(d))
            meta.loc[d.index, 'sentence_start'] = d.start.min()
            last_word_duration = meta.loc[d.index.max(), 'duration']
            meta.loc[d.index, 'sentence_stop'] = d.start.max() + last_word_duration
            # Todo: Add the last word duration ? 
            
        # Adding the constituents stop info
        meta['constituent_id'] = np.cumsum(meta.constituent_onset)
        for s, d in meta.groupby('constituent_id'):
            meta.loc[d.index, 'const_word_id'] = range(len(d))
            meta.loc[d.index, 'constituent_start'] = d.start.min()
            last_word_duration = meta.loc[d.index.max(), 'duration']
            meta.loc[d.index, 'constituent_stop'] = d.start.max() + last_word_duration

        # Adding embeddings info
        meta = add_embeddings(meta, run, 'constituent')
        meta = add_embeddings(meta, run, 'sentence')
        
        embeddings = meta.word.apply(lambda word: nlp(word).vector).values
        meta['embeds_word'] = embeddings
        for start in starts: 
            for level in levels:
                # Select only the rows containing the True for the conditions
                # Simplified to only get for the onset: sentence onset epochs, constituent onset epochs,etc
                sel = meta.query(f'{level}_onset==True')
                assert sel.shape[0] > 10  #

                # Epoching from the metadata having all onset events: if the start=Offset, the mne events
                # Function will epoch on the offset of each level instead of the onset
                # TODO: add adaptative baseline
                epochs = mne.Epochs(raw, **mne_events(sel, raw ,start=start, level=level), decim = 100,
                                     tmin = epoch_windows[f'{level}'][f'{start}_min'],
                                       tmax = epoch_windows[f'{level}'][f'{start}_max'],
                                         event_repeated = 'drop',
                                            preload=True,
                                                baseline=None)
                epoch_key = f'{level}_{start}'

                dict_epochs[epoch_key].append(epochs)
            
    # Once we have the dict of epochs per condition full (epoching for each run for a subject)
    # we can concatenate them, and fix the dev_head             
    for start_ in starts: 
        for level_ in levels:
            epoch_key = f'{level_}_{start_}'
            all_epochs_chosen = dict_epochs[epoch_key]
            # Concatenate epochs

            for epo in all_epochs_chosen:
                epo.info["dev_head_t"] = all_epochs_chosen[1].info["dev_head_t"]

            dict_epochs[epoch_key] = mne.concatenate_epochs(all_epochs_chosen)
            
    return dict_epochs

In [4]:
def decoding_from_criterion(criterion, dict_epochs, starts, levels):
    """
    Input:
    - criterion: the criterion on which the decoding will be done (embeddings, wlength, w_freq, etc..)
    - dict_epochs: the dictionnary containing the epochs for each condition (starts x levels)
    - starts: (onset, offset)
    - levels: (word, sentence, constituent)
    
    Returns:
    Two dataframes: 
    - all_scores: decoding scores for each subject / starts x levels
    - all_evos: ERP plots for each subject / starts x levels
    
    
    """
    
    all_evos = []
    all_scores = []
    # All epochs -> Decoding and generate evoked potentials
    for start in starts: 
        for level in levels:
            epoch_key = f'{level}_{start}'
            epochs = dict_epochs[epoch_key]
            # mean
            evo = epochs.copy().pick_types(meg=True).average(method='median')
            all_evos.append(dict(subject=subject, evo=evo, start=start, level=level))


            # decoding word emb
            epochs = epochs.load_data().pick_types(meg=True, stim=False, misc=False)
            X = epochs.get_data()
            if criterion == 'emb_sentence' or criterion == 'emb_constituent':
                embeddings = epochs.metadata[f'embeds_{level}']
                embeddings = np.vstack(embeddings.values)
                R_vec = decod_xy(X, embeddings)
                scores = np.mean(R_vec, axis=1)
            elif criterion == 'emb_word':
                nlp = spacy.load("fr_core_news_sm")
                embeddings = epochs.metadata.word.apply(lambda word: nlp(word).vector).values
                embeddings = np.array([emb for emb in embeddings])
                R_vec = decod_xy(X, embeddings)
                scores = np.mean(R_vec, axis=1)
            elif criterion == 'wlength':
                y = epochs.metadata.wlength
                R_vec = decod_xy(X, y)
                scores = R_vec

            for t, score in enumerate(scores):
                all_scores.append(dict(subject=subject, score=score, start=start, level=level, t=epochs.times[t]))
    all_scores = pd.DataFrame(all_scores)
    all_evos = pd.DataFrame(all_evos)
    return all_scores, all_evos

In [5]:
modality = "visual"
nlp = spacy.load("fr_core_news_sm")
path = get_path(modality)
subjects = get_subjects(path)
runs = 9
decoding_criterion = 'wlength'
epoch_windows = {"word": {"onset_min": -0.3, "onset_max": 1.0, "offset_min": -1.0, "offset_max": 0.3},
                  "constituent": {"offset_min": -2.0, "offset_max": 0.5, "onset_min": -0.5, "onset_max": 2.0},
                  "sentence": {"offset_min": -4.0, "offset_max": 1.0, "onset_min": -1.0, "onset_max": 4.0}}
levels = ('word','constituent','sentence')
starts = ('onset', 'offset')

if isinstance(levels, str):
    levels = [levels]
    
if isinstance(starts, str):
    starts = [starts]
      
# Iterate on subjects to epochs, and mean later
for subject in subjects[2:5]:
    dict_epochs = epoch_add_metadata(modality, subject, levels, starts, runs)
    
    all_scores, all_evos = decoding_from_criterion(decoding_criterion, dict_epochs, starts, levels)

Reading raw files for modality: visual

 Epoching for run 1, subject: 3

Opening raw data file /home/is153802/data/LPP_MEG_visual/sub-3/ses-01/meg/sub-3_ses-01_task-read_run-01_meg.fif...
    Read a total of 13 projection items:
        grad_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v6 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v7 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v8 (1 x 306)  idle
    Range : 36000 ... 517999 =     36.000 ...   517.999 secs
Ready.
Reading e

  raw = mne_bids.read_raw_bids(bids_path)
  raw = mne_bids.read_raw_bids(bids_path)
  raw = mne_bids.read_raw_bids(bids_path)


Trigger channel has a non-zero initial value of 8 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
1466 events found
Event IDs: [ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
1684          rire
1685          mais
1686            où
1687          veux
1688            tu
1689            qu
1691            il
1692         aille
1693             n
1695       importe
1696            où
1697         droit
1698        devant
1699           lui
1700         alors
1701            le
1702         petit
1703        prince
1704      remarqua
1705     gravement
1706            ça
1707            ne
1708          fait
1709          rien
1710             c
1712           est
1713     tellement
1714         petit
1715          chez
1716          moi!
1717            et
1718          avec
1719            un
1720           peu
1721            de
1722    mélancolie
1723          peut
1724          être
1725            il
1726        ajouta
17

KeyError: '[1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1571, 1572, 1573, 1574, 1575, 1576, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589] not in index'

In [None]:
%debug

> [0;32m/home/is153802/.pyenv/versions/meg-masc/lib/python3.10/site-packages/pandas/core/indexes/base.py[0m(5859)[0;36m_raise_if_missing[0;34m()[0m
[0;32m   5857 [0;31m[0;34m[0m[0m
[0m[0;32m   5858 [0;31m            [0mnot_found[0m [0;34m=[0m [0mlist[0m[0;34m([0m[0mensure_index[0m[0;34m([0m[0mkey[0m[0;34m)[0m[0;34m[[0m[0mmissing_mask[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m][0m[0;34m.[0m[0munique[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 5859 [0;31m            [0;32mraise[0m [0mKeyError[0m[0;34m([0m[0;34mf"{not_found} not in index"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   5860 [0;31m[0;34m[0m[0m
[0m[0;32m   5861 [0;31m    [0;34m@[0m[0moverload[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/home/is153802/.pyenv/versions/meg-masc/lib/python3.10/site-packages/pandas/core/indexes/base.py[0m(5796)[0;36m_get_indexer_strict

ipdb> len(ià
*** SyntaxError: '(' was never closed
ipdb> len(i)
1558
ipdb> len(j)
1558
ipdb> meta_tokens
0       lorsque
1             j
2         avais
3           six
4           ans
         ...   
1585       peut
1586        pas
1587      aller
1588       bien
1589       loin
Name: clean_word, Length: 1590, dtype: object
ipdb> synt_tokens
0       lorsque
1             j
3         avais
4           six
5           ans
         ...   
1732       peut
1733        pas
1734      aller
1735       bien
1736       loin
Name: word, Length: 1614, dtype: object
ipdb> key
'n_closing'
ipdb> i
array([   0,    1,    2, ..., 1587, 1588, 1589])
ipdb> meta
      Unnamed: 0      word  onset  duration  \
0              0   Lorsque    0.7      0.25   
1              1   j'avais    1.0      0.25   
2              2       six    1.3      0.25   
3              3      ans,    1.6      0.25   
4              4      j'ai    1.9      0.25   
...          ...       ...    ...       ...   
1460        1460    

In [None]:
plot_scores(all_scores)

In [None]:
plot_scores(all_scores)

In [None]:
dict_epochs['word_onset'].metadata.keys