In [1]:
from dataset import get_path, get_subjects, epoch_data
from utils import decod, correlate, match_list
import mne_bids
from pathlib import Path
import pandas as pd
import numpy as np
import mne
import spacy

nlp = spacy.load('fr_core_news_sm')

from sklearn.model_selection import KFold, cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.linear_model import RidgeCV
from wordfreq import zipf_frequency
from Levenshtein import editops
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use("Agg")
mne.set_log_level(False)

In [2]:
print(get_path())

/home/co/data/BIDS_lecture


In [None]:
# TODO : integrate this inside epoch_data in dataset.py directly

# what's currently missing: metadata information adding, basically adding closing column.
# For that, the parse function is needed
def load_events(self):

    raw = self.raw()
    path = StudyPaths().download
    bids_path = BIDSPath(
        subject=self.subject_uid[4:],
        session="01",
        run=self.run,
        task="listen",
        root=path,
        datatype="meg",
    )

    # extract annotations
    event_file = path / f"sub-{bids_path.subject}"
    event_file = event_file / f"ses-{bids_path.session}"
    event_file = event_file / "meg"
    event_file = str(event_file / f"sub-{bids_path.subject}")
    event_file += f"_ses-{bids_path.session}"
    event_file += f"_task-{bids_path.task}"
    event_file += f"_run-{bids_path.run}_events.tsv"
    assert Path(event_file).exists()
    meta = pd.read_csv(event_file, sep="\t")
    meta['word'] = [eval(w)['word'] for w in meta.trial_type]
    meta['kind'] = [eval(w)['kind'] for w in meta.trial_type]
    
    events = self.raw_events

    # word events
    # match events and metadata
    word_events = events[events[:, 2] > 1]
    meg_delta = np.diff(word_events[:, 0].astype(float) / raw.info["sfreq"])
    meta_delta = np.diff(meta.onset.values)

    pres = 1e2
    i, j = utils.match_list(np.round(meg_delta*pres), np.round(meta_delta*pres))
    assert len(i) / len(meg_delta) > .95
    assert len(i) > 500
    meta = meta.iloc[j].reset_index(drop=True)
    meta["start"] = word_events[i, 0] / self.raw().info["sfreq"]

    # Sound events
    CHAPTERS = {
        1: "1-3",
        2: "4-6",
        3: "7-9",
        4: "10-12",
        5: "13-14",
        6: "15-19",
        7: "20-22",
        8: "23-25",
        9: "26-27",
    }

    # Event start and end:
    idx = [np.where(events[:, 2] == 1)[0][0]]
    sound_start = events[idx, 0] / self.raw().info["sfreq"]
    assert len(sound_start) == 1
    sound_start = sound_start[0]
    chapter = CHAPTERS[self.run_uid]

    sound_event = []
    sound_path = path / "stimuli" / f"ch{chapter}.wav"
    assert sound_path.exists()
    sound_event.append(
        dict(kind="sound", filepath=sound_path, start=sound_start)
    )

    meta = pd.concat([meta, pd.DataFrame(sound_event)], ignore_index=True)
    meta["condition"] = "sentence"
    meta = meta.sort_values('start').reset_index(drop=True)
    
    # add parsing data
    meta = enrich(meta, path/'stimuli'/f'ch{chapter}.txt')
    
    return meta[['start', 'duration', 'kind', 'word', 'filepath', 'condition', 'sequence_id', 'sequence_uid', 'word_index', 'closing_', 'match_token']]

In [6]:
run = 1

def parse(sent):
    'identifies the number of closing nodes'

    def is_closed(node, position):
        """JR quick code to know whether is a word is closed given a word position"""
        if node.i > position:
            return False
        for child in node.children:
            if child.i > position:
                return False
            if not is_closed(child, position):
                return False
        return True

    closeds = []
    for current in range(1, len(sent)+1):
        closed = 0
        for position, word in enumerate(sent): # [:current]
            closed += is_closed(word, current)
        closeds.append(closed)

    closing = np.r_[np.diff(closeds), closeds[-1]]
    return closing

def format_meta(meta,run_id):
    model = 'fr_core_news_sm'
    if not spacy.util.is_package(model):
        spacy.cli.download(model)

    nlp = spacy.load(model)

    CHAPTERS = {
    1: "1-3",
    2: "4-6",
    3: "7-9",
    4: "10-12",
    5: "13-14",
    6: "15-19",
    7: "20-22",
    8: "23-25",
    9: "26-27",
    }
    txt_file = f'./../../data/syntax/ch{CHAPTERS[run]}.syntax.txt'
    with open(txt_file, 'r') as f:
        txt = f.read().replace('\n', '')

    # parse text file
    doc = nlp(txt)

    # add parse information to metadata
    parse_annots = []
    for sent_id, sent in enumerate(doc.sents):
        # HERE ADD ERIC DE LA CLERGERIE parser instead
        closings = parse(sent)
        assert len(closings) == len(sent)
        for word, closing in zip(sent, closings):
            parse_annots.append(dict(
                word_index=word.i - sent[0].i,
                sequence_id=sent_id,
                sequence_uid=str(sent),
                closing=closing,
                match_token=word.text,
            ))

    # align text file and meg metadata
    def format_text(text):
        for char in ('jlsmtncd'):
            text = text.replace(f"{char}'", char)
        text = text.replace('œ', 'oe')
        return text.lower()

    meg_words = meta.word.fillna('######').values
    text_words = [format_text(w.text) for w in doc]

    i, j = match_list(meg_words, text_words)

    # deal with missed tokens (e.g. wrong spelling, punctuation)
    assert len(parse_annots) == len(text_words)
    parse_annots = pd.DataFrame(parse_annots)
    parse_annots.closing = parse_annots.closing.fillna(0)
    parse_annots['closing_'] = 0
    parse_annots['missed_closing'] = 0
    missing = np.setdiff1d(range(len(parse_annots)), j)
    for missed in missing:
        current_closing = parse_annots.iloc[missed].closing
        prev_word = parse_annots.iloc[[missed-1]].index
        if prev_word[0] >=0:
            parse_annots.loc[prev_word, 'missed_closing'] = current_closing
    parse_annots.closing_ = parse_annots.closing + parse_annots.missed_closing

    # Add new columns to original mne.Epochs.metadata
    # fill columns
    columns = ('word_index', 'sequence_id', 'sequence_uid', 'closing_', 'match_token')
    for column in columns:
        meta[column] = None
        meta.loc[meta.iloc[i].index, column] = parse_annots[column].iloc[j].values
    return meta

In [7]:
report = mne.Report()
path = get_path('LPP_read')
subjects = get_subjects(path)
RUN = 1

print("\nSubjects for which the decoding will be tested: \n")
print(subjects)

for subject in subjects:  # Ignore the first one

    print(f"Subject {subject}'s decoding started")
    epochs = []
    for run_id in range(1, RUN + 1):
        print(".", end="")
        epo = epoch_data(subject, "%.2i" % run_id, task='listen', path=path)
        epo.metadata["label"] = f"run_{run_id}"
        epochs.append(epo)

    # Quick fix for the dev_head: has to be
    # fixed before doing source reconstruction
    for epo in epochs:
        epo.info["dev_head_t"] = epochs[0].info["dev_head_t"]
        # epo.info['nchan'] = epochs[0].info['nchan']

    epochs = mne.concatenate_epochs(epochs)

    # Get the evoked potential averaged on all epochs for each channel
    evo = epochs.average(method="median")
    evo.plot(spatial_colors=True)

    # Handling the data structure
    epochs.metadata["kind"] = epochs.metadata.trial_type.apply(
        lambda s: eval(s)["kind"]
    )
    epochs.metadata["word"] = epochs.metadata.trial_type.apply(
        lambda s: eval(s)["word"]
    )
    # TODO : re-epoch
    print(format_meta(epochs.metadata, run_id))
    epochs.metadata
    epochs.metadata['closing'] = epochs.metadata.closing_.fillna(0)
    # Run a linear regression between MEG signals
    # and word frequency classification
    X = epochs.get_data()

    embeddings = epochs.metadata.word.apply(lambda word: nlp(word).vector).values
    embeddings = np.array([emb for emb in embeddings])

    y = embeddings

    R_vec = decod(X, y)
    R_vec_avg = np.mean(R_vec,axis = 1)

    fig, ax = plt.subplots(1, figsize=[6, 6])
    dec = plt.fill_between(epochs.times, R_vec_avg)
    # plt.show()
    report.add_evokeds(evo, titles=f"Evoked for sub {subject} ")
    report.add_figure(fig, title=f"decoding for subject {subject}")
    # report.add_figure(dec, subject, tags="word")
    report.save("./figures/reading_decoding_embeddings.html", open_browser=False, overwrite=True)

    print("Finished!")



Subjects for which the decoding will be tested: 

['1', '2', '3', '4', '5', '6']
Subject 1's decoding started
.Running the script on RAW data:
run 01, subject: 1


  raw = mne_bids.read_raw_bids(bids_path)
  raw = mne_bids.read_raw_bids(bids_path)
  raw = mne_bids.read_raw_bids(bids_path)
  epochs = mne.Epochs(
  epochs = mne.concatenate_epochs(epochs)


      index  Unnamed: 0     word   onset  duration  \
0         0           0  Lorsque    0.70       0.3   
1         1           1   javais    1.05       0.3   
2         2           2      six    1.40       0.3   
3         3           3      ans    1.75       0.3   
4         4           4      jai    2.10       0.3   
...     ...         ...      ...     ...       ...   
1406   1459        1459       ne  511.35       0.3   
1407   1460        1460     peut  511.70       0.3   
1408   1461        1461      pas  512.05       0.3   
1409   1462        1462    aller  512.40       0.3   
1410   1463        1463     bien  512.75       0.3   

                               trial_type    start  label  kind word_index  \
0     {'kind': 'word', 'word': 'Lorsque'}   44.534  run_1  word       None   
1      {'kind': 'word', 'word': 'javais'}   44.889  run_1  word       None   
2         {'kind': 'word', 'word': 'six'}   45.189  run_1  word         16   
3         {'kind': 'word', 'word': 'ans

AttributeError: 'BaseEpochs' object has no attribute 'metadat'

In [9]:
epochs.metadata

Unnamed: 0.1,index,Unnamed: 0,word,onset,duration,trial_type,start,label,kind,word_index,sequence_id,sequence_uid,closing_,match_token
0,0,0,Lorsque,0.70,0.3,"{'kind': 'word', 'word': 'Lorsque'}",44.534,run_1,word,,,,,
1,1,1,javais,1.05,0.3,"{'kind': 'word', 'word': 'javais'}",44.889,run_1,word,,,,,
2,2,2,six,1.40,0.3,"{'kind': 'word', 'word': 'six'}",45.189,run_1,word,16,2,Sint (VN (CLS-SUJ 1=j'avais) (DET 2=six) (NC 3...,3,six
3,3,3,ans,1.75,0.3,"{'kind': 'word', 'word': 'ans'}",45.489,run_1,word,22,2,Sint (VN (CLS-SUJ 1=j'avais) (DET 2=six) (NC 3...,2,ans
4,4,4,jai,2.10,0.3,"{'kind': 'word', 'word': 'jai'}",45.806,run_1,word,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1406,1459,1459,ne,511.35,0.3,"{'kind': 'word', 'word': 'ne'}",528.061,run_1,word,111,164,Sint (PP-MOD (PP-MOD (P 1=avec) (NP (ADV+ (DET...,0,ne
1407,1460,1460,peut,511.70,0.3,"{'kind': 'word', 'word': 'peut'}",528.378,run_1,word,117,164,Sint (PP-MOD (PP-MOD (P 1=avec) (NP (ADV+ (DET...,0,peut
1408,1461,1461,pas,512.05,0.3,"{'kind': 'word', 'word': 'pas'}",528.694,run_1,word,124,164,Sint (PP-MOD (PP-MOD (P 1=avec) (NP (ADV+ (DET...,0,pas
1409,1462,1462,aller,512.40,0.3,"{'kind': 'word', 'word': 'aller'}",529.011,run_1,word,136,164,Sint (PP-MOD (PP-MOD (P 1=avec) (NP (ADV+ (DET...,0,aller
