In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
import h5py
import pymf3

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"analysis","mtrf"))
import mtrf_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]

models = ['model1','model2','model3','model4']
features = {m : mtrf_utils.get_feats(m, mode="ecog") for m in models}
results = pd.read_csv(os.path.join(git_path,"analysis","mtrf","results.csv"))

### load mtrf results

In [None]:
mtrf = dict()
for s in tqdm(subjs):
    mtrf[s] = dict()
    mtrf[s]['ch_names'] = results.loc[(results['subject']==s)&(results['model']=='model1')]['channel'].values
    blockid = "_".join([s,blocks[s][0]])
    mtrf[s]['fif_ch_names'] = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,
        "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
    for m in models:
        mtrf[s][m] = dict(); mtrf[s][m]['r'], mtrf[s][m]['p'], mtrf[s][m]['w'] = dict(), dict(), dict()
        with h5py.File(os.path.join(git_path,"analysis","mtrf","h5","weights",f"{s}_weights.hdf5"),'r') as f:
            subj_wts = np.array(f.get(m))
        for ch in mtrf[s]['ch_names']:
            mtrf[s][m]['r'][ch] = results.loc[(results['subject']==s)&(results['model']==m)&(
                results['channel']==ch)]['r_value'].values[0]
            mtrf[s][m]['p'][ch] = results.loc[(results['subject']==s)&(results['model']==m)&(
                results['channel']==ch)]['p_value'].values[0]
            ch_idx = mtrf[s]['fif_ch_names'].index(ch); mtrf[s][m]['w'][ch] = subj_wts[:,:,ch_idx].T

# Extract significant weights, corrs
sig_wts, sig_corrs = dict(), dict()
for s in tqdm(subjs):
    sig_wts[s], sig_corrs[s] = dict(), dict()
    for m in models:
        nchans = wts[s][m].shape[2]
        sig_wts[s][m] = np.zeros((len(delays),n_feats[m],nchans)); sig_corrs[s][m] = np.zeros((nchans))
        for i in np.arange(nchans):
            if pvals[s][m][i] < 0.01:
                sig_wts[s][m][:,:,i] = wts[s][m][:,:,i]; sig_corrs[s][m][i] = corrs[s][m][i]

### regression schematic

In [None]:
subj, block, elec = "S0006", "B2", "RPPST13"
blockid = "_".join([subj,block])
mic_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                        f"{blockid}_mic_sentence.TextGrid")
with open(mic_grid,'r') as f:
    tg = textgrid.TextGrid(f.read())
mic_sentences = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp']])
nsents = np.ceil(mic_sentences.shape[0]/2).astype(int)
mic_sentences = mic_sentences[1:nsents]
spkr_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                         f"{blockid}_spkr_sentence.TextGrid")
with open(spkr_grid, 'r') as f:
    tg = textgrid.TextGrid(f.read())
spkr_sentences = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp']])[:nsents-1]
mic_tmin, mic_tmax = mic_sentences[1,:2].astype(float)
spkr_tmin, spkr_tmax = spkr_sentences[1,:2].astype(float)
sen_trs = spkr_sentences[1,2]
audio_folder = os.path.join(data_path, f"sub-{subj}", blockid, "Audio")
mic_audio, wav_fs = librosa.load(os.path.join(audio_folder,f"{blockid}_mic.wav"),sr=None)
mic_clip = mic_audio[int(wav_fs*mic_tmin):int(wav_fs*mic_tmax)]
spkr_audio, wav_fs = librosa.load(os.path.join(audio_folder,f"{blockid}_spkr.wav"),sr=None)
spkr_clip = spkr_audio[int(wav_fs*spkr_tmin):int(wav_fs*spkr_tmax)]
# Load neural resp
raw = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band",
                                       "ecog_hilbAA70to150.fif"), preload=True, verbose=False)
raw_fs = raw.info['sfreq']
mic_tmin_samp, mic_tmax_samp = int(mic_tmin*raw_fs), int(mic_tmax*raw_fs)
mic_resp = raw.get_data(picks=[elec]).squeeze()[mic_tmin_samp:mic_tmax_samp]
spkr_tmin_samp, spkr_tmax_samp = int(spkr_tmin*raw_fs), int(spkr_tmax*raw_fs)
spkr_resp = raw.get_data(picks=[elec]).squeeze()[spkr_tmin_samp:spkr_tmax_samp]
mic_ph_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                           f"{blockid}_mic.TextGrid")
with open(mic_ph_grid,'r') as f:
    tg = textgrid.TextGrid(f.read())
mic_words = np.array([row for row in tg.tiers[1].simple_transcript if row[2] not in ['','sp','{NS}']])
mic_wr_idxs = [i for i,row in enumerate(mic_words) if float(row[0])>=mic_tmin and float(row[1])<=spkr_tmax]
mic_phones = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp','ns']])
mic_ph_idxs = [i for i,row in enumerate(mic_phones) if float(row[0])>=mic_tmin and float(row[1])<=spkr_tmax]
spkr_ph_grid = os.path.join(git_path,"preprocessing","events","textgrids",subj,blockid,
                            f"{blockid}_spkr.TextGrid")
with open(spkr_ph_grid, 'r') as f:
    tg = textgrid.TextGrid(f.read())
spkr_words = np.array([row for row in tg.tiers[0].simple_transcript if row[2] not in ['','sp','{NS}']])
spkr_wr_idxs = [i for i,row in enumerate(spkr_words) if float(row[0])>=mic_tmin and float(row[1])<=spkr_tmax]
spkr_phones = np.array([row for row in tg.tiers[1].simple_transcript if row[2] not in ['','sp','ns']])
spkr_ph_idxs = [i for i,row in enumerate(spkr_phones) if float(row[0])>=mic_tmin and float(row[1])<=spkr_tmax]

In [None]:
features_dict = {
                'dorsal': ['y','w','k','kcl', 'g','gcl','eng','ng'],
                'coronal': ['ch','jh','sh','zh','s','z','t','tcl','d','dcl','n','th','dh','l','r'],
                'labial': ['f','v','p','pcl','b','bcl','m','em','w'],
                'high': ['uh','ux','uw','iy','ih','ix','ey','eh','oy'],
                'front': ['iy','ih','ix','ey','eh','ae','ay'],
                'low': ['aa','ao','ah','ax','ae','aw','ay','axr','ow','oy'],
                'back': ['aa','ao','ow','ah','ax','ax-h','uh','ux','uw','axr','aw'],
                'plosive': ['p','pcl','t','tcl','k','kcl','b','bcl','d','dcl','g','gcl','q'],
                'fricative': ['f','v','th','dh','s','sh','z','zh','hh','hv','ch','jh'],
                'syllabic': ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix', 'iy','ow', 'oy','uh', 'uw', 'ux'],
                'nasal': ['m','em','n','en','ng','eng','nx'],
                'voiced':   ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix', 'iy','ow', 'oy','uh', 'uw', 'ux','w','y','el','l','r','dh','z','v','b','bcl','d','dcl','g','gcl','m','em','n','en','eng','ng','nx','q','jh','zh'],
                'obstruent': ['b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx','f', 'g', 'gcl', 'hh', 'hv','jh', 'k', 'kcl', 'p', 'pcl', 'q', 's', 'sh','t', 'tcl', 'th','v','z', 'zh','q'],
                'sonorant': ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay','eh','ey','ih', 'ix', 'iy','ow', 'oy','uh', 'uw', 'ux','w','y','el','l','r','m', 'n', 'ng', 'eng', 'nx','en','em'],
        }
feats = np.zeros((len(features['model1']),int(wav_fs*(spkr_tmax-mic_tmin))))
for ii,i in enumerate(mic_ph_idxs):
    onset = int((float(mic_phones[i,0])-mic_tmin)*wav_fs);offset = int((float(mic_phones[i,1])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",mic_phones[i,2]).lower()
    for i, feature in enumerate(list(features_dict.keys())):
        if trs in features_dict[feature]:
            feats[i,onset:offset] = 1
    feats[15,onset:offset] = 1
for ii,i in enumerate(spkr_ph_idxs):
    onset = int((float(spkr_phones[i,0])-mic_tmin)*wav_fs);offset = int((float(spkr_phones[i,1])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",spkr_phones[i,2]).lower()
    for i, feature in enumerate(list(features_dict.keys())):
        if trs in features_dict[feature]:
            feats[i,onset:offset] = 1
    feats[14,onset:offset] = 1; feats[16,onset:offset] = 1
plt.figure(figsize=(16,7)); plt.subplot(2,1,2)
plt.imshow(feats,aspect='auto',interpolation='nearest')
plt.gca().set_xlim([0,int(wav_fs*(spkr_tmax-mic_tmin))])
for ii,i in enumerate(mic_ph_idxs):
    onset = int((float(mic_phones[i,0])-mic_tmin)*wav_fs);offset = int((float(mic_phones[i,0])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",mic_phones[i,2]).lower()
    plt.text(onset,-1-(ii%2),trs,ha='left',va='center',color=mic_color,fontweight='bold')
    plt.axvline(offset,color='k',lw=0.5)
for ii,i in enumerate(spkr_ph_idxs):
    onset = int((float(spkr_phones[i,0])-mic_tmin)*wav_fs);offset = int((float(spkr_phones[i,0])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",spkr_phones[i,2]).lower()
    plt.text(onset,-1-(ii%2),trs,ha='left',va='center',color=spkr_color,fontweight='bold')
    plt.axvline(offset,color='k',lw=0.5)
plt.gca().set_facecolor('white'); plt.grid(False)
plt.gca().set_yticks(np.arange(len(features['model1']))); plt.gca().set_yticklabels(features['model1'])
plt.ylabel("Feature")
plt.gca().set_xticklabels([])
plt.subplot(2,1,1)
mic_clip = mic_audio.copy(); mic_clip[int(wav_fs*mic_tmax):] = 0
mic_clip = mic_clip[int(wav_fs*mic_tmin):int(wav_fs*spkr_tmax)]
spkr_clip = spkr_audio.copy(); spkr_clip[int(wav_fs*27):int(wav_fs*mic_tmax)] = 0 
spkr_clip = spkr_clip[int(wav_fs*mic_tmin):int(wav_fs*spkr_tmax)]
plt.plot(mic_clip/mic_clip.max(), color=mic_color); plt.plot(spkr_clip/spkr_clip.max(), color=spkr_color)
plt.gca().set_xlim([0,int(wav_fs*(spkr_tmax-mic_tmin))]); plt.gca().set_ylim([-.9,1.35])
for ii,i in enumerate(spkr_wr_idxs):
    onset = int((float(spkr_words[i,0])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",spkr_words[i,2]).lower()
    plt.text(onset,1.05+(ii%2)*.2,trs,ha='left',va='center',color=spkr_color,fontweight='bold')
for ii,i in enumerate(mic_wr_idxs):
    onset = int((float(mic_words[i,0])-mic_tmin)*wav_fs)
    trs = re.sub("\d+","",mic_words[i,2]).lower()
    plt.text(onset,1.05+(ii%2)*.2,trs,ha='left',va='center',color=mic_color,fontweight='bold')
plt.axis('off'); plt.tight_layout();
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","regression_schematic.pdf"))

### r values by gross anatomy barplot

In [None]:
pvals_df = pd.read_csv(os.path.join(git_path,"stats","bootstraps","csv",
                                    "seeg_elec_significance_16_subjs_1000_boots.csv"))
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
excl_ch_names = {s:[c for c in excl_df.loc[excl_df['subject']==s]["channel"].values] for s in subjs}
rdf = pd.DataFrame(columns=['subj','ch_name','fif_idx','fs_idx','fs_roi','condensed_roi','gross_anat','hem',
                            'model','r','p'])
for hem in ['lh','rh']:
    for s in [ss for ss in hems.keys() if hem in hems[ss]]:
        blockid = "_".join([s,blocks[s][0]])
        fif_ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,"HilbAA_70to150_8band",
            "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
        pt = img_pipe.freeCoG(f"{s}_complete",hem=hem,subj_dir=data_path)
        elecs = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['elecmatrix']
        anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
        fs_ch_names = [a[0][0].replace("-","") for a in anat]
        if s == "S0020":
            fs_ch_names = [c.replace("APIOF","AIPOF") for c in fs_ch_names]
        for m in models:
            for elecfile_idx, ch in enumerate(fs_ch_names):
                if ch not in excl_ch_names[s] and ch in fif_ch_names:
                    fif_idx = row[1]['fif_idx']; fs_roi = [a[3][0] for a in anat][elecfile_idx]
                    condensed_roi = imaging_utils.condense_roi(fs_roi)
                    gross_anat = imaging_utils.gross_anat(fs_roi)
                    if corrs[s]['model1'][fif_idx] > .1 or pvals_df.loc[(pvals_df['subj']==s)&(
                        pvals_df['ch_name']==ch)]['spkr_p'].values < 0.05 or pvals_df.loc[(pvals_df[
                        'subj']==s)&(pvals_df['ch_name']==ch)]['mic_p'].values < 0.05:
                        r = corrs[s][m][fif_idx]
                        p = corrs[s][m][fif_idx]
                        if condensed_roi in ['frontal', 'temporal', 'parietal', 'occipital',
                                             'precentral', 'postcentral', 'insula']:
                            new_row = pd.DataFrame({'subj':[s],'ch_name':[ch],'fif_idx':[fif_idx],
                                'fs_idx':[elecfile_idx],'fs_roi':[fs_roi],'condensed_roi':[condensed_roi],
                                'gross_anat':[gross_anat],'hem':[hem],'model':[m],'r':[r],'p':[p]})
                            rdf = rdf.append(new_row,ignore_index=True)
                        elif condensed_roi == 'whitematter':
                            new_row = pd.DataFrame({'subj':[s],'ch_name':[ch],'fif_idx':[fif_idx],
                                'fs_idx':[elecfile_idx],'fs_roi':[fs_roi],'condensed_roi':[condensed_roi],
                                'gross_anat':["subcort"],'hem':[hem],'model':[m],'r':[r],'p':[p]})
                            rdf = rdf.append(new_row,ignore_index=True)               

In [None]:
plot_rois = {'heschls' : ['HG'],'planum_temporale' : ['PT'],'STG' : ['STG'],'STS' : ['STS'],
    'middle_temporal' : ['MTG'],'temporal_pole' : ['temp_pole'],'subcentral' : ['subcentral'],
    'precentral_gyrus' : ['preCG'],'precentral_sulcus' : ['preCS'],'superior_frontal' : ['SFG', 'SFS'],
    'middle_frontal' : ['MFG', 'MFS'],'inferior_frontal' : ['IFG', 'IFS', 'front_operculum'],
    'orbitofrontal' : ['OFC'],'superiorparietal' : ['SPL'],'angular' : ['angular'],
    'supramarginal' : ['supramar'],'postcentral_gyrus' : ['postCG'],'anterior_insula' : ['insula_ant'],
    'inferior_insula' : ['insula_inf'],'posterior_insula'  : ['insula_post'],'superior_insula' : ['insula_sup']}

rdf_condensed = pd.DataFrame(columns=['subj','ch_name','plot_roi','hem','model','r','p'])
for k in plot_rois.keys():
    for roi in plot_rois[k]:
        for i,row in rdf.loc[rdf['condensed_roi']==roi].iterrows():
            new_row = pd.DataFrame({
                'subj':[row['subj']], 'ch_name':[row['ch_name']], 'plot_roi':[k], 'hem':[row['hem']],
                'model':[row['model']], 'r':[row['r']], 'p':[row['p']]
             })
            rdf_condensed = rdf_condensed.append(new_row,ignore_index=True)

In [None]:
cmap = {'frontal' : '#4b72db', # blue
        'temporal' : '#84e16c', # green
        'parietal' : '#9c53b7', # purple
        'occipital' : '#d0c358', # yellow
        'precentral' : '#6ec2d5', # cyan
        'postcentral' : '#cb4779', # magenta
        'insula' : '#e14333', # red
        'subcort' : '#b0b0b0', # dark grey
        'whitematter' : '#5a5a5a', # light grey
        'outside_brain' : '#262626'} # black
palette = ([cmap['temporal']]*6)+([cmap['precentral']]*3)+([cmap['frontal']]*4)+([cmap['parietal']]*3)+([
    cmap['postcentral']])+([cmap['insula']]*4)
fig = plt.figure(figsize=(4,12))
sns.barplot(data=rdf_condensed.loc[rdf_condensed['model']=='model1'], x='r', y='plot_roi',
            order=list(plot_rois.keys()), palette=palette)
plt.ylabel("Condensed ROI", fontsize=12); plt.xlabel("Linear correlation coefficient");
plt.gca().set_yticks(np.arange(len(list(plot_rois.keys()))))
plt.gca().set_yticklabels(list(plot_rois.keys()));
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","corr_by_roi_bar.pdf"))

### r value scatterplot

In [None]:
plt.figure(figsize=(5,5)); corr_min,corr_max = -.3,.7
ycolors = {'model2' : '#c01786','model3' : '#5fd7b9','model4' : '#dda34a'}
xm = 'model1'; xcorrs = np.hstack([corrs[s][xm] for s in subjs])
yms = list(ycolors.keys())
plt.xlabel('Model correlation\n Full model'); plt.ylabel('Model correlation\n Expanded/ablated model')
for i,ym in enumerate(yms):
    ycorrs = np.hstack([corrs[s][ym] for s in subjs])
    plt.scatter(xcorrs,ycorrs,s=5,color=ycolors[ym],zorder=3-i)
plt.plot([corr_min,corr_max],[corr_min,corr_max],color='k',alpha=0.5)
plt.gca().set_xlim([corr_min,corr_max]); plt.gca().set_ylim([corr_min,corr_max])  

### single-electrode receptive fields

In [None]:
model = "model1"; subj, elec = "TCH06", "RTG5" # Update accordingly
if subj not in no_imaging:
    pt = img_pipe.freeCoG(f"{subj}_complete", hem="stereo", subj_dir=data_path)
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0].replace("-","") for a in anat]; fs_labels = [a[3][0] for a in anat]
    if elec.replace("-","") in fs_ch_names:
        fs_idx = fs_ch_names.index(elec.replace("-",""))
        fs_roi = imaging_utils.condense_roi(fs_labels[fs_idx])
    else:
        fs_roi = "Anatomy unavailable"
else:
    fs_roi = "Anatomy unavailable"
ch_idx = ch_names[subj].index(elec); trf = wts[subj][model][:,:,ch_idx].T; vmax = trf.max()
plt.figure(figsize=(6,6))
plt.imshow(trf, aspect="auto", interpolation="nearest", cmap=cm.RdBu_r, vmin=-vmax, vmax=vmax); plt.grid(False)
plt.gca().set_yticks(np.arange(len(features[model]))); plt.gca().set_yticklabels(features[model])
plt.axhline(13.5,color='k',ls='--'); plt.axhline(15.5,color='k',ls='--')
plt.axvline(np.where(delays==0)[0][0],color='k')
plt.gca().set_xticks([0,np.where(delays==0)[0][0],delays.shape[0]])
plt.gca().set_xticklabels([delay_min,0,delay_max])
plt.title(f"{subj} {elec}\n{fs_roi}\nr=%.2f"%(corrs[subj][model][ch_idx])); plt.colorbar();
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","single_trfs",f"{subj}_{elec}_{model}.pdf"))

In [None]:
model = "model2"; subj, elec = "TCH06", "RTG5" # Update accordingly
if subj not in no_imaging:
    pt = img_pipe.freeCoG(f"{subj}_complete", hem="stereo", subj_dir=data_path)
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0].replace("-","") for a in anat]; fs_labels = [a[3][0] for a in anat]
    if elec.replace("-","") in fs_ch_names:
        fs_idx = fs_ch_names.index(elec.replace("-",""))
        fs_roi = imaging_utils.condense_roi(fs_labels[fs_idx])
    else:
        fs_roi = "Anatomy unavailable"
else:
    fs_roi = "Anatomy unavailable"
ch_idx = ch_names[subj].index(elec); trf = wts[subj][model][:,:,ch_idx].T; vmax = trf.max()
plt.figure(figsize=(6,12))
plt.imshow(trf, aspect="auto", interpolation="nearest", cmap=cm.RdBu_r, vmin=-vmax, vmax=vmax); plt.grid(False)
plt.gca().set_yticks(np.arange(len(features[model]))); plt.gca().set_yticklabels(features[model])
plt.axhline(13.5,color='k',ls='--'); plt.axhline(27.5,color='k',ls='--')
plt.axhline(41.5,color='k',ls='--'); plt.axhline(43.5,color='k',ls='--')
plt.axvline(np.where(delays==0)[0][0],color='k')
plt.gca().set_xticks([0,np.where(delays==0)[0][0],delays.shape[0]])
plt.gca().set_xticklabels([delay_min,0,delay_max])
plt.title(f"{subj} {elec}\n{fs_roi}\nr=%.2f"%(corrs[subj][model][ch_idx])); plt.colorbar();
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","single_trfs",f"{subj}_{elec}_{model}.pdf"))

In [None]:
model = "model3"; subj, elec = "TCH06", "RTG5" # Update accordingly
if subj not in no_imaging:
    pt = img_pipe.freeCoG(f"{subj}_complete", hem="stereo", subj_dir=data_path)
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0].replace("-","") for a in anat]; fs_labels = [a[3][0] for a in anat]
    if elec.replace("-","") in fs_ch_names:
        fs_idx = fs_ch_names.index(elec.replace("-",""))
        fs_roi = imaging_utils.condense_roi(fs_labels[fs_idx])
    else:
        fs_roi = "Anatomy unavailable"
else:
    fs_roi = "Anatomy unavailable"
ch_idx = ch_names[subj].index(elec); trf = wts[subj][model][:,:,ch_idx].T; vmax = trf.max()
trf_exp = np.zeros((trf.shape[0]+2, trf.shape[1])); trf_exp[:trf.shape[0],:] = trf
plt.figure(figsize=(6,6))
plt.imshow(trf_exp, aspect="auto", interpolation="nearest", cmap=cm.RdBu_r, vmin=-vmax, vmax=vmax)
plt.grid(False)
plt.gca().set_yticks(np.arange(len(features[model]))); plt.gca().set_yticklabels(features[model])
plt.axhline(13.5,color='k',ls='--'); plt.axhspan(15.5,17.5,color='#262626')
plt.axvline(np.where(delays==0)[0][0],color='k')
plt.gca().set_xticks([0,np.where(delays==0)[0][0],delays.shape[0]])
plt.gca().set_xticklabels([delay_min,0,delay_max])
plt.title(f"{subj} {elec}\n{fs_roi}\nr=%.2f"%(corrs[subj][model][ch_idx])); plt.colorbar();
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","single_trfs",f"{subj}_{elec}_{model}.pdf"))

In [None]:
model = "model4"; subj, elec = "TCH06", "RTG5" # Update accordingly
if subj not in no_imaging:
    pt = img_pipe.freeCoG(f"{subj}_complete", hem="stereo", subj_dir=data_path)
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0].replace("-","") for a in anat]; fs_labels = [a[3][0] for a in anat]
    if elec.replace("-","") in fs_ch_names:
        fs_idx = fs_ch_names.index(elec.replace("-",""))
        fs_roi = imaging_utils.condense_roi(fs_labels[fs_idx])
    else:
        fs_roi = "Anatomy unavailable"
else:
    fs_roi = "Anatomy unavailable"
ch_idx = ch_names[subj].index(elec); trf = wts[subj][model][:,:,ch_idx].T; vmax = trf.max()
trf_exp = np.zeros((trf.shape[0]+2, trf.shape[1])); trf_exp[:14,:] = trf[:14,:]; trf_exp[-2:,:] = trf[-2:,:]
plt.figure(figsize=(6,6))
plt.imshow(trf_exp, aspect="auto", interpolation="nearest", cmap=cm.RdBu_r, vmin=-vmax, vmax=vmax)
plt.grid(False)
plt.gca().set_yticks(np.arange(len(features[model]))); plt.gca().set_yticklabels(features[model])
plt.axhspan(13.5,15.5,color='#262626')
plt.axvline(np.where(delays==0)[0][0],color='k')
plt.gca().set_xticks([0,np.where(delays==0)[0][0],delays.shape[0]])
plt.gca().set_xticklabels([delay_min,0,delay_max])
plt.title(f"{subj} {elec}\n{fs_roi}\nr=%.2f"%(corrs[subj][model][ch_idx])); plt.colorbar();
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","single_trfs",f"{subj}_{elec}_{model}.pdf"))

### colorbars

In [None]:
cmap = cm.Reds
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","legend_1d.pdf"))

In [None]:
cmap = cm.RdBu_r
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_5","pdf","legend_2d.pdf"))