In [None]:
import mne
import numpy as np
from glob import glob
import csv
import h5py
from tqdm.notebook import tqdm
import pandas as pd
import sys
sys.path.append('../../preprocessing/utils/')
import strf

In [None]:
# Change these path for running the notebook locally
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'
# Where the output of train_linear_model.ipynb is saved. Run that first if you haven't already.
h5_path = '/path/to/h5'

In [None]:
picks = ['F1','Fz','F2','FC1','FCz','FC2','C1','Cz','C2']
tmin,tmax = -1.5, 3.5
baseline=(None,0)
reject_thresh = 10 # SD
exclude = 'OP0020'
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and exclude not in s])

## Load ERP data

In [None]:
def interpolate_bads(subj,block,raw,git_path,eeg_data_path):
    '''
    CCA .vhdr files do not include bad channels, so this func interpolates them.
    The end result is each subject has a 64 channel raw arrray, which makes plotting easier
    Filter the data before running this!
    '''
    blockid = subj + '_' + block
    nsamps = len(raw)
    info = mne.io.read_raw_brainvision(f'{eeg_data_path}OP0001/OP0001_B1/OP0001_B1_cca.vhdr',
                                           preload=False,verbose=False).info
    ch_names = info['ch_names']
    bads = [c for c in ch_names if c not in raw.info['ch_names']]
    if len(bads) > 0:
        new_data = []
        for ch in ch_names:
            if ch in bads:
                print('Interpolating', ch)
                new_data.append(np.zeros((1,nsamps)))
            else:
                new_data.append(raw.get_data(picks=ch)[0])
        raw = mne.io.RawArray(np.vstack(new_data),info)
        raw.info['bads'] = bads
        raw.interpolate_bads()
    else:
        print(subj, 'has no bads.')
    return raw

In [None]:
raws, spkr_epochs, mic_epochs, spkr_resp, mic_resp = dict(), dict(), dict(), dict(), dict()
conditions = ['el','sh','all']
channels = ['spkr', 'mic']
conditions
subj_bar = tqdm(subjs)
for s in subj_bar:
    spkr_epochs[s], mic_epochs[s], spkr_resp[s], mic_resp[s] = dict(), dict(), dict(), dict()
    if s != 'OP0002':
        block = 'B1'
    else:
        block = 'B2'
    blockid = s + '_' + block
    raw_path = f'{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr'
    raw = mne.io.read_raw_brainvision(raw_path,preload=True,verbose=False)
    raw.filter(l_freq=1,h_freq=30,verbose=False)
    raws[s] = interpolate_bads(s,block,raw,git_path,eeg_data_path)
    for ch in channels:
        spkr_epochs[s][ch], mic_epochs[s][ch], spkr_resp[s][ch], mic_resp[s][ch] = dict(), dict(), dict(), dict()
        for co in conditions:
            subj_bar.set_description(f'Epoching {s} {ch} {co}')
            if s in ['OP0015','OP0016']: # mtpl blocks
                events = []
                for b in ['B1','B2']: 
                    event_file = f'{git_path}eventfiles/{s}/{s}_{b}/{s}_{b}_{ch}_sn_{co}.txt'
                    if b == 'B2':
                        b2_fpath = raw_path = f'{eeg_data_path}{s}/{s}_B1/{s}_B1_downsampled.vhdr'
                        b2_raw = mne.io.read_raw_brainvision(b2_fpath,preload=True,verbose=False)
                        block_shift = b2_raw.last_samp
                    else:
                        block_shift = 0
                    with open(event_file, 'r') as my_csv:
                        csvReader = csv.reader(my_csv, delimiter='\t')
                        for row in csvReader:
                            onset = int((float(row[0])*128)+block_shift)
                            offset = int((float(row[1])*128)+block_shift)
                            sn_id = int(row[2])
                            events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            else:
                event_file = f'{git_path}eventfiles/{s}/{blockid}/{blockid}_{ch}_sn_{co}.txt'
                events = []
                with open(event_file, 'r') as my_csv:
                    csvReader = csv.reader(my_csv, delimiter='\t')
                    for row in csvReader:
                        onset = int((float(row[0])*128))
                        offset = int((float(row[1])*128))
                        sn_id = int(row[2])
                        events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            reject = mne.Epochs(raws[s], events, tmin=tmin, tmax=tmax,reject=None,
                           baseline=baseline,verbose=False)
            reject = reject.get_data(picks=picks)
            reject = dict(eeg=np.std(reject)*(reject_thresh*2))
            epochs = mne.Epochs(raws[s],events,tmin=tmin,tmax=tmax,reject=reject,
                                   baseline=baseline,verbose=False)
            if ch == 'spkr':
                spkr_epochs[s][ch][co] = epochs
                spkr_resp[s][ch][co] = epochs.get_data(picks=picks)
            if ch == 'mic':
                mic_epochs[s][ch][co] = epochs
                mic_resp[s][ch][co] = epochs.get_data(picks=picks)

In [None]:
# Extract windows of interest
N100_window = [0.08,0.15]
P200_window = [0.15,0.25]
t = np.linspace(tmin,tmax,num=spkr_resp['OP0001']['spkr']['all'].shape[2])
N100_inds = np.where((t>N100_window[0]) & (t<N100_window[1]))[0]
P200_inds = np.where((t>P200_window[0]) & (t<P200_window[1]))[0]
N100_spkr_peak_amplitude,N100_mic_peak_amplitude = dict(),dict()
P200_spkr_peak_amplitude,P200_mic_peak_amplitude = dict(),dict()
N100_spkr_peak_latency,N100_mic_peak_latency = dict(),dict()
P200_spkr_peak_latency,P200_mic_peak_latency = dict(),dict()
peak_to_peak_spkr,peak_to_peak_mic = dict(),dict()
for subj in subjs:
    # SPKR
    # Get the minimum value at each epoch
    N100_spkr_peak_amplitude[subj] = spkr_resp[subj]['spkr']['all'].mean(1)[:,N100_inds].min(1) # amplitude
    latency_idx = N100_inds[spkr_resp[subj]['spkr']['all'].mean(1)[:,N100_inds].argmin(1)] # latency
    N100_spkr_peak_latency[subj] = t[latency_idx]
    P200_spkr_peak_amplitude[subj] = spkr_resp[subj]['spkr']['all'].mean(1)[:,P200_inds].max(1) # amplitude
    latency_idx = P200_inds[spkr_resp[subj]['spkr']['all'].mean(1)[:,P200_inds].argmax(1)] # latency
    P200_spkr_peak_latency[subj] = t[latency_idx]
    peak_to_peak_spkr[subj] = np.abs(
        P200_spkr_peak_amplitude[subj] - N100_spkr_peak_amplitude[subj])
    # MIC
    N100_mic_peak_amplitude[subj] = mic_resp[subj]['mic']['all'].mean(1)[:,N100_inds].min(1) # amplitude
    latency_idx = N100_inds[mic_resp[subj]['mic']['all'].mean(1)[:,N100_inds].argmin(1)] # latency
    N100_mic_peak_latency[subj] = t[latency_idx]
    P200_mic_peak_amplitude[subj] = mic_resp[subj]['mic']['all'].mean(1)[:,P200_inds].max(1) # amplitude
    latency_idx = P200_inds[mic_resp[subj]['mic']['all'].mean(1)[:,P200_inds].argmax(1)] # latency
    P200_mic_peak_latency[subj] = t[latency_idx]
    peak_to_peak_mic[subj] = np.abs(
        P200_mic_peak_amplitude[subj] - N100_mic_peak_amplitude[subj])

In [None]:
N100_el_peak_amplitude,N100_sh_peak_amplitude = dict(),dict()
P200_el_peak_amplitude,P200_sh_peak_amplitude = dict(),dict()
N100_el_peak_latency,N100_sh_peak_latency = dict(),dict()
P200_el_peak_latency,P200_sh_peak_latency = dict(),dict()
peak_to_peak_el,peak_to_peak_sh = dict(),dict()
for subj in subjs:
    if subj != 'OP0015': # no sh
        # el
        N100_el_peak_amplitude[subj] = spkr_resp[subj]['spkr']['el'].mean(1)[:,N100_inds].min(1) # amplitude
        latency_idx = N100_inds[spkr_resp[subj]['spkr']['el'].mean(1)[:,N100_inds].argmin(1)] # latency
        N100_el_peak_latency[subj] = t[latency_idx]
        P200_el_peak_amplitude[subj] = spkr_resp[subj]['spkr']['el'].mean(1)[:,P200_inds].max(1) # amplitude
        latency_idx = P200_inds[spkr_resp[subj]['spkr']['el'].mean(1)[:,P200_inds].argmax(1)] # latency
        P200_el_peak_latency[subj] = t[latency_idx]
        peak_to_peak_el[subj] = np.abs(
            P200_el_peak_amplitude[subj] - N100_el_peak_amplitude[subj])
        # sh
        N100_sh_peak_amplitude[subj] = spkr_resp[subj]['spkr']['sh'].mean(1)[:,N100_inds].min(1) # amplitude
        latency_idx = N100_inds[spkr_resp[subj]['spkr']['sh'].mean(1)[:,N100_inds].argmin(1)] # latency
        N100_sh_peak_latency[subj] = t[latency_idx]
        P200_sh_peak_amplitude[subj] = spkr_resp[subj]['spkr']['sh'].mean(1)[:,P200_inds].max(1) # amplitude
        latency_idx = P200_inds[spkr_resp[subj]['spkr']['sh'].mean(1)[:,P200_inds].argmax(1)] # latency
        P200_sh_peak_latency[subj] = t[latency_idx]
        peak_to_peak_sh[subj] = np.abs(P200_sh_peak_amplitude[subj] - N100_sh_peak_amplitude[subj])

In [None]:
# write to CSV
csvheader = [['Subject','Cond','N100_amp','P200_amp','N100_latency','P200_latency','peak_to_peak']]
csv_fname = f'{git_path}stats/lme/csvs/perception_production.csv'
conditions = ['Perc','Prod']
for i,subj in enumerate(subjs):
    for condition in conditions:
        if condition == 'Perc':
            for i,trial in enumerate(N100_spkr_peak_amplitude[subj]):
                csvheader.append([subj,condition,
                                  trial,P200_spkr_peak_amplitude[subj][i],
                                  N100_spkr_peak_latency[subj][i],P200_spkr_peak_latency[subj][i],
                                  peak_to_peak_spkr[subj][i]])
        elif condition == 'Prod':
            for i,trial in enumerate(N100_mic_peak_amplitude[subj]):
                csvheader.append([subj,condition,
                                  trial,P200_mic_peak_amplitude[subj][i],
                                  N100_mic_peak_latency[subj][i],P200_mic_peak_latency[subj][i],
                                  peak_to_peak_mic[subj][i]])
with open(csv_fname,'w+') as my_csv:
    csvWriter = csv.writer(my_csv,delimiter=',')
    csvWriter.writerows(csvheader)

csvheader = [['Subject','Cond','N100_amp','P200_amp','N100_latency','P200_latency','peak_to_peak']]
csv_fname = f'{git_path}stats/lme/csvs/predictable_unpredictable.csv'
conditions = ['Echo','Shuff']
for i,subj in enumerate(subjs):
    if subj != 'OP0015': # no sh
        for condition in conditions:
            if condition == 'Echo':
                for i,trial in enumerate(N100_el_peak_amplitude[subj]):
                    csvheader.append([subj,condition,
                                      trial,P200_el_peak_amplitude[subj][i],
                                      N100_el_peak_latency[subj][i],P200_el_peak_latency[subj][i],
                                      peak_to_peak_el[subj][i]])
            elif condition == 'Shuff':
                for i,trial in enumerate(N100_sh_peak_amplitude[subj]):
                    csvheader.append([subj,condition,
                                      trial,P200_sh_peak_amplitude[subj][i],
                                      N100_sh_peak_latency[subj][i],P200_sh_peak_latency[subj][i],
                                      peak_to_peak_sh[subj][i]])
with open(csv_fname,'w+') as my_csv:
    csvWriter = csv.writer(my_csv,delimiter=',')
    csvWriter.writerows(csvheader)

## Load LEM data

In [None]:
models = ['model1','model1e','model2','model2e','model3','model3e','model4','model4e']
exclude = ['OP0001','OP0002','OP0020']
subjs = [s for s in subjs if s not in exclude]
tmin,tmax = -0.3,0.5
delays = np.arange(np.floor(tmin*128),np.ceil(tmax*128),dtype=int)
features = {model_number:strf.get_feats(model_number,extend_labels=True) for model_number in models}
n_feats = {model_number:len(features[model_number]) for model_number in models}

In [None]:
# Load data from hdf5, pandas
wts, corrs, pvals, sig_wts, sig_corrs, alphas = dict(), dict(), dict(), dict(), dict(), dict()
results_csv_fpath = f"{git_path}stats/lem_results.csv"
df = pd.read_csv(results_csv_fpath)
for m in models:
    wts[m], corrs[m], pvals[m], sig_wts[m], sig_corrs[m], alphas[m] = dict(), dict(), dict(), dict(), dict(), dict()
    b = tqdm(subjs)
    for s in b:
        blockid = f"{s}_B1"
        b.set_description(f'Loading STRF for {s} {m}')
        with h5py.File(f"{h5_path}{s}_weights.hdf5",'r') as f:
            wts[m][s] = np.array(f.get(m))
        ch_names = mne.io.read_raw_brainvision(f"{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr",
                                               preload=False,verbose=False).info['ch_names']
        subj_corrs, subj_best_alphas, subj_pvals = np.zeros(len(ch_names)), np.zeros(len(ch_names)), np.zeros(len(ch_names))
        for i, ch in enumerate(ch_names):
            tgt_row = df[(df['subject']==s) & (df['model']==m) & (df['channel']==ch)]
            subj_corrs[i] = df.loc[tgt_row.index, 'r_value']
            subj_best_alphas[i] = df.loc[tgt_row.index, 'best_alpha']
            subj_pvals[i] = df.loc[tgt_row.index, 'p_value']
        corrs[m][s] = np.array(subj_corrs)
        pvals[m][s] = np.array(subj_pvals)
        alphas[m][s] = np.array(subj_best_alphas)
    # Extract significant weights, corrs
    for s in subjs:
        nchans = wts[m][s].shape[2]
        sig_wts[m][s] = np.zeros((len(delays),n_feats[m],nchans))
        sig_corrs[m][s] = np.zeros((nchans))
        for i in np.arange(nchans):
            if pvals[m][s][i] < 0.01:
                sig_wts[m][s][i] = wts[m][s][i]
                sig_corrs[m][s][i] = corrs[m][s][i]

In [None]:
xm, ym = 'model1', 'model2'
header = [['r','model','subject','channel']]
for s in subjs:
    ch_names = mne.io.read_raw_brainvision(
        f'{eeg_data_path}{s}/{s}_B1/{s}_B1_cca.vhdr',
        preload=False,verbose=False).info['ch_names']
    for i,c in enumerate(ch_names):
        header.append([corrs[xm][s][i],xm,s,c])
        header.append([corrs[ym][s][i],ym,s,c])
csv_fname = f'{git_path}stats/lme/csvs/{xm}_{ym}.csv'
with open(csv_fname,'w+') as f:
    csvWriter = csv.writer(f)
    csvWriter.writerows(header)