In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
from img_pipe import img_pipe
import librosa

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]
click_color = color_palette.loc[color_palette['color_id']=='click']['hex'].values[0]

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

### Task schematic waveforms

In [None]:
# QC: Look at all single sentence trials in a block
subj, block, elec = "S0006", "B2", "RPPST13"
blockid = "_".join([subj,block])
# Read textgrid
mic_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                        f"{blockid}_mic_sentence.TextGrid")
with open(mic_grid,'r') as f:
    tg = textgrid.TextGrid(f.read())
mic_sentences = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp']])
nsents = np.ceil(mic_sentences.shape[0]/2).astype(int)
mic_sentences = mic_sentences[1:nsents]
spkr_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                        f"{blockid}_spkr_sentence.TextGrid")
with open(spkr_grid, 'r') as f:
    tg = textgrid.TextGrid(f.read())
spkr_sentences = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp']])[:nsents-1]
# Load audio
audio_folder = os.path.join(data_path, f"sub-{subj}", blockid, "Audio")
mic_audio, wav_fs = librosa.load(os.path.join(audio_folder,f"{blockid}_mic.wav"),sr=None)
spkr_audio, wav_fs = librosa.load(os.path.join(audio_folder,f"{blockid}_spkr.wav"),sr=None)
# Load raw
raw = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band",
                                       "ecog_hilbAA70to150.fif"), preload=True, verbose=False)
raw_fs = raw.info['sfreq']

mic_tmin, mic_tmax = mic_sentences[1,:2].astype(float)
spkr_tmin, spkr_tmax = spkr_sentences[1,:2].astype(float)
mic_tmin = 24.8
sen_trs = spkr_sentences[1,2]
mic_clip = mic_audio[int(wav_fs*mic_tmin):int(wav_fs*mic_tmax)]
spkr_clip = spkr_audio[int(wav_fs*spkr_tmin):int(wav_fs*spkr_tmax)]
mic_tmin_samp, mic_tmax_samp = int(mic_tmin*raw_fs), int(mic_tmax*raw_fs)
mic_resp = raw.get_data(picks=[elec]).squeeze()[mic_tmin_samp:mic_tmax_samp]
spkr_tmin_samp, spkr_tmax_samp = int(spkr_tmin*raw_fs), int(spkr_tmax*raw_fs)
spkr_resp = raw.get_data(picks=[elec]).squeeze()[spkr_tmin_samp:spkr_tmax_samp]
# Plot
plt.figure(figsize=(8,4))
plt.subplot(2,1,1)
mic_clip = mic_audio.copy()
mic_clip[int(wav_fs*mic_tmax):] = 0
mic_clip = mic_clip[int(wav_fs*mic_tmin):int(wav_fs*spkr_tmax)]
spkr_clip = spkr_audio.copy()
spkr_clip[int(wav_fs*27):int(wav_fs*mic_tmax)] = 0 
spkr_clip = spkr_clip[int(wav_fs*mic_tmin):int(wav_fs*spkr_tmax)]
plt.plot(mic_clip/mic_clip.max(), color=mic_color)
plt.plot(spkr_clip/spkr_clip.max(), color=spkr_color)
plt.axis('off')
plt.title(sen_trs[0] + sen_trs[1:].lower())
plt.subplot(2,1,2)
comb_resp = raw.get_data(picks=[elec]).squeeze()[int(raw_fs*mic_tmin):int(raw_fs*spkr_tmax)]
plt.plot(comb_resp/comb_resp.max(), color='k')
plt.axvline(0,color=mic_color)
plt.axvline(int((spkr_tmin-mic_tmin)*raw_fs),color=spkr_color)
plt.axis('off')
plt.tight_layout();
plt.savefig(os.path.join(git_path,"figures","figure_2","pdf",f"single_trial_waveform_{blockid}_{elec}.pdf"))

### Individual electrode plots

In [None]:
def sem(epochs):
    '''
    calculates standard error margin across epochs
    epochs should have shape (epochs,samples)
    '''
    sem_below = epochs.mean(0) - (epochs.std(0)/np.sqrt(epochs.shape[0]))
    sem_above = epochs.mean(0) + (epochs.std(0)/np.sqrt(epochs.shape[0]))
    return sem_below, sem_above

In [None]:
subj, elec = "S0018", "PST-PI'5" # Update accordingly
tmin, tmax = -0.5, 1.0
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    spkr_epochs = epoch_data(subj,blocks[subj],git_path,dp,channel='spkr', tmin=tmin, tmax=tmax)
    mic_epochs = epoch_data(subj,blocks[subj],git_path,dp,channel='mic', tmin=tmin, tmax=tmax)
ch_names = spkr_epochs.info['ch_names']
ch_idx = ch_names.index(elec)
# Load anat
if os.path.isdir(os.path.join(data_path,f"{subj}_complete")):
    patient = img_pipe.freeCoG(f"{subj}_complete",hem='stereo',subj_dir=data_path)
    anat = patient.get_elecs()['anatomy']
    anat_idx = [a[0][0] for a in anat].index(elec)
    fs_roi = anat[anat_idx][3][0]
else:
    fs_roi = "anatomy unknown"
x = spkr_epochs.times
fig = plt.figure(figsize=(5,5))
# Plot spkr
spkr_y = spkr_epochs.get_data(picks=[elec]).squeeze()
spkr_y_below, spkr_y_above = sem(spkr_y)
plt.plot(x,spkr_y.mean(0),color=spkr_color)
plt.fill_between(x,spkr_y_below,spkr_y_above,color=spkr_color,alpha=0.3)
# Plot mic
mic_y = mic_epochs.get_data(picks=[elec]).squeeze()
mic_y_below, mic_y_above = sem(mic_y)
plt.plot(x,mic_y.mean(0),color=mic_color)
plt.fill_between(x,mic_y_below,mic_y_above,color=mic_color,alpha=0.3)
# Plt decorations
plt.title(f"{subj} {elec} {fs_roi}", fontsize=14)
plt.axvline(0,color='k')
# Plt settings
xlims = [x[0], x[-1]]
plt.gca().set_xlim(xlims)
xticks = np.round(np.arange(xlims[0],xlims[-1]+.5,.5),decimals=1)
plt.gca().set_xticks(xticks)
plt.gca().set_xticklabels(xticks,fontsize=12)
plt.savefig(os.path.join(git_path,"figures","figure_2","pdf",f"{subj}_{elec}_onset_sup.pdf"))

### Colorbar

In [None]:
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:954997-20:954997-45:FFFFFF-50:FFFFFF-55:FFFFFF-80:117632-100:117632
    (0.000, (0.584, 0.286, 0.592)),
    (0.200, (0.584, 0.286, 0.592)),
    (0.450, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.550, (1.000, 1.000, 1.000)),
    (0.800, (0.067, 0.463, 0.196)),
    (1.000, (0.067, 0.463, 0.196))))
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_2","pdf","colorbar.pdf"))

### Strip plot

In [None]:
# Epoch data
si_tmin = 0; si_tmax = 1; erp_tmin = -.5; erp_tmax = 2; reject = None; baseline = None
epochs = dict(); ch_names = dict()
for s in tqdm(subjs):
    epochs[s] = dict(); spkr_epochs, mic_epochs = [], []
    for b in blocks[s]:
        blockid = f'{s}_{b}'
        raw_fpath = os.path.join(data_path,f"sub-{s}",s,blockid,"HilbAA_70to150_8band",
                                 "ecog_hilbAA70to150.fif")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        ch_names[s] = raw.info['ch_names']
        fs = raw.info['sfreq']
        # Spkr events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_spkr_sn_all.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        spkr_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
        # Mic events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_mic_sn_all.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        mic_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
    epochs[s]['spkr'] = mne.concatenate_epochs(spkr_epochs)
    epochs[s]['mic'] = mne.concatenate_epochs(mic_epochs)

In [None]:
# Calc SI
x = epochs[s]['mic'].times
si_inds = [np.where(x==si_tmin)[0][0],np.where(x==si_tmax)[0][0]]
si = dict()
for s in subjs:
    subj_si = []
    for i,ch in enumerate(epochs[s]['mic'].info['ch_names']):
        spkr_resp = epochs[s]['spkr'].get_data(picks=[ch]).squeeze()[:,si_inds[0]:si_inds[1]].mean(0).mean(0)
        mic_resp = epochs[s]['mic'].get_data(picks=[ch]).squeeze()[:,si_inds[0]:si_inds[1]].mean(0).mean(0)
        subj_si.append(spkr_resp-mic_resp)
    si[s] = np.array(subj_si)
# Normalize between 0 and 1
si_min = np.hstack((list(si.values()))).min()
si_max = np.hstack((list(si.values()))).max()
for s in subjs:
    si[s] = ((si[s]-si_min)/(si_max-si_min)) * 2 - 1

In [None]:
# Create colormap
# Shift SI values so that the max value is 1 without normalizing
# Just for cmap! We still report native cmap values
# Max val for spkr - we shift so this is 1
max_val = np.hstack((list(si.values()))).max(); min_val = np.hstack((list(si.values()))).min()
# Apply the shift
normed_si = dict()
for s in subjs:
    subj_si = []
    for ch in si[s]:
        subj_si.append((ch-min_val)/(max_val-min_val))
    # Normalize
    normed_si[s] = np.array(subj_si)
elec_colors_rgb = dict()
elec_colors_hex = dict()
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:954997-20:954997-45:FFFFFF-50:FFFFFF-55:FFFFFF-80:117632-100:117632
    (0.000, (0.584, 0.286, 0.592)),
    (0.200, (0.584, 0.286, 0.592)),
    (0.450, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.550, (1.000, 1.000, 1.000)),
    (0.800, (0.067, 0.463, 0.196)),
    (1.000, (0.067, 0.463, 0.196))))
for s in subjs:
    elec_colors_rgb[s] = [cmap(f) for f in normed_si[s]]
    elec_colors_hex[s] = np.array(
        [plotting_utils.rgb_to_hex(int(ec[0]*255),int(ec[1]*255),int(ec[2]*255)) for ec in elec_colors_rgb[s]])

In [None]:
# Get values
excluded_chs = dict()
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
# Initialize dicts for plot
si_by_roi = pd.DataFrame(columns=["subj","ch_name","hem","fs_roi","gross_anat",
                                  "condensed_roi","si_value","cmap_value"])
for hem in ['lh','rh']:
    excluded_chs[hem] = dict()
    for s in tqdm(hems.keys()):
        pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
        fs_labels = [a[3][0] for a in pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']]
        if hem in hems[s]:
            excluded_chs[hem][s] = []
            # Load anatomy
            # Load df to get SI information of each channel
            ch_names = np.array(epochs[s]['spkr'].info['ch_names'])
            excl_ch_names = list(excl_df.loc[excl_df['subject']==s]['channel'].values)
            for fif_idx,ch in enumerate(ch_names):
                if ch in fs_labels:
                    fs_idx = fs_labels.index(ch)
                    condensed_roi = imaging_utils.condense_roi(fs_labels[fs_idx])
                if ch not in excl_ch_names:
                    gross_anat = imaging_utils.gross_anat(fs_labels[fs_idx])
                    if condensed_roi in ['frontal','temporal','parietal','occipital','precentral',
                                         'postcentral','insula']:
                        new_row = pd.DataFrame({"subj":[s],"ch_name":[ch],"hem":[hem],
                            "fs_roi":[fs_labels[fs_idx]],"gross_anat":[gross_anat],
                            "condensed_roi":[condensed_roi],"si_value":[si[s][fif_idx]],
                            "cmap_value":[normed_si[s][fif_idx]]})
                        si_by_roi = si_by_roi.append(new_row, ignore_index=True)
                    elif condensed_roi == "whitematter":
                        new_row = pd.DataFrame({"subj":[s],"ch_name":[ch],"hem":[hem],
                            "fs_roi":[fs_labels[fs_idx]],"gross_anat":["subcort"],
                            "condensed_roi":[condensed_roi],"si_value":[si[s][fif_idx]],
                            "cmap_value":[normed_si[s][fif_idx]]})
                        si_by_roi = si_by_roi.append(new_row, ignore_index=True)

In [None]:
all_rois = list(np.unique(si_by_roi['condensed_roi']))
ignore_rois = ['wm', 'outside_brain']
plot_rois_rh, plot_rois_lh = [], []
nroi_thresh = 3
for roi in all_rois:
    if np.where(
        np.array(si_by_roi['condensed_roi'].loc[si_by_roi['hem']=='lh'].values)==roi
    )[0].shape[0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_lh.append(roi)
    if np.where(
        np.array(si_by_roi['condensed_roi'].loc[si_by_roi['hem']=='rh'].values)==roi
    )[0].shape[0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_rh.append(roi)
print(f"{len(all_rois)} condensed FreeSurfer ROIs in database. Plotting {len(plot_rois_rh)} for RH and {len(plot_rois_lh)} for LH.")
si_by_roi_rh = si_by_roi.loc[[r in plot_rois_rh for r in si_by_roi['condensed_roi']]]
si_by_roi_rh = si_by_roi_rh.loc[si_by_roi_rh['hem']=='rh']
si_by_roi_lh = si_by_roi.loc[[r in plot_rois_lh for r in si_by_roi['condensed_roi']]]
si_by_roi_lh = si_by_roi_lh.loc[si_by_roi_lh['hem']=='lh']

In [None]:
# A version collapsed across hemispheres (for use in Figure 1)
plot_rois = list(np.unique(plot_rois_rh + plot_rois_lh))
# Reorder ROIs in a more logical way
roi_order = [
    0,6,9,10,4,21, # temporal: HG/PT/STG/STS/MTG/temp_pole
    19,17,18,7,8,3,1,2,5, # frontal: subcent/precg/cs/sfg/sfs/mfg/ifg/ifs/ofc
    11,20,16, # parietal: angular/supramar/postcg
    12,14,15,13# insular: ant/post/sup/inf
]
plot_rois = list(np.array(plot_rois)[roi_order])
si_by_roi_condensed = si_by_roi.loc[[r in plot_rois for r in si_by_roi['condensed_roi_md']]]
# Remake the palette
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:954997-20:954997-45:FFFFFF-50:FFFFFF-55:FFFFFF-80:117632-100:117632
    (0.000, (0.584, 0.286, 0.592)),
    (0.200, (0.584, 0.286, 0.592)),
    (0.450, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.550, (1.000, 1.000, 1.000)),
    (0.800, (0.067, 0.463, 0.196)),
    (1.000, (0.067, 0.463, 0.196))))
si_native_values = []; si_norm_values = []
for row in si_by_roi_condensed.values:
    s, ch, _, _, _, _, si_native, si_norm = row
    si_native_values.append(si_native)
    si_norm_values.append(si_norm)
strip_palette = dict()
for i,v in enumerate(si_norm_values):
    strip_palette[si_native_values[i]] = cmap(v)
# Plot
fig = plt.figure(figsize=(6,14))
cax = sns.stripplot(
    data=si_by_roi_condensed, x='si_value', y='condensed_roi_md',
    order=plot_rois,hue='si_value',palette=strip_palette,linewidth=0.5,edgecolor='k',size=10
)
cax.get_legend().remove()
plt.axvline(0,color='k',ls='--',lw=0.75)
plt.gca().set_xlim([-1.1,1.1])
plt.xlabel("Suppression index (Z)", fontsize=12)
plt.ylabel("Condensed ROI", fontsize=12)
plt.gca().set_yticks(np.arange(len(plot_rois)))
plt.gca().set_yticklabels(plot_rois);
plt.savefig(os.path.join("figures","figure_2","pdf","si_by_roi_scatter.pdf"))