## Within subject statistics
AKA, Wilcoxon signed-rank tests

In [None]:
import mne
import numpy as np
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests
from glob import glob
import csv
import h5py
from tqdm.notebook import tqdm
import sys
sys.path.append('../preprocessing/utils/')
import strf
from matplotlib import pyplot as plt
plt.style.use('seaborn')
from matplotlib import rcParams as rc
rc['pdf.fonttype'] = 42

In [None]:
# Change these path for running the notebook locally
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'
# Where the output of train_linear_model.ipynb is saved. Run that first if you haven't already.
h5_path = '/path/to/h5/'

In [None]:
picks = ['F1','Fz','F2','FC1','FCz','FC2','C1','Cz','C2']
tmin,tmax = -1.5, 3.5
baseline=(None,0)
reject_thresh = 10 # SD
exclude = 'OP0020'
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and exclude not in s])

## Load ERP data

In [None]:
def interpolate_bads(subj,block,raw,git_path,eeg_data_path):
    '''
    CCA .vhdr files do not include bad channels, so this func interpolates them.
    The end result is each subject has a 64 channel raw arrray, which makes plotting easier
    Filter the data before running this!
    '''
    blockid = subj + '_' + block
    nsamps = len(raw)
    info = mne.io.read_raw_brainvision(f'{eeg_data_path}OP0001/OP0001_B1/OP0001_B1_cca.vhdr',
                                           preload=False,verbose=False).info
    ch_names = info['ch_names']
    bads = [c for c in ch_names if c not in raw.info['ch_names']]
    if len(bads) > 0:
        new_data = []
        for ch in ch_names:
            if ch in bads:
                print('Interpolating', ch)
                new_data.append(np.zeros((1,nsamps)))
            else:
                new_data.append(raw.get_data(picks=ch)[0])
        raw = mne.io.RawArray(np.vstack(new_data),info)
        raw.info['bads'] = bads
        raw.interpolate_bads()
    else:
        print(subj, 'has no bads.')
    return raw

In [None]:
raws, spkr_epochs, mic_epochs, spkr_resp, mic_resp = dict(), dict(), dict(), dict(), dict()
conditions = ['el','sh','all']
channels = ['spkr', 'mic']
conditions
subj_bar = tqdm(subjs)
for s in subj_bar:
    spkr_epochs[s], mic_epochs[s], spkr_resp[s], mic_resp[s] = dict(), dict(), dict(), dict()
    if s != 'OP0002':
        block = 'B1'
    else:
        block = 'B2'
    blockid = s + '_' + block
    raw_path = f'{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr'
    raw = mne.io.read_raw_brainvision(raw_path,preload=True,verbose=False)
    raw.filter(l_freq=1,h_freq=30,verbose=False)
    raws[s] = interpolate_bads(s,block,raw,git_path,eeg_data_path)
    for ch in channels:
        spkr_epochs[s][ch], mic_epochs[s][ch], spkr_resp[s][ch], mic_resp[s][ch] = dict(), dict(), dict(), dict()
        for co in conditions:
            subj_bar.set_description(f'Epoching {s} {ch} {co}')
            if s in ['OP0015','OP0016']: # mtpl blocks
                events = []
                for b in ['B1','B2']: 
                    event_file = f'{git_path}eventfiles/{s}/{s}_{b}/{s}_{b}_{ch}_sn_{co}.txt'
                    if b == 'B2':
                        b2_fpath = raw_path = f'{eeg_data_path}{s}/{s}_B1/{s}_B1_downsampled.vhdr'
                        b2_raw = mne.io.read_raw_brainvision(b2_fpath,preload=True,verbose=False)
                        block_shift = b2_raw.last_samp
                    else:
                        block_shift = 0
                    with open(event_file, 'r') as my_csv:
                        csvReader = csv.reader(my_csv, delimiter='\t')
                        for row in csvReader:
                            onset = int((float(row[0])*128)+block_shift)
                            offset = int((float(row[1])*128)+block_shift)
                            sn_id = int(row[2])
                            events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            else:
                event_file = f'{git_path}eventfiles/{s}/{blockid}/{blockid}_{ch}_sn_{co}.txt'
                events = []
                with open(event_file, 'r') as my_csv:
                    csvReader = csv.reader(my_csv, delimiter='\t')
                    for row in csvReader:
                        onset = int((float(row[0])*128))
                        offset = int((float(row[1])*128))
                        sn_id = int(row[2])
                        events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            reject = mne.Epochs(raws[s], events, tmin=tmin, tmax=tmax,reject=None,
                           baseline=baseline,verbose=False)
            reject = reject.get_data(picks=picks)
            reject = dict(eeg=np.std(reject)*(reject_thresh*2))
            epochs = mne.Epochs(raws[s],events,tmin=tmin,tmax=tmax,reject=reject,
                                   baseline=baseline,verbose=False)
            if ch == 'spkr':
                spkr_epochs[s][ch][co] = epochs
                spkr_resp[s][ch][co] = epochs.get_data(picks=picks)
            if ch == 'mic':
                mic_epochs[s][ch][co] = epochs
                mic_resp[s][ch][co] = epochs.get_data(picks=picks)

In [None]:
# get n100 and P200 info 
N100_window = [0.08,0.15]
P200_window = [0.15,0.25]
t = np.linspace(tmin,tmax,spkr_resp[s]['spkr']['all'].shape[2])
N100_inds = np.where((t>N100_window[0]) & (t<N100_window[1]))[0]
P200_inds = np.where((t>P200_window[0]) & (t<P200_window[1]))[0]

In [None]:
N100_el_peak_amplitude,N100_sh_peak_amplitude = dict(),dict()
P200_el_peak_amplitude,P200_sh_peak_amplitude = dict(),dict()
N100_el_peak_latency,N100_sh_peak_latency = dict(),dict()
P200_el_peak_latency,P200_sh_peak_latency = dict(),dict()
peak_to_peak_el,peak_to_peak_sh = dict(),dict()
for subj in subjs:
    if subj !='OP0015':
        # el
        N100_el_peak_amplitude[subj] = spkr_resp[subj]['spkr']['el'].mean(1)[:,N100_inds].min(1) # amplitude
        latency_idx = N100_inds[spkr_resp[subj]['spkr']['el'].mean(1)[:,N100_inds].argmin(1)] # latency
        N100_el_peak_latency[subj] = t[latency_idx]
        P200_el_peak_amplitude[subj] = spkr_resp[subj]['spkr']['el'].mean(1)[:,P200_inds].max(1) # amplitude
        latency_idx = P200_inds[spkr_resp[subj]['spkr']['el'].mean(1)[:,P200_inds].argmax(1)] # latency
        P200_el_peak_latency[subj] = t[latency_idx]
        peak_to_peak_el[subj] = np.abs(
            P200_el_peak_amplitude[subj] - N100_el_peak_amplitude[subj])
        # sh
        N100_sh_peak_amplitude[subj] = spkr_resp[subj]['spkr']['sh'].mean(1)[:,N100_inds].min(1) # amplitude
        latency_idx = N100_inds[spkr_resp[subj]['spkr']['sh'].mean(1)[:,N100_inds].argmin(1)] # latency
        N100_sh_peak_latency[subj] = t[latency_idx]
        P200_sh_peak_amplitude[subj] = spkr_resp[subj]['spkr']['sh'].mean(1)[:,P200_inds].max(1) # amplitude
        latency_idx = P200_inds[spkr_resp[subj]['spkr']['sh'].mean(1)[:,P200_inds].argmax(1)] # latency
        P200_sh_peak_latency[subj] = t[latency_idx]
        peak_to_peak_sh[subj] = np.abs(
            P200_sh_peak_amplitude[subj] - N100_sh_peak_amplitude[subj])

In [None]:
# Make el and sh the same length for wilcoxon
for subj in subjs:
    if subj != 'OP0015':
        el_len = N100_el_peak_amplitude[subj].shape[0]
        sh_len = N100_sh_peak_amplitude[subj].shape[0]
#         print("%s #echo epochs: %d; #shuff epochs: %d" % (subj, el_len, sh_len))
        n_epochs = np.min((el_len,sh_len))
        N100_el_peak_amplitude[subj] = N100_el_peak_amplitude[subj][:n_epochs]
        N100_sh_peak_amplitude[subj] = N100_sh_peak_amplitude[subj][:n_epochs]
        P200_el_peak_amplitude[subj] = P200_el_peak_amplitude[subj][:n_epochs]
        P200_sh_peak_amplitude[subj] = P200_sh_peak_amplitude[subj][:n_epochs]
        N100_el_peak_latency[subj] = N100_el_peak_latency[subj][:n_epochs]
        N100_sh_peak_latency[subj] = N100_sh_peak_latency[subj][:n_epochs]
        P200_el_peak_latency[subj] = P200_el_peak_latency[subj][:n_epochs]
        P200_sh_peak_latency[subj] = P200_sh_peak_latency[subj][:n_epochs]
        peak_to_peak_el[subj] = peak_to_peak_el[subj][:n_epochs]
        peak_to_peak_sh[subj] = peak_to_peak_sh[subj][:n_epochs]

In [None]:
N100_amp_pval, N100_lat_pval, P200_amp_pval, P200_lat_pval = dict(), dict(), dict(), dict()
p2p_pval = dict()
for subj in subjs:
    if subj != 'OP0015':
        stat, N100_amp_pval[subj] = wilcoxon(
            N100_el_peak_amplitude[subj],N100_sh_peak_amplitude[subj])
        reject, N100_amp_pval[subj], alphacSidak,alphacBonf = multipletests(
            N100_amp_pval[subj],alpha=0.05,method='fdr_by')
        # N100 latency
        stat, N100_lat_pval[subj] = wilcoxon(
            N100_el_peak_latency[subj],N100_sh_peak_latency[subj])
        reject, N100_lat_pval[subj], alphacSidak,alphacBonf = multipletests(
            N100_lat_pval[subj],alpha=0.05,method='fdr_by')
        # P200 amplitude
        stat, P200_amp_pval[subj] = wilcoxon(
            P200_el_peak_amplitude[subj],P200_sh_peak_amplitude[subj])
        reject, P200_amp_pval[subj], alphacSidak,alphacBonf = multipletests(
            P200_amp_pval[subj],alpha=0.05,method='fdr_by')
        # P200 latency
        stat, P200_lat_pval[subj] = wilcoxon(
            P200_el_peak_latency[subj],P200_sh_peak_latency[subj])
        reject, P200_lat_pval[subj], alphacSidak,alphacBonf = multipletests(
            P200_lat_pval[subj],alpha=0.05,method='fdr_by')
        # Peak to peak
        stat, p2p_pval[subj] = wilcoxon(
            peak_to_peak_el[subj],peak_to_peak_sh[subj])
        reject, p2p_pval[subj], alphacSidak,alphacBonf = multipletests(
            p2p_pval[subj],alpha=0.05,method='fdr_by')

In [None]:
for subj in subjs:
    if subj != 'OP0015':
        print(subj)
        print("N100 amplitude:", N100_amp_pval[subj][0])
        print("N100 latency:", N100_lat_pval[subj][0])
        print("P200 amplitude:", P200_amp_pval[subj][0])
        print("P200 latency:", P200_lat_pval[subj][0])
        print("Peak to peak:", p2p_pval[subj][0])
        print(" ")

## Load LEM data

In [None]:
models = ['model1','model1e','model2','model2e','model3','model3e','model4','model4e']
exclude = ['OP0001','OP0002','OP0004','OP0017','OP0020']
subjs = [s for s in subjs if s not in exclude]
tmin,tmax = -0.3,0.5
delays = np.arange(np.floor(tmin*128),np.ceil(tmax*128),dtype=int)
features = {model_number:strf.get_feats(model_number,extend_labels=True) for model_number in models}
n_feats = {model_number:len(features[model_number]) for model_number in models}

In [None]:
# Load data from hdf5, pandas
wts, corrs, pvals, sig_wts, sig_corrs, alphas = dict(), dict(), dict(), dict(), dict(), dict()
results_csv_fpath = f"{git_path}stats/lem_results.csv"
df = pd.read_csv(results_csv_fpath)
for m in models:
    wts[m], corrs[m], pvals[m], sig_wts[m], sig_corrs[m], alphas[m] = dict(), dict(), dict(), dict(), dict(), dict()
    b = tqdm(subjs)
    for s in b:
        blockid = f"{s}_B1"
        b.set_description(f'Loading STRF for {s} {m}')
        with h5py.File(f"{h5_path}{s}_weights.hdf5",'r') as f:
            wts[m][s] = np.array(f.get(m))
        ch_names = mne.io.read_raw_brainvision(f"{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr",
                                               preload=False,verbose=False).info['ch_names']
        subj_corrs, subj_best_alphas, subj_pvals = np.zeros(len(ch_names)), np.zeros(len(ch_names)), np.zeros(len(ch_names))
        for i, ch in enumerate(ch_names):
            tgt_row = df[(df['subject']==s) & (df['model']==m) & (df['channel']==ch)]
            subj_corrs[i] = df.loc[tgt_row.index, 'r_value']
            subj_best_alphas[i] = df.loc[tgt_row.index, 'best_alpha']
            subj_pvals[i] = df.loc[tgt_row.index, 'p_value']
        corrs[m][s] = np.array(subj_corrs)
        pvals[m][s] = np.array(subj_pvals)
        alphas[m][s] = np.array(subj_best_alphas)
    # Extract significant weights, corrs
    for s in subjs:
        nchans = wts[m][s].shape[2]
        sig_wts[m][s] = np.zeros((len(delays),n_feats[m],nchans))
        sig_corrs[m][s] = np.zeros((nchans))
        for i in np.arange(nchans):
            if pvals[m][s][i] < 0.01:
                sig_wts[m][s][i] = wts[m][s][i]
                sig_corrs[m][s][i] = corrs[m][s][i]

In [None]:
# EMG regress vs no regress
comparison = ['model1', 'model1e']
plt.figure(figsize=(8,4))
modelcomp_pvals = np.zeros((len(subjs)))
for i,s in enumerate(subjs):
    c1 = sig_corrs[comparison[0]][s]
    c2 = sig_corrs[comparison[1]][s]
    stat, modelcomp_pvals[i] = wilcoxon(c1, c2)
    reject, modelcomp_pvals[i], alphacSidak, alphacBonf = multipletests(modelcomp_pvals[i],
                                                                        alpha=0.05,method='fdr_by')
    ylim = [0,0.03]
    if modelcomp_pvals[i] < 0.05:
        color = 'mediumseagreen'
        alpha = 1
    else:
        color = 'firebrick'
        alpha = 0.5
        plt.text(i,np.mean(ylim),'n.s.', ha='center', style='italic',weight='heavy',color='w',fontsize=12)
    plt.bar(i,modelcomp_pvals[i],color=color,alpha=alpha)
plt.gca().set_ylim(ylim)
plt.gca().set_yticks(np.linspace(ylim[0],ylim[1],5))
plt.gca().set_yticklabels(np.linspace(ylim[0],ylim[1],5),fontsize=12)
plt.gca().set_xticks(np.arange(len(subjs)))
plt.gca().set_xticklabels(subjs,rotation='vertical',fontsize=12);
plt.xlabel('Subject',fontsize=12)
plt.ylabel('p value', fontsize=12);
plt.title(f'Within-subject comparison of {comparison[0]} and {comparison[1]} \nby multiple test-corrected Wilcoxon signed-rank test',fontsize=12);
plt.tight_layout();


In [None]:
# Identical vs. differential phn feat encoding
comparison = ['model1', 'model2']
plt.figure(figsize=(8,4))
modelcomp_pvals = np.zeros((len(subjs)))
for i,s in enumerate(subjs):
    c1 = sig_corrs[comparison[0]][s]
    c2 = sig_corrs[comparison[1]][s]
    stat, modelcomp_pvals[i] = wilcoxon(c1, c2)
    reject, modelcomp_pvals[i], alphacSidak, alphacBonf = multipletests(modelcomp_pvals[i],
                                                                        alpha=0.05,method='fdr_by')
    ylim = [0,0.03]
    if modelcomp_pvals[i] < 0.05:
        color = 'mediumseagreen'
        alpha = 1
    else:
        color = 'firebrick'
        alpha = 0.5
        plt.text(i,np.mean(ylim),'n.s.', ha='center', style='italic',weight='heavy',color='w',fontsize=12)
    plt.bar(i,modelcomp_pvals[i],color=color,alpha=alpha)
plt.gca().set_ylim(ylim)
plt.gca().set_yticks(np.linspace(ylim[0],ylim[1],5))
plt.gca().set_yticklabels(np.linspace(ylim[0],ylim[1],5),fontsize=12)
plt.gca().set_xticks(np.arange(len(subjs)))
plt.gca().set_xticklabels(subjs,rotation='vertical',fontsize=12);
plt.xlabel('Subject',fontsize=12)
plt.ylabel('p value', fontsize=12);
plt.title(f'Within-subject comparison of {comparison[0]} and {comparison[1]} \nby multiple test-corrected Wilcoxon signed-rank test',fontsize=12);
plt.tight_layout();

In [None]:
# Encoding vs not encoding task predictability
comparison = ['model1', 'model3']
plt.figure(figsize=(8,4))
modelcomp_pvals = np.zeros((len(subjs)))
for i,s in enumerate(subjs):
    c1 = sig_corrs[comparison[0]][s]
    c2 = sig_corrs[comparison[1]][s]
    stat, modelcomp_pvals[i] = wilcoxon(c1, c2)
    reject, modelcomp_pvals[i], alphacSidak, alphacBonf = multipletests(modelcomp_pvals[i],
                                                                        alpha=0.05,method='fdr_by')
    ylim = [0,0.04]
    if modelcomp_pvals[i] < 0.05:
        color = 'mediumseagreen'
        alpha = 1
    else:
        color = 'firebrick'
        alpha = 0.5
        plt.text(i,np.mean(ylim),'n.s.', ha='center', style='italic',weight='heavy',color='w',fontsize=12)
    plt.bar(i,modelcomp_pvals[i],color=color,alpha=alpha)
plt.gca().set_ylim(ylim)
plt.gca().set_yticks(np.linspace(ylim[0],ylim[1],5))
plt.gca().set_yticklabels(np.linspace(ylim[0],ylim[1],5),fontsize=12)
plt.gca().set_xticks(np.arange(len(subjs)))
plt.gca().set_xticklabels(subjs,rotation='vertical',fontsize=12);
plt.xlabel('Subject',fontsize=12)
plt.ylabel('p value', fontsize=12);
plt.title(f'Within-subject comparison of {comparison[0]} and {comparison[1]} \nby multiple test-corrected Wilcoxon signed-rank test',fontsize=12);
plt.tight_layout();

In [None]:
# Encoding vs not encoding task modality
comparison = ['model1', 'model4']
plt.figure(figsize=(10,5))
modelcomp_pvals = np.zeros((len(subjs)))
for i,s in enumerate(subjs):
    c1 = sig_corrs[comparison[0]][s]
    c2 = sig_corrs[comparison[1]][s]
    stat, modelcomp_pvals[i] = wilcoxon(c1, c2)
    reject, modelcomp_pvals[i], alphacSidak, alphacBonf = multipletests(modelcomp_pvals[i],
                                                                        alpha=0.05,method='fdr_by')
    ylim = [0,0.03]
    if modelcomp_pvals[i] < 0.05:
        color = 'mediumseagreen'
        alpha = 1
    else:
        color = 'firebrick'
        alpha = 0.5
        plt.text(i,np.mean(ylim),'n.s.', ha='center', style='italic',weight='heavy',color='w',fontsize=12)
    plt.bar(i,modelcomp_pvals[i],color=color,alpha=alpha)
plt.gca().set_ylim(ylim)
plt.gca().set_yticks(np.linspace(ylim[0],ylim[1],5))
plt.gca().set_yticklabels(np.linspace(ylim[0],ylim[1],5),fontsize=12)
plt.gca().set_xticks(np.arange(len(subjs)))
plt.gca().set_xticklabels(subjs,rotation='vertical',fontsize=12);
plt.xlabel('Subject',fontsize=12)
plt.ylabel('p value', fontsize=12);
plt.title(f'Within-subject comparison of {comparison[0]} and {comparison[1]} \nby multiple test-corrected Wilcoxon signed-rank test',fontsize=12);
plt.tight_layout();