In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
import h5py
import pymf3

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"analysis","mtrf"))
import mtrf_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]

models = ['model1','model2','model3','model4']
features = {m : mtrf_utils.get_feats(m, mode="ecog") for m in models}
results = pd.read_csv(os.path.join(git_path,"analysis","mtrf","results.csv"))

### load mtrf results

In [None]:
mtrf = dict()
for s in tqdm(subjs):
    mtrf[s] = dict()
    mtrf[s]['ch_names'] = results.loc[(results['subject']==s)&(results['model']=='model1')]['channel'].values
    blockid = "_".join([s,blocks[s][0]])
    mtrf[s]['fif_ch_names'] = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,
        "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
    for m in models:
        mtrf[s][m] = dict(); mtrf[s][m]['r'], mtrf[s][m]['p'], mtrf[s][m]['w'] = dict(), dict(), dict()
        with h5py.File(os.path.join(git_path,"analysis","mtrf","h5","weights",f"{s}_weights.hdf5"),'r') as f:
            subj_wts = np.array(f.get(m))
        for ch in mtrf[s]['ch_names']:
            mtrf[s][m]['r'][ch] = results.loc[(results['subject']==s)&(results['model']==m)&(
                results['channel']==ch)]['r_value'].values[0]
            mtrf[s][m]['p'][ch] = results.loc[(results['subject']==s)&(results['model']==m)&(
                results['channel']==ch)]['p_value'].values[0]
            ch_idx = mtrf[s]['fif_ch_names'].index(ch); mtrf[s][m]['w'][ch] = subj_wts[:,:,ch_idx].T

In [None]:
# Normalize r across two models
finished = []
normed_r_across = {s:dict() for s in subjs}
for xm in models:
    for ym in models:
        if xm != ym:
            dyad = "_".join(np.sort([xm,ym]))
            if dyad not in finished:
                for s in subjs:
                    normed_r_across[s][dyad] = dict()
                    normed_r_across[s][dyad][xm], normed_r_across[s][dyad][ym] = dict(), dict()
                    for ch in mtrf[s]['ch_names']:
                        if mtrf[s][xm]['r'][ch] >= rmax:
                            normed_r_across[s][dyad][xm][ch] = 1.
                        elif mtrf[s][xm]['r'][ch] <= rmin:
                            normed_r_across[s][dyad][xm][ch] = 0.
                        else:
                            normed_r_across[s][dyad][xm][ch] = (
                                mtrf[s][xm]['r'][ch]-rmin)/(rmax-rmin)
                        if mtrf[s][ym]['r'][ch] >= rmax:
                            normed_r_across[s][dyad][ym][ch] = 1.
                        elif mtrf[s][ym]['r'][ch] <= rmin:
                            normed_r_across[s][dyad][ym][ch] = 0.
                        else:
                            normed_r_across[s][dyad][ym][ch] = (
                                mtrf[s][ym]['r'][ch]-rmin)/(rmax-rmin)
                finished.append(dyad)
            else:
                print(f"Already normalized {dyad}")

In [None]:
# Read in colormap from .png
cmap_2d = plt.imread(os.path.join(git_path,"figure_3","RdBuPr_splinesqrt22.png"))
# Make the colormap
across_model_cmap = dict()
for dyad in finished:
    across_model_cmap[dyad] = dict()
    xm, ym = dyad.split("_")
    for s in subjs:
        across_model_cmap[dyad][s] = dict()
        for ch in normed_r_across[s][dyad][xm].keys():
            xR = round(normed_r_across[s][dyad][xm][ch]*255)
            yR = round(normed_r_across[s][dyad][ym][ch]*255)
            across_model_cmap[dyad][s][ch] = cmap_2d[xR,yR,:]

In [None]:
# Save to csv
save_contrasts = ['model1-model2','model1-model3','model1-model4'] # 
for hem in ['lh','rh']:
    for s in [ss for s in subjs if hem in hems[ss]]:
        blockid = "_".join([s,blocks[s][0]])
        if s == "S0020":
            # One device for S0020 is named incorrectly so we have to write an exception for it.
            mtrf[s]['fif_ch_names'] = [c.replace("AIPOF'","APIOF'") for c in mtrf[s]['fif_ch_names']]
        pt = img_pipe.freeCoG(f'{s}_complete',hem=hem, subj_dir=ip)
        e, a = imaging_utils.clip_4mm_elecs(pt,hem=hem,elecfile_prefix="TDT_elecs_all_warped")
        e, a = imaging_utils.clip_outside_brain_elecs(pt,elecmatrix=e,anatomy=a,hem=hem,
                                                      elecfile_prefix="TDT_elecs_all_warped")
        fs_ch_names = [aa[0][0] for aa in a]
        if s == "S0021" and elecfile_prefix == "TDT_elecs_all":
            # Naming error with a few devices in this matfile we need to fix.
            # fs_ch_names = [c.replace("IPPC","PIPPC") for c in fs_ch_names]
            # all_fs_ch_names = [c.replace("IPPC","PIPPC") for c in all_fs_ch_names]
            fs_ch_names = [c.replace("APPI'","ASPPI'") for c in fs_ch_names]
        for dyad in save_contrasts:
            xm, ym = [int(d) for d in dyad.split('-')]
            df = pd.DataFrame(columns=['subj','hem','ch_name','x','y','z','r','g','b','a'])
            for ch in fs_ch_names:
                if ch.replace('-','') in [c.replace('-','') for c in mtrf[s]['fif_ch_names']]:
                    fif_idx = [c.replace('-','') for c in mtrf[s]['fif_ch_names']].index(ch.replace('-',''))
                    elecfile_idx = [c.replace('-','') for c in fs_ch_names].index(ch.replace('-',''))
                    x,y,z = e[elecfile_idx,:]
                    r,g,b = across_model_cmap[dyad][s][mtrf[s]['fif_ch_names'][fif_idx]]; a=1.
                    new_row = pd.DataFrame({'subj':[s],'hem':[hem],'ch_name':[ch],'x':[x],'y':[y],'z':[z],
                                            'r':[r],'g':[g],'b':[b],'a':[a]})
                    df = df.append(new_row, ignore_index=True)
            df.to_csv(os.path.join(git_path,"figures","figure_5","csv",
                                   f"figure_5_cmap_across_models_{xm}-{ym}.csv"),index=False)