In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
from img_pipe import img_pipe
import librosa

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

smc_blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_smc_mic.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]
click_color = color_palette.loc[color_palette['color_id']=='click']['hex'].values[0]

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

### Individual electrode plots

In [None]:
def sem(epochs):
    '''
    calculates standard error margin across epochs
    epochs should have shape (epochs,samples)
    '''
    sem_below = epochs.mean(0) - (epochs.std(0)/np.sqrt(epochs.shape[0]))
    sem_above = epochs.mean(0) + (epochs.std(0)/np.sqrt(epochs.shape[0]))
    return sem_below, sem_above

In [None]:
subj, elec = "S0018", "PST-PI'5" # Update accordingly
tmin, tmax = -0.5, 1.0
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    spkr_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='spkr', tmin=tmin, tmax=tmax)
    mic_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='mic', tmin=tmin, tmax=tmax)
ch_names = spkr_epochs.info['ch_names']
ch_idx = ch_names.index(elec)
# Load anat
if os.path.isdir(os.path.join(data_path,f"{subj}_complete")):
    patient = img_pipe.freeCoG(f"{subj}_complete",hem='stereo',subj_dir=data_path)
    anat = patient.get_elecs()['anatomy']
    anat_idx = [a[0][0] for a in anat].index(elec)
    fs_roi = anat[anat_idx][3][0]
else:
    fs_roi = "anatomy unknown"
x = spkr_epochs.times
fig = plt.figure(figsize=(5,5))
# Plot spkr
spkr_y = spkr_epochs.get_data(picks=[elec]).squeeze()
spkr_y_below, spkr_y_above = sem(spkr_y)
plt.plot(x,spkr_y.mean(0),color=spkr_color)
plt.fill_between(x,spkr_y_below,spkr_y_above,color=spkr_color,alpha=0.3)
# Plot mic
mic_y = mic_epochs.get_data(picks=[elec]).squeeze()
mic_y_below, mic_y_above = sem(mic_y)
plt.plot(x,mic_y.mean(0),color=mic_color)
plt.fill_between(x,mic_y_below,mic_y_above,color=mic_color,alpha=0.3)
# Plt decorations
plt.title(f"{subj} {elec} {fs_roi}", fontsize=14)
plt.axvline(0,color='k')
# Plt settings
xlims = [x[0], x[-1]]
plt.gca().set_xlim(xlims)
xticks = np.round(np.arange(xlims[0],xlims[-1]+.5,.5),decimals=1)
plt.gca().set_xticks(xticks)
plt.gca().set_xticklabels(xticks,fontsize=12)
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf",f"{subj}_{elec}_dual_onset.pdf"))

### Colorbar

In [None]:
# Create legends
mic_cmap = LinearSegmentedColormap.from_list('my_gradient', (
    (0.000, (1.00, 1.000, 1.000)),
    (0.700, (0.584, 0.286, 0.592)),
    (1.000, (0.584, 0.286, 0.592))))
spkr_cmap = LinearSegmentedColormap.from_list('my_gradient', (
    (0.000, (1.00, 1.000, 1.000)),
    (0.700, (0.067, 0.463, 0.196)),
    (1.000, (0.067, 0.463, 0.196))))
dual_onset_cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:954997-15:954997-40:FFFFFF-50:FFFFFF-60:FFFFFF-85:117632-100:117632
    (0.000, (0.584, 0.286, 0.592)),
    (0.150, (0.584, 0.286, 0.592)),
    (0.400, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.600, (1.000, 1.000, 1.000)),
    (0.850, (0.067, 0.463, 0.196)),
    (1.000, (0.067, 0.463, 0.196))))
no_onset_cmap = cm.Greys

plt.figure(figsize=(8,2))
plt.subplot(3,1,1)
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=mic_cmap)
plt.axis('off');
plt.subplot(3,1,2)
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=spkr_cmap)
plt.axis('off');
plt.subplot(3,1,3)
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=dual_onset_cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","onset_categorical_legend.pdf"))

### Individual electrode plots (with speech motor control rasters)

In [None]:
labels_ext = list(np.loadtxt(os.path.join(git_path,"task","speechmotor","SupplementalFiles","tasks.txt"),
                             dtype=str,delimiter='\n'))
labels = labels_ext[:9]
labels_ext = np.array([labels.index(l) for l in labels_ext])
event_ids =  np.arange(9)

In [None]:
subj, elec = "TCH06", "RPILG5"
if len(smc_blocks[subj]) > 1:
    raise Exception("This code block doesn't support block concatenation.")
else:
    block = smc_blocks[subj][0]
smc_tmin, smc_tmax = -2, 3.
blockid = "_".join([subj,block])
raw_fpath = os.path.join(
    data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    raw = mne.io.read_raw_fif(raw_fpath, preload=True, verbose=False)
    resp = raw.get_data(picks=[elec]).squeeze(); fs = raw.info['sfreq']
    ch_names = raw.info['ch_names']; elec_idx = ch_names.index(elec)
# Get SMC events
block_evs = []
evs = np.loadtxt(os.path.join(git_path,"preprocessing","events","csv",subj,blockid,f"{blockid}_click_eve.txt"))
evs = np.vstack(((evs[:,0]*fs).astype(int),(evs[:,1]*fs).astype(int),evs[:,2].astype(int))).T
for i in np.arange(9):
    ev_idxs = np.where(evs[:,2]==i)[0]
    for ev_idx in ev_idxs:
        block_evs.append([int(evs[ev_idx,0]),int(evs[ev_idx,1]),int(evs[ev_idx,2])])
evs = np.array(block_evs)
voc_evs = np.loadtxt(os.path.join(git_path,"preprocessing","events","csv",subj,blockid,f"{blockid}_smc_mic.txt"))
voc_evs = np.vstack(((voc_evs[:,0]*fs).astype(int),(voc_evs[:,1]*fs).astype(int),voc_evs[:,2].astype(int))).T
evs_no_voc = evs[:-voc_evs.shape[0]]; evs_yes_voc = evs[-voc_evs.shape[0]:]
evs_yes_voc = evs_yes_voc[np.argsort(evs_yes_voc[:,0])]
sorted_idxs = []; latencies = []
for ev in evs_yes_voc:
    all_latencies = voc_evs[:,0]-ev[0]
    min_latency = np.array([l for l in all_latencies if l >=0]).min()
    min_idx = np.where(all_latencies==min_latency)[0][0]
    sorted_idxs.append(min_idx); latencies.append(min_latency)
sorted_latencies = np.array(latencies)[np.argsort(latencies)][::-1]
sorted_idxs = list(np.array(sorted_idxs)[np.argsort(latencies)][::-1])
evs_yes_voc = evs_yes_voc[sorted_idxs]; voc_evs = voc_evs[sorted_idxs]
epochs_no_voc = np.zeros((evs_no_voc.shape[0],int((smc_tmax*fs)-(smc_tmin*fs))))
epochs_voc_click = np.zeros((evs_yes_voc.shape[0],int((smc_tmax*fs)-(smc_tmin*fs))))
epochs_voc_onset = np.zeros((voc_evs.shape[0],int((smc_tmax*fs)-(smc_tmin*fs))))
for i,ev in enumerate(evs_no_voc):
    ev_onset = ev[0]+int(smc_tmin*fs); ev_offset = ev[0]+int(smc_tmax*fs)
    epochs_no_voc[i,:] = resp[ev_onset:ev_offset]
for i,ev in enumerate(evs_yes_voc):
    ev_onset = ev[0]+int(smc_tmin*fs); ev_offset = ev[0]+int(smc_tmax*fs)
    epochs_voc_click[i,:] = resp[ev_onset:ev_offset]
for i,ev in enumerate(voc_evs):
    ev_onset = ev[0]+int(smc_tmin*fs); ev_offset = ev[0]+int(smc_tmax*fs)
    epochs_voc_onset[i,:] = resp[ev_onset:ev_offset]

In [None]:
ymin, ymax = -0.75,0.95; vmax = 1.5
fig = plt.figure(figsize=(6,8));
# Click ERP waveform
plt.subplot(4,1,1)
plt.title(f"{subj} {elec}", fontsize=14)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    spkr_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='spkr',tmin=smc_tmin,tmax=smc_tmax)
    mic_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='mic',tmin=smc_tmin,tmax=smc_tmax)
ch_names = spkr_epochs.info['ch_names']; x = spkr_epochs.times
spkr_y = spkr_epochs.get_data(picks=[elec]).squeeze(); spkr_y_below, spkr_y_above = sem(spkr_y)
plt.plot(x,spkr_y.mean(0),color=spkr_color)
plt.fill_between(x,spkr_y_below,spkr_y_above,color=spkr_color,alpha=0.3)
mic_y = mic_epochs.get_data(picks=[elec]).squeeze(); mic_y_below, mic_y_above = sem(mic_y)
plt.plot(x,mic_y.mean(0),color='mic_color')
plt.fill_between(x,mic_y_below,mic_y_above,color=mic_color,alpha=0.3)
plt.axvline(0,color=click_color)
plt.gca().set_xticklabels([]);
xlims = [x[0], x[-1]]; plt.gca().set_xlim(xlims); plt.gca().set_ylim([ymin,ymax])
plt.ylabel("ERP\n(Hγ±Z)")
# SMC w/o vocalization raster
plt.subplot(4,1,2)
plt.imshow(epochs_no_voc,aspect='auto',interpolation='nearest', 
           cmap=sns.color_palette("light:gray",as_cmap=True),vmin=-vmax,vmax=vmax); plt.grid(False)
plt.axvline(np.where(x==0)[0][0],color=click_color)
plt.gca().set_xticklabels([]); plt.gca().set_yticklabels([])
plt.ylabel("Single-trial raster\n(No voc.)")
# SMC w/ vocalization raster (epoched to click onset)
plt.subplot(4,1,3)
plt.imshow(epochs_voc_click,aspect='auto',interpolation='nearest',
           cmap=sns.color_palette("light:gray",as_cmap=True),vmin=-vmax,vmax=vmax); plt.grid(False)
for y,l in enumerate(sorted_latencies):
    plt.plot(np.where(x==0)[0][0]+l,y,marker='|',color='r',ms=20,mec='r',mew=2)
plt.axvline(np.where(x==0)[0][0],color=click_color)
plt.gca().set_xticklabels([]); plt.gca().set_yticklabels([])
plt.ylabel("Single-trial raster\n(Voc., t=0 is click)")
# SMC w/ vocalization raster (epoched to vocalization onset)
plt.subplot(4,1,4)
plt.imshow(epochs_voc_onset,aspect='auto',interpolation='nearest',
           cmap=sns.color_palette("light:gray",as_cmap=True),vmin=-vmax,vmax=vmax); plt.grid(False)
for y,l in enumerate(sorted_latencies):
    plt.plot(np.where(x==0)[0][0]-l,y,marker='|',color=click_color,ms=20,mec=click_color,mew=2)
plt.axvline(np.where(x==0)[0][0],color='r')
plt.gca().set_yticklabels([]);
plt.gca().set_xticks([0,np.where(x==-1)[0][0],np.where(x==0)[0][0],np.where(x==1)[0][0],
                      np.where(x==2)[0][0],len(x)])
plt.gca().set_xticklabels([smc_tmin,-1,0,1,2,smc_tmax]);
plt.ylabel("Single-trial raster\n(Voc., t=0 is voc. onset)"); plt.xlabel("Time (sec)"); plt.tight_layout();
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf",f"{subj}_{elec}_smc_imshow.pdf"))

### Speech motor colorbar

In [None]:
# SMC legend
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',
           cmap=sns.color_palette("light:gray",as_cmap=True))
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","legend_smc.pdf"))

### Onset quantification schematic

In [None]:
subj, elec = "S0007", "PSF-PI3"; thresh_tmin, thresh_tmax = 0.5, 1.5
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    spkr_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='spkr')
    mic_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='mic')
ch_names = spkr_epochs.info['ch_names']; ch_idx = ch_names.index(elec)
# Load anat
if os.path.isdir(os.path.join(ip,f"{subj}_complete")):
    patient = img_pipe.freeCoG(f"{subj}_complete",hem='stereo',subj_dir=data_path)
    anat = patient.get_elecs()['anatomy']; anat_idx = [a[0][0] for a in anat].index(elec)
    fs_roi = anat[anat_idx][3][0]
else:
    fs_roi = "anatomy unknown"
x = spkr_epochs.times; fig = plt.figure(figsize=(7,3))
spkr_y = spkr_epochs.get_data(picks=[elec]).squeeze(); mic_y = mic_epochs.get_data(picks=[elec]).squeeze()
# Calculate threshold
spkr_onset_threshold = spkr_y[:,np.where(x==thresh_tmin)[0][0]:np.where(x==thresh_tmax)[0][0]].mean()
mic_onset_threshold = mic_y[:,np.where(x==thresh_tmin)[0][0]:np.where(x==thresh_tmax)[0][0]].mean()
# Plot spkr
plt.subplot(1,2,1)
spkr_y_below, spkr_y_above = sem(spkr_y)
plt.plot(x,spkr_y.mean(0),color=spkr_color)
plt.fill_between(x,spkr_y_below,spkr_y_above,color=spkr_color,alpha=0.3)
plt.title("Perception", fontsize=12)
plt.axvline(0,color='k')
idx0 = np.where(x==0)[0][0]
idx300 = np.where(x==0.3)[0][0]
if len(np.where(spkr_y.mean(0)[idx0:idx300] > spkr_onset_threshold)[0]) > 0:
    enter_onset_spkr = [idx for idx in np.where(spkr_y.mean(0) > spkr_onset_threshold)[0] if x[
        idx] >= 0 and x[idx] <= 0.3][0]
    exit_onset_spkr = [idx for idx in np.where(spkr_y.mean(0) > spkr_onset_threshold)[0] if x[
        idx] >= x[enter_onset_spkr] and idx+1 not in np.where(spkr_y.mean(0) > spkr_onset_threshold)[0]]
    if len(exit_onset_spkr) > 0:
        exit_onset_spkr = exit_onset_spkr[0]; plt_spkr_onset = True
    else:
        plt_spkr_onset = False
else:
    plt_spkr_onset=False
if plt_spkr_onset:
    spkr_peak_amp = spkr_y.mean(0)[enter_onset_spkr:exit_onset_spkr].max()
    spkr_peak_lat = x[np.where(spkr_y.mean(0)==spkr_peak_amp)[0][0]]
    spkr_marker_loc = spkr_y_above[np.where(x==spkr_peak_lat)[0][0]]
    plt.plot(spkr_peak_lat,spkr_marker_loc,'v',color='k',ms=10,label="onset peak amplitude")
    plt.fill_between(
        x[enter_onset_spkr:exit_onset_spkr],np.repeat(spkr_onset_threshold,x[
            enter_onset_spkr:exit_onset_spkr].shape[0]),spkr_y.mean(0)[enter_onset_spkr:exit_onset_spkr],
        color='orange',label="onset response", alpha=1)
plt.axhline(spkr_onset_threshold,color='k',ls='--')
xlims = [x[0], x[-1]]; plt.gca().set_xlim(xlims)
xticks = np.round(np.arange(xlims[0],xlims[-1]+.5,.5),decimals=1); plt.gca().set_xticks(xticks)
plt.gca().set_xticklabels(xticks,fontsize=12); plt.gca().set_ylim([ymin,ymax])
# Plot mic
plt.subplot(1,2,2)
mic_y_below, mic_y_above = sem(mic_y); plt.plot(x,mic_y.mean(0),color=mic_color)
plt.fill_between(x,mic_y_below,mic_y_above,color=mic_color,alpha=0.3)
plt.title("Production", fontsize=12); plt.axvline(0,color='k')
if len(np.where(mic_y.mean(0)[idx0:idx300] > mic_onset_threshold)[0]) > 0:
    enter_onset_mic = [idx for idx in np.where(mic_y.mean(0) > mic_onset_threshold)[0] if x[idx] >= 0 and x[
        idx] <= 0.3][0]
    exit_onset_mic = [idx for idx in np.where(mic_y.mean(0) > mic_onset_threshold)[0] if x[idx] >= x[
        enter_onset_mic] and idx+1 not in np.where(mic_y.mean(0) > mic_onset_threshold)[0]]
    if len(exit_onset_mic) > 0:
        exit_onset_mic = exit_onset_mic[0]; plt_mic_onset=True
    else:
        plt_mic_onset = False
else:
    plt_mic_onset=False
if plt_mic_onset:
    mic_peak_amp = mic_y.mean(0)[enter_onset_mic:exit_onset_mic].max()
    mic_peak_lat = x[np.where(mic_y.mean(0)==mic_peak_amp)[0][0]]
    mic_marker_loc = mic_y_above[np.where(x==mic_peak_lat)[0][0]]
    plt.plot(mic_peak_lat,mic_marker_loc,'v',color='k',ms=10,label="onset peak amplitude")
    plt.fill_between(x[enter_onset_mic:exit_onset_mic],np.repeat(mic_onset_threshold,x[
        enter_onset_mic:exit_onset_mic].shape[0]),mic_y.mean(0)[enter_onset_mic:exit_onset_mic],
    color='orange',label="onset response", alpha=1,)
plt.axhline(mic_onset_threshold,color='k',ls='--')
plt.gca().set_xlim(xlims); plt.gca().set_xticks(xticks); plt.gca().set_xticklabels(xticks,fontsize=12)
plt.gca().set_ylim([ymin,ymax]); plt.legend(frameon=True,framealpha=1);
plt.suptitle(f"{subj} {elec}", fontsize=14); plt.tight_layout();
plt.savefig(os.path.join(git_path, "figures", "figure_3", "pdf", f"{subj}_{elec}_onset_quant.pdf"))

### Strip plot

In [None]:
# Load onset df
onset_df = pd.read_csv(os.path.join(git_path,"stats","onset_stats.csv"))
# Clip to only ROIs with >3 elecs
nroi_thresh = 3 * 2 # multiply by 2 as each elec has two rows in onset_df
all_rois = list(np.unique(onset_df['condensed_roi']))
ignore_rois = ['wm', 'outside_brain']; plot_rois_rh, plot_rois_lh = [], []
for roi in all_rois:
    if np.where(np.array(onset_df.loc[onset_df['hem']=='lh']['condensed_roi'].values)==roi)[0].shape[
        0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_lh.append(roi)
    if np.where(np.array(onset_df.loc[onset_df['hem']=='rh']['condensed_roi'].values)==roi)[0].shape[
        0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_rh.append(roi)
print(
    f"{len(all_rois)} condensed FreeSurfer ROIs in database. Plotting {len(plot_rois_rh)} for RH and {len(plot_rois_lh)} for LH.")
onset_df_lh = onset_df.loc[[r in plot_rois_lh for r in onset_df['condensed_roi']]]
onset_df_lh = onset_df_lh.loc[onset_df_lh['hem']=='lh']
onset_df_rh = onset_df.loc[[r in plot_rois_rh for r in onset_df['condensed_roi']]]
onset_df_rh = onset_df_rh.loc[onset_df_rh['hem']=='rh']
# Condensed both hemispheres into a df
plot_rois = list(np.unique(plot_rois_rh + plot_rois_lh))
roi_order = [
    0,6,9,10,4,21, # temporal: HG/PT/STG/STS/MTG/temp_pole
    19,17,18,7,8,3,1,2,5, # frontal: subcent/precg/cs/sfg/sfs/mfg/ifg/ifs/ofc
    11,20,16, # parietal: angular/supramar/postcg
    12,14,15,13# insular: ant/post/sup/inf
]
plot_rois = list(np.array(plot_rois)[roi_order])
onset_df_condensed = onset_df.loc[[r in plot_rois for r in onset_df['condensed_roi']]]

In [None]:
# Calc onset diffs
onset_diffs = pd.DataFrame(columns=['subj','ch_name','hem','fs_roi','gross_anat','condensed_roi',
    'onset_start','onset_stop','peak_amplitude','peak_latency'])
all_chs = []
for row in onset_df_condensed.values:
    all_chs.append(f"{row[0]}_{row[1]}")
all_chs = np.unique(all_chs)
for row in all_chs:
    s,ch = row.split("_")
    df_clip = onset_df_condensed.loc[(onset_df_condensed['subj']==s)&(onset_df_condensed['ch_name']==ch)]
    hem = df_clip['hem'].values[0]; fs_roi = df_clip['fs_roi'].values[0]
    gross_anat = df_clip['gross_anat'].values[0]; condensed_roi = df_clip['condensed_roi'].values[0]
    spkr_onset_start, spkr_onset_stop = [float(re.sub(r'\[|\]','',d)) for d in df_clip.loc[df_clip[
        'condition']=='spkr']['onset_times'].values[0].split(', ')]
    mic_onset_start, mic_onset_stop = [float(re.sub(r'\[|\]','',d)) for d in df_clip.loc[df_clip[
        'condition']=='mic']['onset_times'].values[0].split(', ')]
    mic_peak_amplitude = df_clip.loc[df_clip['condition']=='mic']['peak_amplitude'].values[0]
    mic_peak_latency = df_clip.loc[df_clip['condition']=='mic']['peak_latency'].values[0]
    spkr_peak_amplitude = df_clip.loc[df_clip['condition']=='spkr']['peak_amplitude'].values[0]
    spkr_peak_latency = df_clip.loc[df_clip['condition']=='spkr']['peak_latency'].values[0]
    if True not in [np.isnan(v) for v in [spkr_onset_start, spkr_onset_stop, mic_onset_start, mic_onset_stop]]:
        new_row = pd.DataFrame({"subj":[s],"ch_name":[ch],"hem":[hem],"fs_roi":[fs_roi],
            "gross_anat":[gross_anat],"condensed_roi":[condensed_roi],
            "onset_start":[spkr_onset_start-mic_onset_start],"onset_stop":[spkr_onset_stop-mic_onset_stop],
            "peak_amplitude":[spkr_peak_amplitude-mic_peak_amplitude],
            "peak_latency":[spkr_peak_latency-mic_peak_latency]})
        onset_diffs = onset_diffs.append(new_row, ignore_index=True)
# Clip
onset_df_spkronly = pd.DataFrame(columns=list(onset_df.columns))
onset_df_miconly = pd.DataFrame(columns=list(onset_df.columns))
for row in all_chs:
    s,ch = row.split("_")
    df_clip = onset_df_condensed.loc[(onset_df_condensed['subj']==s)&(onset_df_condensed['ch_name']==ch)]
    hem = df_clip['hem'].values[0]; fs_roi = df_clip['fs_roi'].values[0]
    gross_anat = df_clip['gross_anat'].values[0]; condensed_roi = df_clip['condensed_roi'].values[0]
    spkr_onset_start, spkr_onset_stop = [float(re.sub(r'\[|\]','',d)) for d in df_clip.loc[df_clip[
        'condition']=='spkr']['onset_times'].values[0].split(', ')]
    mic_onset_start, mic_onset_stop = [float(re.sub(r'\[|\]','',d)) for d in df_clip.loc[df_clip[
        'condition']=='mic']['onset_times'].values[0].split(', ')]
    mic_peak_amplitude = df_clip.loc[df_clip['condition']=='mic']['peak_amplitude'].values[0]
    mic_peak_latency = df_clip.loc[df_clip['condition']=='mic']['peak_latency'].values[0]
    spkr_peak_amplitude = df_clip.loc[df_clip['condition']=='spkr']['peak_amplitude'].values[0]
    spkr_peak_latency = df_clip.loc[df_clip['condition']=='spkr']['peak_latency'].values[0]
    if np.isnan(mic_peak_amplitude) and not np.isnan(spkr_peak_amplitude):
        onset_df_spkronly = onset_df_spkronly.append(pd.DataFrame({'subj':[s],'ch_name':[ch],'hem':[hem],
            'fs_roi':[fs_roi],'gross_anat':[gross_anat],'condensed_roi':[condensed_roi],'condition':['spkr'],
            'onset_times':[[spkr_onset_start,spkr_onset_stop]],'peak_amplitude':[spkr_peak_amplitude],
            'peak_latency':[spkr_peak_latency]}),ignore_index=True)
    if np.isnan(spkr_peak_amplitude) and not np.isnan(mic_peak_amplitude):
        onset_df_miconly = onset_df_miconly.append(pd.DataFrame({'subj':[s],'ch_name':[ch],'hem':[hem],
            'fs_roi':[fs_roi],'gross_anat':[gross_anat],'condensed_roi':[condensed_roi],'condition':['mic'],
            'onset_times':[[mic_onset_start,mic_onset_stop]],'peak_amplitude':[mic_peak_amplitude],
            'peak_latency':[mic_peak_latency]}),ignore_index=True)

In [None]:
# Make a dict of colors for the electrodes on the strip plot
diff_native_values = []
for row in onset_diffs.values:
    s,ch,_,_,_,_,_,_,peak_amp,_ = row; diff_native_values.append(peak_amp)
diff_min = np.hstack((diff_native_values)).min(); diff_max = np.hstack((diff_native_values)).max()
dual_onset_palette = dict()
for v in diff_native_values:
    norm_v = (v-diff_min)/(diff_max-diff_min); dual_onset_palette[v] = dual_onset_cmap(norm_v)
spkr_native_values = []
for row in onset_df_spkronly.values:
    s,ch,_,_,_,_,_,_,peak_amp,_ = row; spkr_native_values.append(peak_amp)
spkr_min = np.hstack((spkr_native_values)).min(); spkr_max = np.hstack((spkr_native_values)).max()
spkronly_palette = dict()
for v in spkr_native_values:
    norm_v = (v-spkr_min)/(spkr_max-spkr_min); spkronly_palette[v] = spkr_cmap(norm_v)
mic_native_values = []
for row in onset_df_miconly.values:
    s,ch,_,_,_,_,_,_,peak_amp,_ = row; mic_native_values.append(peak_amp)
mic_min = np.hstack((mic_native_values)).min(); mic_max = np.hstack((mic_native_values)).max()
miconly_palette = dict()
for v in mic_native_values:
    norm_v = (v-mic_min)/(mic_max-mic_min); miconly_palette[v] = mic_cmap(norm_v)

In [None]:
plt.figure(figsize=(10,8)); msize=10
# dual onset
plt.subplot(1,3,3); plt.title("Dual onset", fontsize=14)
cax = sns.stripplot(data=onset_diffs,x='peak_amplitude',y='condensed_roi',hue='peak_amplitude',order=plot_rois,
                    palette=dual_onset_palette,linewidth=0.5,edgecolor='k',size=msize)
cax.get_legend().remove(); plt.axvline(0,color='k',lw=0.5,ls='--')
plt.gca().set_xlim([-3,1.5]); plt.gca().set_xticks([-3,-2,-1,0,1]); plt.gca().set_yticklabels([])
plt.ylabel(''); plt.xlabel("Peak amplitude difference\n(Z; perception-production)")
# spkr only
plt.subplot(1,3,1); plt.title("Perception onset only", fontsize=14)
cax = sns.stripplot(data=onset_df_spkronly,x='peak_amplitude',y='condensed_roi',hue='peak_amplitude',
                    order=plot_rois,palette=spkronly_palette,linewidth=0.5,edgecolor='k',size=msize)
cax.get_legend().remove()
plt.xlabel("Peak amplitude (Z)"); plt.gca().set_xlim([0,2.5]); plt.gca().set_xticks([0,1,2])
# mic only
plt.subplot(1,3,2); plt.title("Production onset only", fontsize=14)
cax = sns.stripplot(data=onset_df_miconly,x='peak_amplitude',y='condensed_roi',hue='peak_amplitude',
                    order=plot_rois, palette=miconly_palette, linewidth=0.5, edgecolor='k',size=msize)
cax.get_legend().remove()
plt.ylabel(''); plt.xlabel("Peak amplitude (Z)")
plt.gca().set_xlim([0,2]); plt.gca().set_xticks([0,1,2]);plt.gca().set_yticklabels([])
plt.tight_layout(); plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","peak_amplitude_scatter.pdf"))