In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
# these are the subjects that don't have inconsistent playback trials and
# are therefore excluded from analysis in this notebook
no_sh = ['S0023','TCH06'] 
subjs = [s for s in subjs if s not in exclude and s not in no_sh]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]

### epoch data

In [None]:
ci_tmin = 0; ci_tmax = 1; erp_tmin = -.5; erp_tmax = 2; reject = None; baseline = None
epochs = dict(); ch_names = dict()
for s in tqdm(subjs):
    epochs[s] = dict(); el_epochs, sh_epochs = [], []
    for b in blocks[s]:
        blockid = f'{s}_{b}'
        raw_fpath = os.path.join(data_path,f"sub-{s}",s,blockid,"HilbAA_70to150_8band",
                                 "ecog_hilbAA70to150.fif")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        ch_names[s] = raw.info['ch_names']
        fs = raw.info['sfreq']
        # Consistent events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_spkr_sn_el.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        el_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
        # Inconsistent events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_spkr_sn_sh.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        sh_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
    epochs[s]['el'] = mne.concatenate_epochs(el_epochs)
    epochs[s]['sh'] = mne.concatenate_epochs(sh_epochs)

### make colormap

In [None]:
cmap_2d = plt.imread(os.path.join(git_path,"figures","figure_4","msh_elsh_splinesqrt_2d_bu_or.png"))
x = epochs[s]['el'].times
ci_inds = [np.where(x==ci_tmin)[0][0],np.where(x==ci_tmax)[0][0]]
all_el, all_sh, all_subj, all_ch = [], [], [], []
for s in subjs:
    for i, ch in enumerate(epochs[s]['el'].info['ch_names']):
        all_el.append(epochs[s]['el'].get_data(picks=[ch]).squeeze()[:,ci_inds[0]:ci_inds[1]].mean(0).mean(0))
        all_sh.append(epochs[s]['sh'].get_data(picks=[ch]).squeeze()[:,ci_inds[0]:ci_inds[1]].mean(0).mean(0))
        all_subj.append(s)
        all_ch.append(ch)
all_el = np.array(all_el); all_sh = np.array(all_sh)
all_subj = np.array(all_subj); all_ch = np.array(all_ch)
all_el_norm, all_sh_norm = [], []
for s in subjs:
    idxs = np.where(all_subj==s)[0]
    all_el_clip = all_el[idxs]; all_sh_clip = all_sh[idxs]
    resp_min = np.hstack((all_el_clip,all_sh_clip)).min(); resp_max = np.hstack((all_el_clip,all_sh_clip)).max()
    norm_el_clip = (all_el_clip + resp_min*-1)/(resp_max + resp_min*-1)
    norm_sh_clip = (all_sh_clip + resp_min*-1)/(resp_max + resp_min*-1)
    all_el_norm.append(norm_el_clip); all_sh_norm.append(norm_sh_clip)
all_el_norm = np.hstack((all_el_norm)); all_sh_norm = np.hstack((all_sh_norm))
# Format to be by-subject
elec_colors_rgb_2d, values = dict(), dict()
for s in subjs:
    elec_colors_rgb_2d[s] = []; values[s] = dict(); values[s]['el'], values[s]['sh'] = dict(), dict()
    idxs = np.where(all_subj==s)[0]
    values[s]['el']['norm'] = all_el_norm[idxs]; values[s]['el']['native'] = all_el[idxs]
    values[s]['sh']['norm'] = all_sh_norm[idxs]; values[s]['sh']['native'] = all_sh[idxs]
    for i in idxs:
        x = int(all_el_norm[i]*255); y = int(all_sh_norm[i]*255); elec_colors_rgb_2d[s].append(cmap_2d[x,y,:])        

In [None]:
# Save to pandas
df = pd.DataFrame(columns=['subj','hem','ch_name','x','y','z','r','g','b','a'])
for hem in ['lh','rh']:
    for s in [ss for s in subjs if hem in hems[ss]]:
        blockid = "_".join([s,blocks[s][0]])
        fif_ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,"HilbAA_70to150_8band",
            "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
        pt = img_pipe.freeCoG(f'{s}_complete',hem=hem, subj_dir=ip)
        e, a = imaging_utils.clip_4mm_elecs(pt,hem=hem,elecfile_prefix="TDT_elecs_all_warped")
        e, a = imaging_utils.clip_outside_brain_elecs(pt,elecmatrix=e,anatomy=a,hem=hem,
                                                      elecfile_prefix="TDT_elecs_all_warped")
        fs_ch_names = [aa[0][0] for aa in a]
        for ch in fs_ch_names:
            if ch.replace('-','') in [c.replace('-','') for c in fif_ch_names]:
                fif_idx = [c.replace('-','') for c in fif_ch_names].index(ch.replace('-',''))
                r,g,b,a = elec_colors_rgb_2d[s][fif_idx]
                elecfile_idx = [c.replace('-','') for c in fs_ch_names].index(ch.replace('-',''))
                x,y,z = e[elecfile_idx,:]
                new_row = pd.DataFrame({'subj':[s],'hem':[hem],'ch_name':[ch],'x':[x],'y':[y],'z':[z],
                                       'r':[r],'g':[g],'b':[b],'a':[a]})
                df = df.append(new_row, ignore_index=True)
df.to_csv(os.path.join(git_path,"figures","figure_4","csv","figure_4_cmap.csv"),index=False)