In [None]:
import mne
import h5py
import scipy
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib import cm
from matplotlib import rcParams as rc
from tqdm.notebook import tqdm
import seaborn
plt.style.use('seaborn')
import csv
import numpy as np
rc['pdf.fonttype'] = 42

In [None]:
# Change these path for running the notebook locally
eeg_data_path = '/path/to/dataset/' # downloadable from OSF: https://doi.org/10.17605/OSF.IO/FNRD9
git_path  = '/path/to/git/speaker_induced_suppression_EEG/'

In [None]:
click_color = "#973DC2"
aux_color = "#D24E2D"
perception_color = '#117733'
production_color = '#332288'
picks = ['F1','Fz','F2','FC1','FCz','FC2','C1','Cz','C2']
tmin,tmax = -.2,.5
reject_thresh = 5 # SD
exclude = 'OP0020'
subjs = np.sort([s[-6:] for s in glob(f'{git_path}eventfiles/*') if 'OP0' in s and exclude not in s])

### load data

In [None]:
def interpolate_bads(subj,block,raw,git_path,eeg_data_path):
    '''
    CCA .vhdr files do not include bad channels, so this func interpolates them.
    The end result is each subject has a 64 channel raw arrray, which makes plotting easier
    Filter the data before running this!
    '''
    blockid = subj + '_' + block
    nsamps = len(raw)
    info = mne.io.read_raw_brainvision(f'{eeg_data_path}OP0001/OP0001_B1/OP0001_B1_cca.vhdr',
                                           preload=False,verbose=False).info
    ch_names = info['ch_names']
    bads = [c for c in ch_names if c not in raw.info['ch_names']]
    if len(bads) > 0:
        new_data = []
        for ch in ch_names:
            if ch in bads:
                print('Interpolating', ch)
                new_data.append(np.zeros((1,nsamps)))
            else:
                new_data.append(raw.get_data(picks=ch)[0])
        raw = mne.io.RawArray(np.vstack(new_data),info)
        raw.info['bads'] = bads
        raw.interpolate_bads()
    else:
        print(subj, 'has no bads.')
    return raw
def reject_epochs(raw,events,tmin,tmax,reject_thresh=10,picks=None):
    """
    Get reject threshold
    """
    reject = mne.Epochs(raw,events,tmin=tmin,tmax=tmax,reject=None,baseline=(None,0),verbose=False, event_repeated='drop')
    if picks == None:
        picks = mne.pick_types(raw.info,meg=False,eeg=True,eog=False)
    reject = reject.get_data(picks=picks)
    return dict(eeg=np.std(reject)*(reject_thresh)*2)
def butter_highpass(cutoff, fs, order=5):
    '''
    Creates a high pass filter at a given frequency.
    '''
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = scipy.signal.butter(order, normal_cutoff, btype='high', analog=False)
    return b, a
def butter_highpass_filter(data, cutoff, fs, order=2):
    '''
    Highpass filters data using a highpass filter.
    '''
    b, a = butter_highpass(cutoff, fs, order=order)
    y = scipy.signal.filtfilt(b, a, data,axis=0)
    return y

In [None]:
# Load raw data
raws,infos = dict(),dict()
for subj in tqdm(subjs):
    blockid = subj + '_B1'
    # get raw
    raw_path = f'{eeg_data_path}{subj}/{blockid}/{blockid}_downsampled.vhdr'
    raws[subj] = mne.io.read_raw_brainvision(raw_path,preload=True,verbose=False)
    raws[subj].set_eeg_reference(['TP9','TP10'],verbose=False)
    raws[subj].notch_filter(60,verbose=False)
    raws[subj].filter(l_freq=1,h_freq=30,verbose=True)
    # find bad channels and interpolate them so that each subj has 64 chans
    raws[subj] = interpolate_bads(subj,block,raws[subj],git_path,eeg_data_path)
    infos[subj] = raws[subj].info

In [None]:
# Load ccad data
ccas = dict()
for subj in tqdm(subjs):
    blockid = subj + '_B1'
    # get raw
    cca_path = f'{eeg_data_path}{subj}/{blockid}/{blockid}_cca.vhdr'
    ccas[subj] = mne.io.read_raw_brainvision(cca_path,preload=True,verbose=False)
    ccas[subj].filter(l_freq=1,h_freq=30,verbose=False)
    # find bad channels and interpolate them so that each subj has 64 chans
    ccas[subj] = interpolate_bads(subj,block,ccas[subj],git_path,eeg_data_path)

### Epoch data

In [None]:
# EMG epochs
raw_emg,cca_emg = dict(),dict()
emg_resp = dict()
emg_events = dict()
reject_thresh = 5 # 10 SD
for subj in tqdm(subjs):
    print(subj)
    emg_events[subj] = mne.preprocessing.find_eog_events(raws[subj],ch_name='hEOG',verbose=False,l_freq=1,h_freq=30)
    # raw
    reject = reject_epochs(raws[subj],emg_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(raws[subj],emg_events[subj],tmin=tmin,tmax=tmax,reject=reject,
                        verbose=True,reject_by_annotation=False,proj=False,flat=None)
    raw_emg[subj] = epochs.get_data(picks=picks)
    emg_resp[subj] = epochs.get_data(picks='hEOG')

    # cca
    reject = reject_epochs(ccas[subj],emg_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(ccas[subj],emg_events[subj],tmin=tmin,tmax=tmax,reject=reject,
                        verbose=True,reject_by_annotation=False,proj=False,flat=None)
    cca_emg[subj] = epochs.get_data(picks=picks)
    if subj in ['OP0016']: # cut off last epoch
        cca_emg[subj] = cca_emg[subj][:-1,:,:]

In [None]:
# Click epochs
raw_click, cca_click = dict(),dict()
click_events = dict()
for subj in tqdm(subjs):
    click_event_file = git_path + 'eventfiles/%s/%s_B1/%s_B1_click_eve.txt' % (
        subj,subj,subj)
    evs = np.loadtxt(click_event_file,dtype='f',usecols=(0,1,2))
    evs[:,:2] = evs[:,:2]*128
    click_events[subj] = np.array(evs,dtype=np.int)
    # raw
    reject = reject_epochs(raws[subj],click_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(raws[subj],click_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    raw_click[subj] = epochs.get_data(picks=picks)
    # cca
    reject = reject_epochs(ccas[subj],click_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(ccas[subj],click_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    cca_click[subj] = epochs.get_data(picks=picks)

# Load click audio (for plotting)
aud_onsets = np.array((click_events['OP0007'][:,0]/128)*25000).astype(int)
aud_offsets =np.array((click_events['OP0007'][:,1]/128)*25000).astype(int)
click_events_25k = np.array((aud_onsets,aud_offsets,click_events['OP0007'][:,2])).T
audio_fpath = f'{eeg_data_path}OP0007/OP0007_B1/OP0007_B1_audio.vhdr'
raw_audio = mne.io.read_raw_brainvision(audio_fpath, preload=True, verbose=False)
click_epochs = mne.Epochs(raw_audio,click_events_25k,tmin=tmin,tmax=tmax,reject=None,
                         baseline=(None,0),verbose=False)
click_resp = click_epochs.get_data(picks=['spkr','mic'])
click_resp = butter_highpass_filter(click_resp.T,10,raw_audio.info['sfreq']).T

In [None]:
# Perception and production epochs
raw_spkr, cca_spkr = dict(), dict()
raw_mic, cca_mic = dict(),dict()
spkr_events,mic_events = dict(),dict()
for subj in tqdm(subjs):
    # Perc
    spkr_event_file = f"{git_path}eventfiles/{subj}/{subj}_B1/{subj}_B1_spkr_sn_all.txt"
    this_event = []
    with open(spkr_event_file,'r') as f:
        r = csv.reader(f,delimiter='\t')
        for row in r:
            this_event.append(row[:3])
    event_samples = np.array(this_event,dtype=np.float)
    event_samples[:,:2] = np.round(event_samples[:,:2]*128)
    spkr_events[subj] = event_samples.astype(np.int)
    # raw
    reject = reject_epochs(raws[subj],spkr_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(raws[subj],spkr_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    raw_spkr[subj] = epochs.get_data(picks=picks)
    # cca
    reject = reject_epochs(ccas[subj],spkr_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(ccas[subj],spkr_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    cca_spkr[subj] = epochs.get_data(picks=picks)
    # Prod
    mic_event_file = f"{git_path}eventfiles/{subj}/{subj}_B1/{subj}_B1_mic_sn_all.txt"
    this_event = []
    with open(mic_event_file,'r') as f:
        r = csv.reader(f,delimiter='\t')
        for row in r:
            this_event.append(row[:3])
    event_samples = np.array(this_event,dtype=np.float)
    event_samples[:,:2] = np.round(event_samples[:,:2]*128)
    mic_events[subj] = event_samples.astype(np.int)
    # raw
    reject = reject_epochs(raws[subj],mic_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(raws[subj],mic_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    raw_mic[subj] = epochs.get_data(picks=picks)
    # cca
    reject = reject_epochs(ccas[subj],mic_events[subj],tmin,tmax,reject_thresh=reject_thresh,picks=picks)
    epochs = mne.Epochs(ccas[subj],mic_events[subj],tmin=tmin,tmax=tmax,reject=reject,verbose=False)
    cca_mic[subj] = epochs.get_data(picks=picks)
    
# Load sentence audio (for plotting)
aud_onsets = np.array((mic_events['OP0007'][:,0]/128)*25000).astype(int)
aud_offsets =np.array((mic_events['OP0007'][:,1]/128)*25000).astype(int)
sen_events_25k = np.array((aud_onsets,aud_offsets,mic_events['OP0007'][:,2])).T
audio_fpath = f'{data_path}OP0007/OP0007_B1/OP0007_B1_audio.vhdr'
raw_audio = mne.io.read_raw_brainvision(audio_fpath, preload=True, verbose=False)
sen_epochs = mne.Epochs(raw_audio,sen_events_25k,tmin=tmin,tmax=tmax,reject=None,
                         baseline=(None,0),verbose=False)
sen_resp = sen_epochs.get_data(picks=['mic'])
sen_resp = butter_highpass_filter(sen_resp.T,10,raw_audio.info['sfreq']).T

## Plot

In [None]:
fontsize=16
fig = plt.figure(figsize=(9,11))
gs = GridSpec(16, 8, figure=fig)
ax = []
x = np.linspace(tmin,tmax,raw_emg['OP0007'].shape[2])
ax.append(fig.add_subplot(gs[0:2,:4]))
sn_x = np.linspace(tmin,tmax,sen_resp.shape[2])
plt.plot(sn_x,sen_resp[12][0],color=mic_color)
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_facecolor('w')
plt.title("Production",fontsize=fontsize, loc='left', color=mic_color)

ax.append(fig.add_subplot(gs[2:4,:4]))
s = 'OP0007' 
plt.plot(x, emg_resp[s].mean(0)[0],color=aux_color);
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_facecolor('w')
plt.title("EMG Peak", fontsize=fontsize, loc='left', color=aux_color)

ax.append(fig.add_subplot(gs[0:2,4:]))
sn_x = np.linspace(tmin,tmax,sen_resp.shape[2])
plt.plot(sn_x,sen_resp[12][0],color=spkr_color)
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_facecolor('w')
plt.title("Perception",fontsize=fontsize, loc='left', color=spkr_color)

ax.append(fig.add_subplot(gs[2:4,4:]))
click_x = np.linspace(tmin,tmax,click_resp.shape[2])
plt.plot(click_x,click_resp[0][0],color=click_color)
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().set_facecolor('w')
plt.title("Click",fontsize=fontsize, loc='left', color=click_color)

ax.append(fig.add_subplot(gs[4:10,:4]))
plt.ylabel("µV±σ",fontsize=fontsize)
y = np.vstack(list(raw_emg.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=aux_color)
y = np.vstack(list(raw_mic.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=mic_color)
plt.gca().set_xlim([tmin,tmax])
plt.xticks([tmin,0,0.2,0.5,tmax],fontsize=fontsize)
plt.yticks([-20,0,20,40,60],fontsize=fontsize)
plt.gca().set_ylim([-20,60])
plt.gca().set_xticklabels([])
plt.title(" ", fontsize=fontsize)


ax.append(fig.add_subplot(gs[4:10,4:]))
y = np.vstack(list(raw_click.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=click_color)
plt.gca().set_xlim([tmin,tmax])
y = np.vstack(list(raw_spkr.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=spkr_color)
plt.xticks([tmin,0,0.2,0.5,tmax],fontsize=fontsize)
plt.yticks([-20,0,20,40,60],fontsize=fontsize)
plt.gca().set_ylim([-20,60])
plt.gca().set_yticklabels([])
plt.gca().set_xticklabels([])


ax.append(fig.add_subplot(gs[10:,:4]))
y = np.vstack(list(cca_emg.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=aux_color)
y = np.vstack(list(cca_mic.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=mic_color)
plt.gca().set_xlim([tmin,tmax])
plt.xticks([tmin,0,0.2,0.5,tmax],fontsize=fontsize)
plt.gca().set_ylim([-1.5,3.5])
plt.yticks([-1,0,1,2,3],fontsize=fontsize)
plt.ylabel("µV±σ",fontsize=fontsize)
plt.xlabel("Time (s)", fontsize=fontsize)
plt.title(" ", fontsize=fontsize)


ax.append(fig.add_subplot(gs[10:,4:]))
y = np.vstack(list(cca_click.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=click_color)
y = np.vstack(list(cca_spkr.values())).mean(1)/1e-6
gplt.errorbars(x,y,color=spkr_color)
plt.gca().set_xlim([tmin,tmax])
plt.xticks([tmin,0,0.2,0.5,tmax],fontsize=fontsize)
plt.gca().set_ylim([-1.5,3.5])
plt.yticks([-1,0,1,2,3],fontsize=fontsize)
plt.gca().set_yticklabels([])
plt.xlabel("Time (s)", fontsize=fontsize)

plt.gcf().text(0.53,0.383,"CCA data", fontsize=fontsize, ha='center')
plt.gcf().text(0.53,0.733,"Raw data", fontsize=fontsize, ha='center')

gs.tight_layout(figure=fig)

In [None]:
plt.figure(figsize=(3,2))
for i,c in enumerate(colors):
    plt.bar(0,0,color=c,label=labels[i])
plt.grid(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.legend(fontsize=fontsize,loc='center');
plt.tight_layout()