In [1]:
import mne
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.linear_model import RidgeCV
from wordfreq import zipf_frequency
from Levenshtein import editops
import torch
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


## Initial textual data

The textual data that we'll use to compare GPT2 and MEG activations is from Le Petit Prince.
Let's build an array with each words.

In [2]:
file = '~/Downloads/lpp.csv'
df = pd.read_csv(file)
list_words = (df.iloc[:,1])
list_words_run1 = list_words[:1610]

## MEG Activations

In [5]:
# Read raw file
raw_file = '/home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_meg.fif'
raw = mne.io.read_raw_fif(raw_file,allow_maxshield = True)

# Load data, filter
raw.pick_types(meg=True, stim=True)
raw.load_data()
raw = raw.filter(.5, 20)

# Load events and realign them
event_file = '/home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_events.tsv'
meta = pd.read_csv(event_file, sep='\t')
events = mne.find_events(raw, stim_channel='STI101',
                            shortest_event=1, min_duration=0.0010001)

# match events and metadata
word_events = events[events[:, 2] > 1]
meg_delta = np.round(np.diff(word_events[:, 0]/raw.info['sfreq']))
meta_delta = np.round(np.diff(meta.onset.values))

i, j = match_list(meg_delta, meta_delta)
events = word_events[i]
meta = meta.iloc[j].reset_index()

# Epoch on the aligned events and load the epoch data
epochs = mne.Epochs(raw, events, metadata=meta,
                    tmin=-.3, tmax=.8, decim=10, baseline=(-0.2, 0.0))

data = epochs.get_data()
epochs.load_data()

# Scale the data
n_words, n_chans, n_times = data.shape
vec = data.transpose(0, 2, 1).reshape(-1, n_chans)
scaler = RobustScaler()
idx = np.arange(len(vec))
np.random.shuffle(idx)
vec = scaler.fit(vec[idx[:20_000]]).transform(vec)
sigma = 7
vec = np.clip(vec, -sigma, sigma)
epochs._data[:, :, :] = scaler.inverse_transform(vec)\
    .reshape(n_words, n_times, n_chans).transpose(0, 2, 1)

Opening raw data file /home/co/workspace_LPP/data/MEG/LPP/LPP_bids/sub-220628/ses-01/meg/sub-220628_ses-01_task-rest_run-01_meg.fif...
    Read a total of 13 projection items:
        grad_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        grad_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v1 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v2 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v3 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v4 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v5 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v6 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v7 (1 x 306)  idle
        mag_ssp_upright.fif : PCA-v8 (1 x 306)  idle
    Range : 23000 ... 649999 =     23.000 ...   649.999 secs
Ready.
Reading 0 ... 626999  =      0.000 ...   626.999 secs...


  raw = mne.io.read_raw_fif(raw_file,allow_maxshield = True)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 20 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 20.00 Hz
- Upper transition bandwidth: 5.00 Hz (-6 dB cutoff frequency: 22.50 Hz)
- Filter length: 6601 samples (6.601 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done 306 out of 306 | elapsed:    5.0s finished


1616 events found
Event IDs: [  1 128]
Adding metadata with 4 columns
1610 matching events found
Applying baseline correction (mode: mean)
Created an SSP operator (subspace dimension = 13)
13 projection items activated
Using data from preloaded Raw for 1610 events and 1101 original time points (prior to decimation) ...
0 bad epochs dropped
Using data from preloaded Raw for 1610 events and 1101 original time points (prior to decimation) ...


In [6]:
epochs.get_data().shape

(1610, 325, 111)

## GPT 2 Activations

In [7]:
from transformers import GPT2Config, GPT2Model, GPT2Tokenizer

# Initiate the model and its trained weights
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")

In [8]:
embeddings_array = [] # The list containing the different embeddings
# Accessible under the parameters: embeddings_array[word_n, layer_n, 0, 0]

# For each word, tokenize it and get the embeddings from it
for word in list_words_run1:
    inputs = tokenizer(word,return_tensors='pt')

    outputs = model(**inputs,output_hidden_states=True)

    embeddings_array.append(outputs.hidden_states)

In [9]:
len(embeddings_array)
(embeddings_array[0][12][0])

tensor([[-0.0967, -0.0303, -0.3699,  ..., -0.1717,  0.0229, -0.0405],
        [-0.2289, -0.3152, -0.3306,  ..., -0.1642,  0.0195,  0.2693],
        [-0.1142, -0.4238, -0.2373,  ...,  0.0927,  0.1045,  0.1797]],
       grad_fn=<SelectBackward0>)

In [11]:
embeddings_array[0][0].shape

torch.Size([1, 3, 768])

## Correlation

In [21]:
final_array = []
for tuple_ in embeddings_array:

    tensor = torch.cat(tuple_)
    print(tensor)
    tensor = tensor[:,0,:]
    tensor.reshape(13,1,768)
    tensor = tensor.detach().numpy()
    final_array.append(tensor)

final_array = np.reshape(final_array,(1610,-1))

tensor([[[ 1.3942e-01, -3.0973e-01,  1.0083e-01,  ..., -1.5958e-01,
           8.6239e-02, -3.7619e-02],
         [-1.3111e-01, -1.2052e-01,  3.8821e-02,  ..., -7.4348e-02,
           7.4390e-02,  2.9336e-02],
         [ 1.9243e-02, -2.1743e-01,  1.7124e-01,  ...,  4.6557e-02,
           2.8480e-02,  2.9179e-01]],

        [[ 7.9171e-01, -1.4054e-01, -1.6875e-01,  ..., -9.8208e-01,
           5.1664e-01,  1.7883e-01],
         [-8.4687e-01, -1.1502e+00, -2.0205e-01,  ...,  1.4145e-02,
           3.7234e-01,  9.6948e-02],
         [ 6.3530e-01, -3.0820e-01, -6.5044e-01,  ..., -7.3857e-01,
           1.7940e-01,  7.1559e-01]],

        [[ 9.4581e-04, -9.1928e-01, -4.7588e-01,  ..., -9.4059e-01,
           1.0360e+00,  2.8523e-01],
         [-1.0128e+00, -1.1867e+00, -3.5754e-01,  ..., -7.4775e-02,
           6.2627e-01,  4.6940e-01],
         [ 4.5485e-01,  5.4101e-01, -6.5192e-01,  ..., -7.6885e-01,
           6.6697e-01,  6.9065e-01]],

        ...,

        [[-2.0373e-01, -1.1441e+00,

In [27]:
X = final_array
X.shape

(1610, 9984)

In [25]:
y = epochs.get_data() 
y.shape

(1610, 325, 111)

In [35]:
R = decod(X, y)

...............................................................................................................

In [36]:

fig, ax = plt.subplots(1, figsize=[6, 6])
dec = plt.fill_between(epochs.times, R)

NameError: name 'plt' is not defined

# Utils

In [34]:
# Utils
def match_list(A, B, on_replace="delete"):
    """Match two lists of different sizes and return corresponding indice
    Parameters
    ----------
    A: list | array, shape (n,)
        The values of the first list
    B: list | array: shape (m, )
        The values of the second list
    Returns
    -------
    A_idx : array
        The indices of the A list that match those of the B
    B_idx : array
        The indices of the B list that match those of the A
    """

    if not isinstance(A, str):
        unique = np.unique(np.r_[A, B])
        label_encoder = dict((k, v) for v, k in enumerate(unique))

        def int_to_unicode(array: np.ndarray) -> str:
            return "".join([str(chr(label_encoder[ii])) for ii in array])

        A = int_to_unicode(A)
        B = int_to_unicode(B)

    changes = editops(A, B)
    B_sel = np.arange(len(B)).astype(float)
    A_sel = np.arange(len(A)).astype(float)
    for type_, val_a, val_b in changes:
        if type_ == "insert":
            B_sel[val_b] = np.nan
        elif type_ == "delete":
            A_sel[val_a] = np.nan
        elif on_replace == "delete":
            # print('delete replace')
            A_sel[val_a] = np.nan
            B_sel[val_b] = np.nan
        elif on_replace == "keep":
            # print('keep replace')
            pass
        else:
            raise NotImplementedError
    B_sel = B_sel[np.where(~np.isnan(B_sel))]
    A_sel = A_sel[np.where(~np.isnan(A_sel))]
    assert len(B_sel) == len(A_sel)
    return A_sel.astype(int), B_sel.astype(int)


def decod(X, y):
    assert len(X) == len(y)
    # define data
    model = make_pipeline(StandardScaler(), RidgeCV(alphas=np.logspace(-3, 8, 10)))
    cv = KFold(5, shuffle=True, random_state=0)

    # fit predict
    n, n_chans, n_times = y.shape
    R = np.zeros(n_times)
    for t in range(n_times):
        print(".", end="")
        y_pred = cross_val_predict(model, X, y[:,:,t], cv=cv)
        R[t] = correlate(y[:,:,t], y_pred)
    return R

# Function to correlate
def correlate(X, Y):
    correlation = np.corrcoef(X,Y)
    return np.mean(correlation)
    # if X.ndim == 1:
    #     X = np.array(X)[:, None]
    # if Y.ndim == 1:
    #     Y = np.array(Y)[:, None]
    # X = X - X.mean(0)
    # Y = Y - Y.mean(0)

    # SX2 = (X**2).sum(0) ** 0.5
    # SY2 = (Y**2).sum(0) ** 0.5
    # SXY = (X * Y).sum(0)
    # return SXY / (SX2 * SY2)
