In [None]:
import numpy as np
import os
import pandas as pd
import h5py
import pickle
import mne
import csv
import re
import tqdm

import sys
sys.path.append('../preprocessing/utils/')
import strf
from utils import make_delayed, zs

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.colors import Normalize
from matplotlib import cm
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
plt.style.use('seaborn')
rc['pdf.fonttype'] = 42
import seaborn as sns

In [None]:
# Change these path for running the notebook locally
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'
# Where the output of train_linear_model.ipynb is saved. Run that first if you haven't already.
h5_path = '/path/to/h5/' 

In [None]:
perception_color = '#117733'
production_color = '#332288'
all_color = 'gray'
picks = ['F1','Fz','F2','FC1','FCz','FC2','C1','Cz','C2']
tmin,tmax = -.3, .5
delays = np.arange(np.floor(tmin*128),np.ceil(tmax*128),dtype=int)
exclude = ['OP0001','OP0002','OP0004','OP0017','OP0020']
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and s[-6:] not in exclude])
models = ['model1','model1e','model2','model2e','model3','model3e','model4','model4e']
single_model = 'model2'
features = {model_number:strf.get_feats(model_number=model_number,extend_labels=True) for model_number in models}
n_feats = {model_number:len(features[model_number]) for model_number in models}

## Load data

In [None]:
# Load data from hdf5, pandas
wts, corrs, pvals, sig_wts, sig_corrs, alphas = dict(), dict(), dict(), dict(), dict(), dict()
results_csv_fpath = f"{git_path}stats/lem_results.csv"
df = pd.read_csv(results_csv_fpath)
for m in models:
    wts[m], corrs[m], pvals[m], sig_wts[m], sig_corrs[m], alphas[m] = dict(), dict(), dict(), dict(), dict(), dict()
    b = tqdm(subjs)
    for s in b:
        blockid = f"{s}_B1"
        b.set_description(f'Loading STRF for {s} {m}')
        with h5py.File(f"{h5_path}{s}_weights.hdf5",'r') as f:
            wts[m][s] = np.array(f.get(m))
        ch_names = mne.io.read_raw_brainvision(f"{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr",
                                               preload=False,verbose=False).info['ch_names']
        subj_corrs, subj_best_alphas, subj_pvals = np.zeros(len(ch_names)), np.zeros(len(ch_names)), np.zeros(len(ch_names))
        for i, ch in enumerate(ch_names):
            tgt_row = df[(df['subject']==s) & (df['model']==m) & (df['channel']==ch)]
            subj_corrs[i] = df.loc[tgt_row.index, 'r_value']
            subj_best_alphas[i] = df.loc[tgt_row.index, 'best_alpha']
            subj_pvals[i] = df.loc[tgt_row.index, 'p_value']
        corrs[m][s] = np.array(subj_corrs)
        pvals[m][s] = np.array(subj_pvals)
        alphas[m][s] = np.array(subj_best_alphas)
    # Extract significant weights, corrs
    for s in subjs:
        nchans = wts[m][s].shape[2]
        sig_wts[m][s] = np.zeros((len(delays),n_feats[m],nchans))
        sig_corrs[m][s] = np.zeros((nchans))
        for i in np.arange(nchans):
            if pvals[m][s][i] < 0.01:
                sig_wts[m][s][i] = wts[m][s][i]
                sig_corrs[m][s][i] = corrs[m][s][i]

In [None]:
# Load raw from MNE vhdr
raws, ch_names = dict(), dict()
for s in subjs:
    blockid = s + "_B1"
    raws[s] = mne.io.read_raw_brainvision(f"{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr",
                                              preload=True,verbose=False)
    ch_names[s] = raws[s].info['ch_names']

#### Calculate prediction of held-out EEG

In [None]:
preds = dict()
for s in subjs:
    preds[s] = dict()
    # Spkr pred
    tmp_wts = wts[s][single_model][:,spkr_feats+all_feats,:]
    n_delays, n_feats, n_chans = tmp_wts.shape
    tmp_wts = tmp_wts.reshape((n_delays*n_feats,n_chans))
    _,preds[s]['spkr_pred'] = strf.predict_response(tmp_wts,
                                                   make_delayed(vStim[s][m][:,spkr_feats+all_feats], delays),
                                                   zs(vResp[s][m]))
    # Mic pred
    tmp_wts = wts[s][single_model][:,mic_feats+all_feats,:]
    n_delays, n_feats, n_chans = tmp_wts.shape
    tmp_wts = tmp_wts.reshape((n_delays*n_feats,n_chans))
    _,preds[s]['mic_pred'] = strf.predict_response(tmp_wts,
                                                   make_delayed(vStim[s][m][:,mic_feats+all_feats], delays),
                                                   zs(vResp[s][m]))
    # All pred
    tmp_wts = wts[s][single_model][:,all_feats,:]
    n_delays, n_feats, n_chans = tmp_wts.shape
    tmp_wts = tmp_wts.reshape((n_delays*n_feats,n_chans))
    _,preds[s]['all_pred'] = strf.predict_response(tmp_wts,
                                                   make_delayed(vStim[s][m][:,all_feats], delays),
                                                   zs(vResp[s][m]))

#### calculate weight correlations

In [None]:
wt_corrs = dict()
n_feats = len(all_feats)
for s in subjs:
    n_delays, _, n_chans = wts[s][single_model].shape
    spkr_wts = wts[s][single_model][:,spkr_feats,:]
    mic_wts = wts[s][single_model][:,mic_feats,:]
    wt_corrs[s] = np.zeros((n_feats,n_chans))
    for ci in np.arange(n_chans):
        for fi in np.arange(n_feats):
            wt_corrs[s][fi,ci] = np.corrcoef(spkr_wts[:,fi,ci], mic_wts[:,fi,ci])[0,1]

#### epoch the prediction data
this may take a while

In [None]:
example_elecs = {
    'OP0007':"F7",
    "OP0010":"FT7",
    "OP0014":"F3"
}

In [None]:
# Get the sentence onsets
all_onsets = dict()
tmin, tmax = 0.0, 0.5
for s in subjs:
    ch_idx = ch_names[s].index(example_elecs[s])
    blockid = s + "_B1"
    resp = raws[s].get_data()
    fs = raws[s].info['sfreq']
    # Read event files
    onsets, offsets, ids, = [], [], []
    # Read event files (spkr)
    spkr_sn_ev_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_spkr_sn_all.txt"
    with open(spkr_sn_ev_fpath,'r') as f:
        c = csv.reader(f,delimiter='\t')
        for row in c:
            onsets.append(int(float(row[0])*fs))
            offsets.append(int(float(row[1])*fs))
            ids.append(int(row[2]))
    if s in ['OP0015','OP0016']:
        # Load events from second block too
        last_samp = mne.io.read_raw_brainvision(
            f"{eeg_data_path}{s}/{blockid}/{blockid}_downsampled.vhdr",preload=False,verbose=False
        ).last_samp
        b2_blockid = f"{s}_B2"
        b2_spkr_sn_ev_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_spkr_sn_all.txt"
        with open(b2_spkr_sn_ev_fpath,'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                onsets.append(int(float(row[0])*fs)+last_samp)
                offsets.append(int(float(row[1])*fs)+last_samp)
                ids.append(int(row[2]))
    # Read event files (mic)
    mic_sn_ev_fpath = f"{git_path}eventfiles/{s}/{blockid}/{blockid}_mic_sn_all.txt"
    with open(mic_sn_ev_fpath,'r') as f:
        c = csv.reader(f,delimiter='\t')
        for row in c:
            onsets.append(int(float(row[0])*fs))
            offsets.append(int(float(row[1])*fs))
            ids.append(int(row[2]))
    if s in ['OP0015','OP0016']:
        # Load events from second block too
        b2_mic_sn_ev_fpath = f"{git_path}eventfiles/{s}/{b2_blockid}/{b2_blockid}_mic_sn_all.txt"
        with open(b2_mic_sn_ev_fpath,'r') as f:
            c = csv.reader(f,delimiter='\t')
            for row in c:
                onsets.append(int(float(row[0])*fs)+last_samp)
                offsets.append(int(float(row[1])*fs)+last_samp)
                ids.append(int(row[2]))
    # Split events sentence-by-sentence
    sn_events = dict()
    for this_sentence in range(len(np.unique(ids))):
        sn_ranges = []
        for i, sn_id in enumerate(ids):
            if sn_id == this_sentence:
                onset_samp = onsets[i]
                offset_samp = offsets[i]
                sn_ranges.append([onset_samp,offset_samp])
        sn_events[this_sentence] = sn_ranges        
    # Split stim/resp sentence-by-sentence
    resp_dict, sn_onsets_dict, sn_offsets_dict = dict(), dict(), dict()
    for this_sentence in range(len(np.unique(ids))):
        sn_resps, onset_matches, offset_matches = [], [], []
        for i, ev in enumerate(sn_events[this_sentence]):
            onset = ev[0]
            offset = ev[1]
            for samp_idx in np.arange(resp.shape[1]):
                if samp_idx >= onset and samp_idx <= offset:
                    sn_resps.append(resp[:,samp_idx])
                    if samp_idx == onset:
                        onset_matches.append(1)
                    else:
                        onset_matches.append(0)
        resp_dict[this_sentence] = np.array(sn_resps)
        sn_onsets_dict[this_sentence] = np.array(onset_matches)
        sn_offsets_dict[this_sentence] = np.array(offset_matches)
    # Split stim/resp into training/validation sets along sentence boundaries
    nsentences = 50
    tv_split = 40 # 40 sentences IDs to train, remaining 10 to validate
    np.random.seed(6655321)
    val_sn_ids = np.random.permutation(nsentences)[tv_split:]
    vResps_by_sn, vResp_onsets, vResp_offsets = dict(), dict(), dict()
    for this_sentence in val_sn_ids:
        vResps_by_sn[this_sentence] = resp_dict[this_sentence]
        vResp_onsets[this_sentence] = sn_onsets_dict[this_sentence]
        vResp_offsets[this_sentence] = sn_offsets_dict[this_sentence]
    # Reshape vResp/pred so they're sentence-by-sentence and "epoch"
    all_onsets[s] = np.where(np.hstack((list(vResp_onsets.values())))==1)[0]

In [None]:
# "Epoch" to sentence onsets
epochs = dict()
for s in subjs:
    epochs[s] = dict()
    ch_idx = ch_names[s].index(example_elecs[s])
    for p in preds[s].keys():
        sentences = []
        for i, samp in enumerate(all_onsets[s]):
            onset = samp - np.abs(tmin*fs).astype(int)
            offset = samp + int(tmax*fs)
            sentences.append(preds[s][p][onset:offset,ch_idx])
        epochs[s][p] = zs(np.array(sentences))
    # Epoch the actual resp too
    sentences = []
    for i, samp in enumerate(all_onsets[s]):
        onset = samp - np.abs(tmin*fs).astype(int)
        offset = samp + int(tmax*fs)
        sentences.append(vResp[s][single_model][onset:offset,ch_idx])
        epochs[s]['vResp'] = zs(np.array(sentences))

## Plot

In [None]:
def sem(epochs):
    '''
    calculates standard error margin across epochs
    epochs should have shape (epochs,samples)
    '''
    sem_below = epochs.mean(0) - (epochs.std(0)/np.sqrt(epochs.shape[0]))
    sem_above = epochs.mean(0) + (epochs.std(0)/np.sqrt(epochs.shape[0]))
    return sem_below, sem_above

In [None]:
x = np.linspace(tmin,tmax,epochs[s]['vResp'].shape[1])

In [None]:
# "Grand average" prediction plots
for s in subjs:
    plt.figure(figsize=(7,3))
    ch_idx = ch_names[s].index(example_elecs[s])
    if s == 'OP0010':
        ymin = epochs[s]['all_pred'].mean(0).min() * 1.5
        ymax = -ymin
#         ymax = epochs[s]['all_pred'].mean(0).max() * 1.5
    elif s == 'OP0007':
        ymin = epochs[s]['all_pred'].mean(0).min() * 1.08
        ymax = -ymin
    else:
        ymin = epochs[s]['all_pred'].mean(0).min() * 1.15
        ymax = -ymin
#         ymax = epochs[s]['all_pred'].mean(0).max() * 1.15
    # Spkr vs all
    plt.subplot(1,3,1)
    plt.plot(x,epochs[s]['spkr_pred'].mean(0),color=spkr_color)
    plt.fill_between(x,sem(epochs[s]['spkr_pred'])[0],sem(epochs[s]['spkr_pred'])[1],color=spkr_color,alpha=0.3)
    plt.plot(x,epochs[s]['all_pred'].mean(0),color=all_color)
    plt.fill_between(x,sem(epochs[s]['all_pred'])[0],sem(epochs[s]['all_pred'])[1],color=all_color,alpha=0.3)
    plt.gca().set_xlim(x[0],x[-1])
    plt.gca().set_ylim([ymin,ymax])
    plt.xlabel("Time (s)",fontsize=10)
    plt.ylabel("Z-scored EEG ± SEM", fontsize=10)
    plt.gca().set_yticks([ymin,ymin/2,0,ymax/2,ymax])
    plt.gca().set_yticklabels(["%.2f"%ymin,'',0,'',"%.2f"%ymax])
    plt.gca().set_xticks([tmin,tmax*0.25,tmax*0.5,tmax*0.75,tmax])
    plt.gca().set_xticklabels([tmin,'','','',tmax])
    
    # mic vs all
    plt.subplot(1,3,2)
    plt.plot(x,epochs[s]['mic_pred'].mean(0),color=mic_color)
    plt.fill_between(x,sem(epochs[s]['mic_pred'])[0],sem(epochs[s]['mic_pred'])[1],color=mic_color,alpha=0.3)
    plt.plot(x,epochs[s]['all_pred'].mean(0),color=all_color)
    plt.fill_between(x,sem(epochs[s]['all_pred'])[0],sem(epochs[s]['all_pred'])[1],color=all_color,alpha=0.3)
    plt.gca().set_ylim([ymin,ymax])
    plt.gca().set_yticklabels([])
    plt.gca().set_xticklabels([])
    plt.gca().set_xlim(x[0],x[-1])
    plt.gca().set_yticks([ymin,ymin/2,0,ymax/2,ymax])
    plt.gca().set_xticks([tmin,tmax*0.25,tmax*0.5,tmax*0.75,tmax])
    plt.gca().set_xticklabels([tmin,'','','',tmax])

    
    # spkr vs mic
    plt.subplot(1,3,3)
    plt.plot(x,epochs[s]['spkr_pred'].mean(0),color=spkr_color)
    plt.fill_between(x,sem(epochs[s]['spkr_pred'])[0],sem(epochs[s]['spkr_pred'])[1],color=spkr_color,alpha=0.3)
    plt.plot(x,epochs[s]['mic_pred'].mean(0),color=mic_color)
    plt.fill_between(x,sem(epochs[s]['mic_pred'])[0],sem(epochs[s]['mic_pred'])[1],color=mic_color,alpha=0.3)
    plt.gca().set_xlim(x[0],x[-1])
    plt.gca().set_yticklabels([])
    plt.gca().set_xticklabels([])
    plt.gca().set_ylim([ymin,ymax])
    plt.gca().set_yticks([ymin,ymin/2,0,ymax/2,ymax])
    plt.gca().set_xticks([tmin,tmax*0.25,tmax*0.5,tmax*0.75,tmax])
    plt.gca().set_xticklabels([tmin,'','','',tmax])

    
    plt.suptitle(f"{s} {example_elecs[s]}",fontsize=12)
    plt.tight_layout();

In [None]:
# Legend
plt.figure(figsize=(7,1))
plt.axis(False)
plt.bar(0,0,color=spkr_color,label="Pred. EEG (perception-specific weights)")
plt.bar(0,0,color=mic_color,label="Pred. EEG (production-specific weights)")
plt.bar(0,0,color=all_color,label="Pred. EEG (combined weights)")
plt.legend(fontsize=10);

In [None]:
# Weight correlations
corr_vmin = -1
corr_vmax = 0
corr_cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#41047F-EAEAF2
    (0.000, (0.255, 0.016, 0.498)),
    (1.000, (0.918, 0.918, 0.949))))
for s in subjs:
    ch_idx = ch_names[s].index(example_elecs[s])
    fig = plt.figure(figsize=(8,4))
    gs = GridSpec(8, 17, figure=fig)
    axes = []
    # Spkr feats
    axes.append(fig.add_subplot(gs[:,:8]))
    ch_wts = wts[s][single_model][:,spkr_feats,ch_idx]
    vmax = ch_wts.max()
    plt.imshow(ch_wts.T,
           aspect='auto',interpolation='nearest',vmin=-vmax,vmax=vmax,cmap=cm.RdBu_r)
    plt.gca().set_yticks(np.arange(n_feats))
    plt.gca().set_yticklabels(features[:14])
    plt.gca().set_xticks([0,np.where(delays==0)[0][0],n_delays])
    plt.gca().set_xticklabels([delay_min,0.0,delay_max])
    plt.axvline(np.where(delays==0)[0][0],color='k')
    plt.grid(False)
    plt.title("Perception-specific phnfeat")
    # Correlation
    axes.append(fig.add_subplot(gs[:,8]))
    plt.imshow(np.expand_dims(wt_corrs[s][:,ch_idx],axis=0).T,
               aspect='auto',interpolation='nearest',vmin=corr_vmin,vmax=corr_vmax,cmap=corr_cmap)
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])
    plt.title("corr")
    plt.grid(False)
    # Mic feats
    axes.append(fig.add_subplot(gs[:,9:]))
    ch_wts = wts[s][single_model][:,mic_feats,ch_idx]
    vmax = ch_wts.max()
    plt.imshow(ch_wts.T,
           aspect='auto',interpolation='nearest',vmin=-vmax,vmax=vmax,cmap=cm.RdBu_r)
    plt.axvline(np.where(delays==0)[0][0],color='k')
    plt.gca().set_yticks([])
    plt.gca().set_xticks([])
    plt.grid(False)
    plt.title("Production-specific phnfeat")
    plt.suptitle(f"{s} {ch_names[s][ch_idx]}", fontsize=12)
    gs.tight_layout(figure=fig)

In [None]:
# legend
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(np.zeros((1,1)),aspect='auto',interpolation='nearest',cmap=corr_cmap,vmin=corr_vmin,vmax=corr_vmax)
plt.gca().set_visible(False)
plt.colorbar(orientation='horizontal');
plt.subplot(1,2,2)
plt.imshow(np.zeros((1,1)),aspect='auto',interpolation='nearest',cmap=cm.RdBu_r)
plt.gca().set_visible(False)
plt.colorbar(orientation='horizontal');
plt.tight_layout()

In [None]:
# Violin plots
plt.figure(figsize=(2,9))
sns.violinplot(data=pd.DataFrame.from_dict(all_wt_corrs,orient='index').T,
               orient='h', palette=['#b2b2b2'] * len(subjs), scale='width')
plt.axvline(0,color='k',ls='--',lw=0.5)
plt.gca().set_yticklabels(subjs,fontsize=8)
plt.ylabel("Subject",fontsize=12)
plt.gca().set_xlim([-1.15,1.15])
plt.gca().set_xticks([-1,-.5,0,.5,1])
plt.gca().set_xticklabels([-1.0,-.5,0.0,0.5,1.0],fontsize=8)
plt.xlabel("r",fontsize=12)
plt.title("Correlation of production-specific\nand perception-specific weights", fontsize=14)