In [4]:
import mne
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.linear_model import RidgeCV
from wordfreq import zipf_frequency
from Levenshtein import editops

## Initial textual data

The textual data that we'll use to compare GPT2 and MEG activations is from Le Petit Prince.
Let's build an array with each words.

In [5]:
file = '~/Downloads/lpp.csv'
df = pd.read_csv(file)
list_words = (df.iloc[:,1])
list_words_run1 = list_words[:1610]

## MEG Activations

In [8]:
# Read raw file
raw_file = '/home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_meg.fif'
raw = mne.io.read_raw_fif(raw_file,allow_maxshield = True)

# Load data, filter
raw.pick_types(meg=True, stim=True)
raw.load_data()
raw = raw.filter(.5, 20)

# Load events and realign them
event_file = '/home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_events.tsv'
meta = pd.read_csv(event_file, sep='\t')
events = mne.find_events(raw, stim_channel='STI101',
                            shortest_event=1, min_duration=0.0010001)

# match events and metadata
word_events = events[events[:, 2] > 1]
meg_delta = np.round(np.diff(word_events[:, 0]/raw.info['sfreq']))
meta_delta = np.round(np.diff(meta.onset.values))

i, j = match_list(meg_delta, meta_delta)
events = word_events[i]
meta = meta.iloc[j].reset_index()

# Epoch on the aligned events and load the epoch data
epochs = mne.Epochs(raw, events, metadata=meta,
                    tmin=-.3, tmax=.8, decim=10, baseline=(-0.2, 0.0))

data = epochs.get_data()
epochs.load_data()

# Scale the data
n_words, n_chans, n_times = data.shape
vec = data.transpose(0, 2, 1).reshape(-1, n_chans)
scaler = RobustScaler()
idx = np.arange(len(vec))
np.random.shuffle(idx)
vec = scaler.fit(vec[idx[:20_000]]).transform(vec)
sigma = 7
vec = np.clip(vec, -sigma, sigma)
epochs._data[:, :, :] = scaler.inverse_transform(vec)\
    .reshape(n_words, n_times, n_chans).transpose(0, 2, 1)

Opening raw data file /home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_meg.fif...
    Read a total of 13 projection items:
        grad_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v6 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v7 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v8 (1 x 306)  idle
    Range : 23000 ... 649999 =     23.000 ...   649.999 secs
Ready.
Reading 0 ... 626999  =      0.000 ...   626.999 secs...


  raw = mne.io.read_raw_fif(raw_file,allow_maxshield = True)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 20 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 20.00 Hz
- Upper transition bandwidth: 5.00 Hz (-6 dB cutoff frequency: 22.50 Hz)
- Filter length: 6601 samples (6.601 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s


1616 events found
Event IDs: [  1 128]


[Parallel(n_jobs=1)]: Done 306 out of 306 | elapsed:    4.7s finished


Adding metadata with 4 columns
1610 matching events found
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 13)
13 projection items activated
Using data from preloaded Raw for 1610 events and 1101 original time points (prior to decimation) ...
0 bad epochs dropped
Using data from preloaded Raw for 1610 events and 1101 original time points (prior to decimation) ...


In [12]:
epochs

0,1
Number of events,1610
Events,128: 1610
Time range,-0.300 – 0.800 sec
Baseline,-0.200 – 0.000 sec


## GPT 2 Activations

In [17]:
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

model = GPT2Model.from_pretrained("gpt2")


for word in list_words_run1:
    inputs = tokenizer(word,return_tensors='pt')

outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

In [18]:
outputs

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.1565, -0.0245, -0.2290,  ..., -0.1638,  0.0065, -0.0811],
         [-0.4882,  0.0754,  0.1907,  ..., -0.1567,  0.2462,  0.2180]]],
       grad_fn=<ViewBackward0>), past_key_values=((tensor([[[[-1.0082,  2.2626,  0.1203,  ..., -1.8167, -0.6454,  1.3689],
          [-2.4185,  2.7140,  2.0556,  ..., -0.3029, -0.6959,  2.1567]],

         [[-0.2322,  0.0790, -0.6130,  ...,  0.0453,  1.6757,  1.3713],
          [-0.7260, -0.6771, -2.7542,  ..., -0.0202,  3.4651,  0.7425]],

         [[ 0.6486,  0.0628,  0.8847,  ..., -1.5250, -1.6209,  0.3861],
          [ 0.2770, -0.3390,  0.9356,  ..., -2.3814, -0.2867,  1.8097]],

         ...,

         [[ 0.1187, -0.4680, -0.2099,  ...,  0.4519,  0.6347,  0.6258],
          [ 0.1671, -0.1267, -0.0718,  ...,  0.5801,  0.7477,  0.6586]],

         [[ 0.7430,  0.9419, -0.7024,  ..., -0.5742,  0.7436, -1.1344],
          [ 1.0718,  0.6798, -0.4590,  ..., -0.6027,  1.0389, -0.4415]],



# Utils

In [7]:
# Utils
def match_list(A, B, on_replace="delete"):
    """Match two lists of different sizes and return corresponding indice
    Parameters
    ----------
    A: list | array, shape (n,)
        The values of the first list
    B: list | array: shape (m, )
        The values of the second list
    Returns
    -------
    A_idx : array
        The indices of the A list that match those of the B
    B_idx : array
        The indices of the B list that match those of the A
    """

    if not isinstance(A, str):
        unique = np.unique(np.r_[A, B])
        label_encoder = dict((k, v) for v, k in enumerate(unique))

        def int_to_unicode(array: np.ndarray) -> str:
            return "".join([str(chr(label_encoder[ii])) for ii in array])

        A = int_to_unicode(A)
        B = int_to_unicode(B)

    changes = editops(A, B)
    B_sel = np.arange(len(B)).astype(float)
    A_sel = np.arange(len(A)).astype(float)
    for type_, val_a, val_b in changes:
        if type_ == "insert":
            B_sel[val_b] = np.nan
        elif type_ == "delete":
            A_sel[val_a] = np.nan
        elif on_replace == "delete":
            # print('delete replace')
            A_sel[val_a] = np.nan
            B_sel[val_b] = np.nan
        elif on_replace == "keep":
            # print('keep replace')
            pass
        else:
            raise NotImplementedError
    B_sel = B_sel[np.where(~np.isnan(B_sel))]
    A_sel = A_sel[np.where(~np.isnan(A_sel))]
    assert len(B_sel) == len(A_sel)
    return A_sel.astype(int), B_sel.astype(int)

