In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
from img_pipe import img_pipe
import librosa

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
# these are the subjects that don't have inconsistent playback trials and
# are therefore excluded from analysis in this notebook
no_sh = ['S0023','TCH06'] 
subjs = [s for s in subjs if s not in exclude and s not in no_sh]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]
el_color = color_palette.loc[color_palette['color_id']=='consistent']['hex'].values[0]
sh_color = color_palette.loc[color_palette['color_id']=='inconsistent']['hex'].values[0]

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

### Task schematic
**double check times for this!**

In [None]:
# QC: Look at all single sentence trials in a block
subj, block, elec = "S0006", "B2", "RPPST13"
blockid = "_".join([subj,block])
# Read textgrid
spkr_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                        f"{blockid}_spkr_sentence.TextGrid")
with open(spkr_grid, 'r') as f:
    tg = textgrid.TextGrid(f.read())
spkr_sentences = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp']])[:nsents-1]
# Load audio
audio_folder = os.path.join(data_path, f"sub-{subj}", blockid, "Audio")
spkr_audio, wav_fs = librosa.load(os.path.join(audio_folder,f"{blockid}_spkr.wav"),sr=None)
# Load raw
raw = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band",
                                       "ecog_hilbAA70to150.fif"), preload=True, verbose=False)
raw_fs = raw.info['sfreq']

spkr_tmin, spkr_tmax = spkr_sentences[1,:2].astype(float)
sen_trs = spkr_sentences[1,2]
spkr_clip = spkr_audio[int(wav_fs*spkr_tmin):int(wav_fs*spkr_tmax)]
spkr_tmin_samp, spkr_tmax_samp = int(spkr_tmin*raw_fs), int(spkr_tmax*raw_fs)
spkr_resp = raw.get_data(picks=[elec]).squeeze()[spkr_tmin_samp:spkr_tmax_samp]
# Plot
plt.figure(figsize=(8,4))
spkr_clip = spkr_audio.copy()
spkr_clip[int(wav_fs*27):int(wav_fs*spkr_tmax)] = 0 
spkr_clip = spkr_clip[int(wav_fs*spkr_tmin):int(wav_fs*spkr_tmax)]
plt.plot(spkr_clip/spkr_clip.max(), color=spkr_color)
plt.axis('off')
plt.title(sen_trs[0] + sen_trs[1:].lower())
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf",f"single_trial_waveform_{blockid}_{elec}.pdf"))

### Single electrode plots with CI visualized
Consistency index ($CI$) is formalized as:

$CI_{n} = \frac{1}{t}\sum\limits_{t=0}^{t=1}{H_\gamma Incon_{n,t}}-{H_\gamma Con_{n,t}}$ , 

where $H\gamma$ is the averaged high gamma activity averaged across a time window of interest for either inconsistent ($Incon$) or consistent ($Con$) playback. Here we will calculate $CI$ in a 0-1000ms window relative to sentence onset.

In [None]:
# Epoch data
epochs = dict(); ch_names = dict()
for s in tqdm(subjs):
    epochs[s] = dict()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        epochs[s]['el'] = epoch_data(subj,blocks[subj],git_path,data_path,condition='el',tmin=-0.5,tmax=2)
        epochs[s]['sh'] = epoch_data(subj,blocks[subj],git_path,data_path,condition='sh',tmin=-0.5,tmax=2)
        epochs[s]['mic'] = epoch_data(subj,blocks[subj],git_path,data_path,condition='mic',tmin=-0.5,tmax=2)
        epochs[s]['click'] = epoch_data(subj,blocks[subj],git_path,data_path,condition='spkr',click=True,tmin=-0.5,tmax=2)

In [None]:
# Calc CI
ci_tmin = 0; ci_tmax = 1; ci_inds = [np.where(x==ci_tmin)[0][0],np.where(x==ci_tmax)[0][0]]
ci = dict()
for s in subjs:
    subj_ci = []
    for i,ch in enumerate(epochs[s]['el'].info['ch_names']):
        el_resp = epochs[s]['el'].get_data(picks=[ch]).squeeze()[:,ci_inds[0]:ci_inds[1]].mean(0).mean(0)
        sh_resp = epochs[s]['sh'].get_data(picks=[ch]).squeeze()[:,ci_inds[0]:ci_inds[1]].mean(0).mean(0)
        subj_ci.append(sh_resp-el_resp)
    ci[s] = np.array(subj_ci)
ci_min = np.hstack((list(ci.values()))).min()
ci_max = np.hstack((list(ci.values()))).max()
for s in subjs:
    ci[s] = ((ci[s]-ci_min)/(ci_max-ci_min)) * 2 - 1
# Normalize values 0 to 1 (for plotting)
max_val = np.hstack((list(ci.values()))).max()
min_val = np.hstack((list(ci.values()))).min()
normed_ci = dict()
for s in subjs:
    subj_ci = []
    for ch in ci[s]:
        subj_ci.append((ch-min_val)/(max_val-min_val))
    normed_ci[s] = np.array(subj_ci)
elec_colors_rgb = dict()
elec_colors_hex = dict()
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:6991EF-20:6991EF-45:FFFFFF-50:FFFFFF-55:FFFFFF-80:D38043-100:D38043
    (0.000, (0.412, 0.569, 0.937)),
    (0.200, (0.412, 0.569, 0.937)),
    (0.450, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.550, (1.000, 1.000, 1.000)),
    (0.800, (0.827, 0.502, 0.263)),
    (1.000, (0.827, 0.502, 0.263))))
for s in subjs:
    elec_colors_rgb[s] = [cmap(f) for f in normed_ci[s]]
    elec_colors_hex[s] = np.array(
        [gkc.rgb_to_hex(int(ec[0]*255),int(ec[1]*255),int(ec[2]*255)) for ec in elec_colors_rgb[s]])

In [None]:
subj, elec = "S0019", "LG28"
x = epochs[subj]['el'].times
ch_names = epochs[subj]['el'].info['ch_names']
ch_idx = ch_names.index(elec)
# Load anat
if subj not in no_imaging:
    patient = img_pipe.freeCoG(f"{subj}_complete",hem='stereo',subj_dir=data_path)
    anat = patient.get_elecs()['anatomy']; anat_idx = [a[0][0] for a in anat].index(elec)
    fs_roi = anat[anat_idx][3][0]
else:
    fs_roi = "anatomy unknown"
r,g,b = elec_colors_rgb[s][ch_idx]; elec_ci = ci[s][ch_idx]
fig = plt.figure(figsize=(6,5)); gs = GridSpec(12, 10, figure=fig)
# plot erp
ax = fig.add_subplot(gs[:,:-1])
y2 = epochs[subj]['el'].get_data(picks=[elec]).squeeze(); y2a,y2b = sem(y2)
ax.plot(x,y2.mean(0),color=el_color); ax.fill_between(x,y2a,y2b,color=el_color,alpha=0.3)
y3 = epochs[subj]['sh'].get_data(picks=[elec]).squeeze(); y3a,y3b = sem(y3)
ax.plot(x,y3.mean(0),color=sh_color); ax.fill_between(x,y3a,y3b,color=sh_color,alpha=0.3)
ax.set_title(f"{subj} {elec}\n{fs_roi}", fontsize=14, fontweight='bold', color=(r,g,b))
ax.axvline(0,color='k')
xlims = [x[0], x[-1]]; ax.set_xlim(xlims); ax.set_ylim([-0.65,1])
xticks = np.round(np.arange(xlims[0],xlims[-1]+.5,.5),decimals=1)
ax.set_xticks(xticks); ax.set_xticklabels(xticks,fontsize=12)
ax = fig.add_subplot(gs[:,-1])
cbar = np.array([cmap(f)[:3] for f in np.linspace(0,1,100)])
for y,(cr,cg,cb) in enumerate(cbar):
    ax.fill_between([0,1],y,y+1,color=(cr,cg,cb))
ax.axhline(value_norm*100,color='k')
ax.set_xlim([0,1]); ax.set_ylim([0,100]); ax.axis('off'); gs.tight_layout(figure=fig);
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf",f"{subj}_{elec}_ci.pdf"))

### Strip plot

In [None]:
# Get values
excluded_chs = dict()
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
# Initialize dicts for plot
ci_by_roi = pd.DataFrame(columns=["subj","ch_name","hem","fs_roi","condensed_roi","gross_anat",
                                  "ci_norm","ci_native"])
for hem in ['lh','rh']:
    for s in tqdm(hems.keys()):
        fif_ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,"HilbAA_70to150_8band",
                "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
        pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
        fs_ch_names = [a[0][0] for a in pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['elecmatrix']]
        fs_labels = [a[3][0] for a in pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']]
        if hem in hems[s]:
            excl_ch_names = list(excl_df.loc[excl_df['subject']==s]['channel'].values)
            for elecfile_idx,ch in enumerate(fs_ch_names):
                condensed_roi = imaging_utils.condense_roi(fs_labels[elecfile_idx])
                gross_anat = imaging_utils.gross_anat(fs_labels[elecfile_idx])
                if ch not in excl_ch_names and ch in fif_ch_names:
                    fif_idx = fif_ch_names.index(ch)
                    if condensed_roi in ['frontal','temporal','parietal','occipital','precentral',
                                         'postcentral','insula']:
                        new_row = pd.DataFrame({
                            "subj":[s],"ch_name":[ch],"hem":[hem],"fs_roi":[fs_labels[elecfile_idx]],
                            "condensed_roi":[condensed_roi],"gross_anat":[gross_anat],
                            "ci_norm":[normed_ci[s][fif_idx]],"ci_native":[ci[s][fif_idx]]})
                        ci_by_roi = ci_by_roi.append(new_row, ignore_index=True)
                    elif condensed_roi == "whitematter":
                        new_row = pd.DataFrame({
                            "subj":[s],"ch_name":[ch],"hem":[hem],"fs_roi":[fs_labels[elecfile_idx]],
                            "condensed_roi":[condensed_roi],"gross_anat":["subcort"],
                            "ci_norm":[normed_ci[s][fif_idx]],"ci_native":[ci[s][fif_idx]]})
                        ci_by_roi = ci_by_roi.append(new_row, ignore_index=True)

In [None]:
all_rois = list(np.unique(ci_by_roi['condensed_roi']))
ignore_rois = ['wm', 'outside_brain']
plot_rois_rh, plot_rois_lh = [], []
nroi_thresh = 3
for roi in all_rois:
    if np.where(np.array(ci_by_roi['condensed_roi'].loc[ci_by_roi['hem']=='lh'].values)==roi)[0].shape[
        0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_lh.append(roi)
    if np.where(np.array(ci_by_roi['condensed_roi'].loc[ci_by_roi['hem']=='rh'].values)==roi)[0].shape[
        0] > nroi_thresh and roi not in ignore_rois:
        plot_rois_rh.append(roi)
print(f"{len(all_rois)} condensed FreeSurfer ROIs in database. Plotting {len(plot_rois_rh)} for RH and {len(plot_rois_lh)} for LH.")
ci_by_roi_rh = ci_by_roi.loc[[r in plot_rois_rh for r in ci_by_roi['condensed_roi']]]
ci_by_roi_rh = ci_by_roi_rh.loc[ci_by_roi_rh['hem']=='rh']
ci_by_roi_lh = ci_by_roi.loc[[r in plot_rois_lh for r in ci_by_roi['condensed_roi']]]
ci_by_roi_lh = ci_by_roi_lh.loc[ci_by_roi_lh['hem']=='lh']

In [None]:
plot_rois = list(np.unique(plot_rois_rh + plot_rois_lh))
roi_order = [
    0,6,9,10,4, # temporal: HG/PT/STG/STS/MTG
    17,18,7,8,3,1,2,5, # frontal: PreCG/CS/SFG/SFS/MFG/IFG/IFS/OFC
    11,19,16, # parietal: angular,supramar, postCG
    12,14,15,13 # insular: ant/post/sup/inf
]
plot_rois = list(np.array(plot_rois)[roi_order])
ci_by_roi_condensed = ci_by_roi.loc[[r in plot_rois for r in ci_by_roi['condensed_roi']]]
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:6991EF-20:6991EF-45:FFFFFF-50:FFFFFF-55:FFFFFF-80:D38043-100:D38043
    (0.000, (0.412, 0.569, 0.937)),
    (0.200, (0.412, 0.569, 0.937)),
    (0.450, (1.000, 1.000, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (0.550, (1.000, 1.000, 1.000)),
    (0.800, (0.827, 0.502, 0.263)),
    (1.000, (0.827, 0.502, 0.263))))
ci_native_values = []; ci_norm_values = []
for row in ci_by_roi.values:
    s, ch, _, _, _, _, ci_norm, ci_native = row
    ci_native_values.append(ci_native)
    ci_norm_values.append(ci_norm)
strip_palette = dict()
for i,v in enumerate(ci_norm_values):
    strip_palette[ci_native_values[i]] = cmap(v)
# Plot
fig = plt.figure(figsize=(4,12))
cax = sns.stripplot(data=ci_by_roi_condensed, x='ci_native', y='condensed_roi',order=plot_rois,
                    hue='ci_native',palette=strip_palette,linewidth=0.5,edgecolor='k',size=10)
cax.get_legend().remove()
plt.axvline(0,color='k',ls='--',lw=0.75)
plt.gca().set_xlim([-1.1,1.1])
plt.xlabel("Consistency index (Z)", fontsize=12); plt.ylabel("Condensed ROI", fontsize=12)
plt.gca().set_yticks(np.arange(len(plot_rois))); plt.gca().set_yticklabels(plot_rois);
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf","ci_by_roi_scatter.pdf"))

### Single electrode plots with production + click also

In [None]:
subj, elec = "TCH14", "LPIALG9"
if subj not in no_imaging:
    pt = img_pipe.freeCoG(f"{subj}_complete",hem="stereo",subj_dir=data_path)
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0] for a in anat]; fs_rois = [a[3][0] for a in anat]
    if elec in fs_ch_names:
        fs_idx = fs_ch_names.index(elec)
        fs_roi = fs_rois[fs_idx]
    else:
        fs_roi = "Anatomy unavailable"
else:
    fs_roi = "Anatomy unavailable"
fig = plt.figure(figsize=(5,5))
x = epochs[s]['mic'].times
y1 = epochs[s]['mic'].get_data(picks=[elec]).squeeze(); y1b,y1a = sem(y1)
plt.plot(x,y1.mean(0),color=mic_color); plt.fill_between(x,y1b,y1a,color=mic_color,alpha=0.3)
y2 = epochs[s]['el'].get_data(picks=[elec]).squeeze(); y2a,y2b = sem(y2)
plt.plot(x,y2.mean(0),color=el_color); plt.fill_between(x,y2a,y2b,color=el_color,alpha=0.3)
y3 = epochs[s]['sh'].get_data(picks=[elec]).squeeze(); y3a,y3b = sem(y3)
plt.plot(x,y3.mean(0),color=sh_color); plt.fill_between(x,y3a,y3b,color=sh_color,alpha=0.3)
y4 = epochs[s]['click'].get_data(picks=[elec]).squeeze(); y4a,y4b = sem(y4)
plt.plot(x,y4.mean(0),color=click_color); plt.fill_between(x,y4a,y4b,color=click_color,alpha=0.3)
plt.axvline(0,color='k'); plt.gca().set_xlim([-0.5,2]); plt.gca().set_ylim([-0.7,1.6])
plt.title(f"{subj} {elec}\n{fs_roi}");
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf",f"{subj}_{elec}.pdf"))

### Colorbars

In [None]:
cmap = LinearSegmentedColormap.from_list('my_gradient', (
    # Edit this gradient at https://eltos.github.io/gradient/#0:9ECBFF-25:9ECBFF-40:B2B2B2-50:B2B2B2-60:B2B2B2-75:DDCB76-100:DDCB76
    (0.000, (0.620, 0.796, 1.000)),
    (0.500, (1.000, 1.000, 1.000)),
    (1.000, (0.867, 0.796, 0.463))))
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf","legend_1d.pdf"))

In [None]:
# Read in colormap from .png
cmap_2d = plt.imread(os.path.join(git_path,"figures","figure_5","msh_elsh_splinesqrt_2d_bu_or.png"))
plt.figure(figsize=(5,5))
plt.imshow(cmap_2d)
plt.grid(False)
plt.xlabel("RGB 0-255")
plt.ylabel("RGB 0-255");
plt.savefig(os.path.join(git_path,"figures","figure_4","pdf","legend_2d.pdf"))