## Figure 2

In [None]:
import mne
import numpy as np
import csv
from glob import glob
from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from tqdm.notebook import tqdm
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')

In [None]:
# Change these path for running the notebook locally
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'

In [None]:
perception_color = '#117733'
production_color = '#332288'
consistent_color = '#ddcc77'
inconsistent_color = '#aa4499'
picks = ['F1','Fz','F2','FC1','FCz','FC2','C1','Cz','C2']
tmin,tmax = -1.5, 3.5
baseline=(None,0)
reject_thresh = 10 # SD
exclude = 'OP0020'
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and exclude not in s])

### Load and epoch data

In [None]:
def interpolate_bads(subj,block,raw,git_path,eeg_data_path):
    '''
    CCA .vhdr files do not include bad channels, so this func interpolates them.
    The end result is each subject has a 64 channel raw arrray, which makes plotting easier
    Filter the data before running this!
    '''
    blockid = subj + '_' + block
    nsamps = len(raw)
    info = mne.io.read_raw_brainvision(f'{eeg_data_path}OP0001/OP0001_B1/OP0001_B1_cca.vhdr',
                                           preload=False,verbose=False).info
    ch_names = info['ch_names']
    bads = [c for c in ch_names if c not in raw.info['ch_names']]
    if len(bads) > 0:
        new_data = []
        for ch in ch_names:
            if ch in bads:
                print('Interpolating', ch)
                new_data.append(np.zeros((1,nsamps)))
            else:
                new_data.append(raw.get_data(picks=ch)[0])
        raw = mne.io.RawArray(np.vstack(new_data),info)
        raw.info['bads'] = bads
        raw.interpolate_bads()
    else:
        print(subj, 'has no bads.')
    return raw
def sem_evoked(evoked):
    '''
    calculates standard error margin across events(evoked)
    evoked should have shape (channels,samples)
    '''
    sem_below = np.nanmean(evoked,axis=0) - (np.nanstd(evoked,axis=0)/np.sqrt(evoked.shape[0]))
    sem_above = np.nanmean(evoked,axis=0) + (np.nanstd(evoked,axis=0)/np.sqrt(evoked.shape[0]))
    return sem_below/1e-6, sem_above/1e-6

In [None]:
raws, spkr_epochs, mic_epochs, spkr_resp, mic_resp = dict(), dict(), dict(), dict(), dict()
conditions = ['el','sh','all']
channels = ['spkr', 'mic']
conditions
subj_bar = tqdm(subjs)
for s in subj_bar:
    spkr_epochs[s], mic_epochs[s], spkr_resp[s], mic_resp[s] = dict(), dict(), dict(), dict()
    if s != 'OP0002':
        block = 'B1'
    else:
        block = 'B2'
    blockid = s + '_' + block
    raw_path = f'{eeg_data_path}{s}/{blockid}/{blockid}_cca.vhdr'
    raw = mne.io.read_raw_brainvision(raw_path,preload=True,verbose=False)
    raw.filter(l_freq=1,h_freq=30,verbose=False)
    raws[s] = interpolate_bads(s,block,raw,git_path,eeg_data_path)
    for ch in channels:
        spkr_epochs[s][ch], mic_epochs[s][ch], spkr_resp[s][ch], mic_resp[s][ch] = dict(), dict(), dict(), dict()
        for co in conditions:
            subj_bar.set_description(f'Epoching {s} {ch} {co}')
            if s in ['OP0015','OP0016']: # mtpl blocks
                events = []
                for b in ['B1','B2']: 
                    event_file = f'{git_path}eventfiles/{s}/{s}_{b}/{s}_{b}_{ch}_sn_{co}.txt'
                    if b == 'B2':
                        b2_fpath = raw_path = f'{eeg_data_path}{s}/{s}_B1/{s}_B1_downsampled.vhdr'
                        b2_raw = mne.io.read_raw_brainvision(b2_fpath,preload=True,verbose=False)
                        block_shift = b2_raw.last_samp
                    else:
                        block_shift = 0
                    with open(event_file, 'r') as my_csv:
                        csvReader = csv.reader(my_csv, delimiter='\t')
                        for row in csvReader:
                            onset = int((float(row[0])*128)+block_shift)
                            offset = int((float(row[1])*128)+block_shift)
                            sn_id = int(row[2])
                            events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            else:
                event_file = f'{git_path}eventfiles/{s}/{blockid}/{blockid}_{ch}_sn_{co}.txt'
                events = []
                with open(event_file, 'r') as my_csv:
                    csvReader = csv.reader(my_csv, delimiter='\t')
                    for row in csvReader:
                        onset = int((float(row[0])*128))
                        offset = int((float(row[1])*128))
                        sn_id = int(row[2])
                        events.append([onset,offset,sn_id])
                events = np.array(events,dtype=int)
            reject = mne.Epochs(raws[s], events, tmin=tmin, tmax=tmax,reject=None,
                           baseline=baseline,verbose=False)
            reject = reject.get_data(picks=picks)
            reject = dict(eeg=np.std(reject)*(reject_thresh*2))
            epochs = mne.Epochs(raws[s],events,tmin=tmin,tmax=tmax,reject=reject,
                                   baseline=baseline,verbose=False)
            if ch == 'spkr':
                spkr_epochs[s][ch][co] = epochs
                spkr_resp[s][ch][co] = epochs.get_data(picks=picks)
            if ch == 'mic':
                mic_epochs[s][ch][co] = epochs
                mic_resp[s][ch][co] = epochs.get_data(picks=picks)

In [None]:
# Concatenate trials by first averaging across event ID
times = spkr_resp['OP0001']['spkr']['all'].shape[2]
spkr_evoked = np.zeros((len(subjs),len(picks),times))
mic_evoked = np.zeros((len(subjs),len(picks),times))
el_evoked = np.zeros((len(subjs),len(picks),times))
sh_evoked = np.zeros((len(subjs),len(picks),times))
for i,s in enumerate(subjs):
    e = spkr_epochs[s]['spkr']['all'].average(picks=picks,by_event_type=True)
    spkr_evoked[i] = mne.combine_evoked(e,'nave').get_data()
    e = mic_epochs[s]['mic']['all'].average(picks=picks,by_event_type=True)
    mic_evoked[i] = mne.combine_evoked(e,'nave').get_data()
    e = spkr_epochs[s]['spkr']['el'].average(picks=picks,by_event_type=True)
    el_evoked[i] = mne.combine_evoked(e,'nave').get_data()
    e = spkr_epochs[s]['spkr']['sh'].average(picks=picks,by_event_type=True)
    sh_evoked[i] = mne.combine_evoked(e,'nave').get_data()
# Take the difference wave within subject at each event ID, then combine
spkr_mic_dw = np.zeros((len(subjs),len(picks),times))
el_sh_dw = np.zeros((len(subjs),len(picks),times))
for i0,s in enumerate(subjs):
    spkr_e = spkr_epochs[s]['spkr']['all'].average(picks=picks,by_event_type=True)
    mic_e = mic_epochs[s]['mic']['all'].average(picks=picks,by_event_type=True)
    el_e = spkr_epochs[s]['spkr']['el'].average(picks=picks,by_event_type=True)
    sh_e = spkr_epochs[s]['spkr']['sh'].average(picks=picks,by_event_type=True)
    subj_dws = []
    for ie, e1 in enumerate(spkr_e):
        spkr_ev_id = e1.comment
        e2 = [x for x in mic_e if x.comment==spkr_ev_id][0]
        subj_dws.append(mne.combine_evoked([e1,e2],[1,-1]))
    spkr_mic_dw[i] = mne.combine_evoked(subj_dws,'nave').get_data()
    subj_dws = []
    for ie, e1 in enumerate(el_e):
        el_ev_id = e1.comment
        e2 = [x for x in sh_e if x.comment==spkr_ev_id][0]
        subj_dws.append(mne.combine_evoked([e1,e2],[1,-1]))
    el_sh_dw[i] = mne.combine_evoked(subj_dws,'nave').get_data()

### Plotting

In [None]:
# Panel B
fontsize=16
plt.figure(figsize=(10,10))
x = np.linspace(tmin,tmax,spkr_evoked.shape[2])

plt.subplot(2,2,1)
y1 = np.nanmean(spkr_evoked,axis=0)/1e-6
y1_error_epochs = sem_evoked(np.nanmean(spkr_evoked,axis=0))
y2 = np.nanmean(mic_evoked,axis=0)/1e-6
y2_error_epochs = sem_evoked(np.nanmean(mic_evoked,axis=0)) 
plt.plot(x,y1.mean(0),color=perception_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=perception_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=production_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=production_color,alpha=0.3)
plt.gca().set_xticks([tmin,0,1.,2.,3.]);
plt.gca().set_xticklabels([tmin,0,1.,2.,3.],fontsize=fontsize);
plt.gca().set_xlim([tmin,tmax])
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([-1,0,1], fontsize=fontsize)
plt.gca().set_ylim([-1, 1.25])
plt.ylabel("µV±σ", fontsize=fontsize);

plt.subplot(2,2,2)
y1 = np.nanmean(el_evoked,axis=0)/1e-6
y1_error_epochs = sem_evoked(np.nanmean(el_evoked,axis=0))
y2 = np.nanmean(sh_evoked,axis=0)/1e-6
y2_error_epochs = sem_evoked(np.nanmean(sh_evoked,axis=0)) 
plt.plot(x,y1.mean(0),color=consistent_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=consistent_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=inconsistent_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=inconsistent_color,alpha=0.3)
plt.gca().set_xticks([tmin,0,1.,2.,3.]);
plt.gca().set_xticklabels([tmin,0,1.,2.,3.],fontsize=fontsize);
plt.gca().set_xlim([tmin,tmax])
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([-1,0,1], fontsize=fontsize)
plt.gca().set_ylim([-1, 1.25])
plt.ylabel("µV±σ", fontsize=fontsize);

plt.subplot(2,2,3)
y = np.nanmean(spkr_mic_dw,axis=0)/1e-6
y_error_epochs = sem_evoked(np.nanmean(spkr_mic_dw,axis=0))
plt.plot(x,y.mean(0),color='grey')
plt.fill_between(x,y_error_epochs[0],y_error_epochs[1],alpha=0.3,color='grey')
plt.gca().set_xticks([tmin,0,1.,2.,3.]);
plt.gca().set_xticklabels([tmin,0,1.,2.,3.],fontsize=fontsize);
plt.gca().set_xlim([tmin,tmax])
plt.gca().set_yticks([-.1,0,.1])
plt.gca().set_yticklabels([-.1,0,.1], fontsize=fontsize)
plt.gca().set_ylim([-.1, .125])

plt.subplot(2,2,4)
y = np.nanmean(el_sh_dw,axis=0)/1e-6
y_error_epochs = sem_evoked(np.nanmean(el_sh_dw,axis=0))
plt.plot(x,y.mean(0),color='grey')
plt.fill_between(x,y_error_epochs[0],y_error_epochs[1],alpha=0.3,color='grey')
plt.gca().set_xticks([tmin,0,1.,2.,3.]);
plt.gca().set_xticklabels([tmin,0,1.,2.,3.],fontsize=fontsize);
plt.gca().set_xlim([tmin,tmax])
plt.gca().set_yticks([-.5,0,.5])
plt.gca().set_yticklabels([-.5,0,.5], fontsize=fontsize)
plt.gca().set_ylim([-.5, .625]);

# legend
plt.figure(figsize=(3,2))
plt.bar(0,0,color=perception_color,label="Perception")
plt.bar(0,0,color=production_color,label="Production")
plt.bar(0,0,color=consistent_color,label="Consistent playback")
plt.bar(0,0,color=inconsistent_color,label="Inconsistent playback")
plt.bar(0,0,color='grey', label="Difference wave")
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.legend(fontsize=fontsize,loc='center');
plt.tight_layout()

In [None]:
# N100/P200 windows -- perception vs. production
fig = plt.figure(figsize=(5,5))
gs = GridSpec(4, 4, figure=fig)
axes = []
# N100
axes.append(fig.add_subplot(gs[:,:2]))
plt.plot(x,y1.mean(0),color=spkr_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=spkr_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=mic_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=mic_color,alpha=0.3)
plt.gca().set_xlim([.08,.15])
plt.gca().set_xticks([.08,.1,.12,.14])
plt.gca().set_xticklabels([80,100,120,140],fontsize=fontsize)
plt.xlabel("Time (ms)", fontsize=fontsize)
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([-1,0,1], fontsize=fontsize)
plt.gca().set_ylim([-1, 1.25])
plt.ylabel("µV±σ", fontsize=fontsize);
plt.title("N100 window",fontsize=fontsize)
# P200
axes.append(fig.add_subplot(gs[:,2:]))
plt.plot(x,y1.mean(0),color=spkr_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=spkr_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=mic_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=mic_color,alpha=0.3)
plt.gca().set_xlim([.15,.25])
plt.gca().set_xticks([.15,.18,.21,.24])
plt.gca().set_xticklabels([150,180,210,240],fontsize=fontsize)
plt.xlabel("Time (ms)", fontsize=fontsize)
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([])
plt.gca().set_ylim([-1, 1.25])
plt.title("P200 window",fontsize=fontsize)

gs.tight_layout(figure=fig)

In [None]:
# N100 / P200 windows - consistent vs. inconsistent playback
fig = plt.figure(figsize=(5,5))
gs = GridSpec(4, 4, figure=fig)
axes = []
# N100
axes.append(fig.add_subplot(gs[:,:2]))
plt.plot(x,y1.mean(0),color=consistent_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=consistent_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=inconsistent_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=inconsistent_color,alpha=0.3)
plt.gca().set_xlim([.08,.15])
plt.gca().set_xticks([.08,.1,.12,.14])
plt.gca().set_xticklabels([80,100,120,140],fontsize=fontsize)
plt.xlabel("Time (ms)", fontsize=fontsize)
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([-1,0,1], fontsize=fontsize)
plt.gca().set_ylim([-1, 1.25])
plt.ylabel("µV±σ", fontsize=fontsize);
plt.title("N100 window",fontsize=fontsize)
# P200
axes.append(fig.add_subplot(gs[:,2:]))
plt.plot(x,y1.mean(0),color=consistent_color)
plt.fill_between(x,y1_error_epochs[0],y1_error_epochs[1],color=consistent_color,alpha=0.3)
plt.plot(x,y2.mean(0),color=inconsistent_color)
plt.fill_between(x,y2_error_epochs[0],y2_error_epochs[1],color=inconsistent_color,alpha=0.3)
plt.gca().set_xlim([.15,.25])
plt.gca().set_xticks([.15,.18,.21,.24])
plt.gca().set_xticklabels([150,180,210,240],fontsize=fontsize)
plt.xlabel("Time (ms)", fontsize=fontsize)
plt.gca().set_yticks([-1,0,1])
plt.gca().set_yticklabels([])
plt.gca().set_ylim([-1, 1.25])
plt.title("P200 window",fontsize=fontsize)

gs.tight_layout(figure=fig);