In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
import h5py
import pymf3

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]

### Load cNMF results

In [None]:
# Load NMF results from h5
bname = 'spkrmicclick'; epoch_types = ['spkr','mic','click']; plt_colors = [spkr_color,mic_color,click_color]
nmf = dict(); nmf['resp'] = dict()
h5_fpath = os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}.hf5")
with h5py.File(h5_fpath,'r') as f:
    num_bases = np.array(f.get('num_bases')); nmf['pve'] = np.array(f.get('pve'))
    for nb in num_bases:
        nmf[nb] = dict(); nmf[nb]['W'] = np.array(f.get(f'{nb}_bases/W'))
        nmf[nb]['H'] = np.array(f.get(f'{nb}_bases/H'))
    for epoch_type in epoch_types:
        nmf['resp'][epoch_type] = np.array(f.get(f'resp/{epoch_type}'))
nmf['ch_names'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                          f"NMF_grouped_{bname}_ch_names.txt"),dtype=str)
nmf['all_subjs'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                           f"NMF_grouped_{bname}_all_subjs.txt"), dtype=str)

In [None]:
# Convert resp into a subject-by-subject format
resp, ch_names = dict(), dict()
for s in subjs:
    nmf_inds = np.where(nmf['all_subjs']==s)[0]
    if nmf_inds.shape[0] == 0:
        warnings.warn(f"Subject {s} missing from {bname} NMF, skipping...")
    else:
        ch_names[s] = nmf['ch_names'][nmf_inds]; resp[s] = dict()
        for epoch_type in epoch_types:
            resp[s][epoch_type] = nmf['resp'][epoch_type][nmf_inds,:]

In [None]:
# Truncate weights according to a kneepoint
nmf['k'] = 9 # 86% PVE
# clip nmf.W to k clusters, and reorder by weight
clust = nmf[nmf['k']]; W = clust['W']; clusters = dict(); clusters['W'] = np.zeros(W.shape)
clusters['all_subjs'] = np.zeros(W.shape).astype(str); clusters['ch_names'] = np.zeros(W.shape).astype(str)
clusters['resp'] = {epoch_type:np.zeros((nmf['k'],W.shape[0],
                                         nmf['resp'][epoch_types[0].shape[1]])) for epoch_type in epoch_types}
clust_sort = np.argsort(W.sum(0))
for ri,ai in enumerate(clust_sort): # relative/absolute index
    sorted_idxs = np.flip(np.argsort(W[:,ai])); clusters['W'][:,ri] = W[:,ai][sorted_idxs]
    clusters['all_subjs'][:,ri] = nmf['all_subjs'][sorted_idxs]
    clusters['ch_names'][:,ri] = nmf['ch_names'][sorted_idxs]
    for epoch_type in epoch_types:
        clusters['resp'][epoch_type][ri,:] = nmf['resp'][epoch_type][sorted_idxs]   

In [None]:
# normalize weights between 0 and 1
normed_wts = np.zeros(clusters['W'].shape)
for i in np.arange(clusters['W'].shape[1]):
    wtmin = clusters['W'][:,i].min()
    normed_wts[:,i] = (clusters['W'][:,i]-wtmin)/(clusters['W'][:,i]-wtmin).max()

### Within-cluster colormaps

In [None]:
within_cluster_cmap = dict()
for clust in np.arange(nmf['k']):
    within_cluster_cmap[clust] = [cm.Reds(w) for w in normed_wts[:,clust]]

In [None]:
# Save to csv
save_clusters = [0, # Dual onset (these are zero-indexed)
                 1, # Onset suppression
                 2] # Pre-articulatory motor
for hem in ['lh','rh']:
    for s in [ss for s in subjs if hem in hems[ss]]:
        blockid = "_".join([s,blocks[s][0]])
        fif_ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,"HilbAA_70to150_8band",
            "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
        pt = img_pipe.freeCoG(f'{s}_complete',hem=hem, subj_dir=ip)
        e, a = imaging_utils.clip_4mm_elecs(pt,hem=hem,elecfile_prefix="TDT_elecs_all_warped")
        e, a = imaging_utils.clip_outside_brain_elecs(pt,elecmatrix=e,anatomy=a,hem=hem,
                                                      elecfile_prefix="TDT_elecs_all_warped")
        fs_ch_names = [aa[0][0] for aa in a]
        for clust in save_clusters:
            df = pd.DataFrame(columns=['subj','hem','ch_name','x','y','z','r','g','b','a'])
            subj_idxs = list(np.where(clusters['all_subjs'][:,clust]==s)[0])
            for idx in subj_idxs:
                ch_name = clusters['ch_names'][idx,clust]
                if ch_name != "EKG1":
                    fif_idx = fif_ch_names.index(ch_name.replace("-",""))
                    elecfile_idx = fs_ch_names.index(ch_name.replace("-",""))
                    x,y,z = e[elecfile_idx,:]
                    r,g,b,a = within_cluster_cmap[clust][idx]
                    new_row = pd.DataFrame({'subj':[s],'hem':[hem],'ch_name':[ch],'x':[x],'y':[y],'z':[z],
                                            'r':[r],'g':[g],'b':[b],'a':[a]})
                    df = df.append(new_row, ignore_index=True)
            df.to_csv(os.path.join(git_path,"figures","figure_3","csv",
                                   f"figure_3_cmap_within_clust_{clust+1}.csv"),index=False)

### Across-cluster colormap
Comparing cluster 1 (dual onset) to cluster 2 (onset suppression)

In [None]:
cmap_2d = plt.imread(os.path.join(git_path,"figures","figure_3","RdBuPr_splinesqrt22.png"))
across_cluster_cmap, across_cluster_values = dict(), dict(); completed_cmaps = []
across_cluster_values['native'], across_cluster_values['norm'] = dict(), dict()
for xc in np.arange(nmf['k']): # x-axis cluster
    for yc in np.arange(nmf['k']): # y-axis cluster
        if xc != yc:
            # Don't compare diagonals
            if [xc,yc] not in completed_cmaps and [yc,xc] not in completed_cmaps:
                # If we have already made a colormap for this pair don't make it again
                completed_cmaps.append([xc,yc])
                x_ch_names = list(clusters['ch_names'][:,xc]); x_subjs = list(clusters['all_subjs'][:,xc])
                x_ch_names_ext = [f"{x_subjs[i]}_{c}" for i,c in enumerate(x_ch_names)]
                y_ch_names = list(clusters['ch_names'][:,yc]); y_subjs = list(clusters['all_subjs'][:,yc])
                y_ch_names_ext = [f"{y_subjs[i]}_{c}" for i,c in enumerate(y_ch_names)]
                x_W = clusters['W'][:,xc]; y_W = clusters['W'][:,yc]
                # Reorder y_W so that the channels line up with each other
                # This means the channel name follows xc not yc
                y_inds = []
                for ch in x_ch_names_ext:
                    y_inds.append(y_ch_names_ext.index(ch))
                y_W = y_W[y_inds]; across_cluster_values['native'][f"{xc}-{yc}"] = x_W - y_W
                xymin = np.array([x_W.min(), y_W.min()]).min(); x_W = x_W-xymin; y_W = y_W-xymin
                xymax = np.array([x_W.max(), y_W.max()]).max(); x_W = x_W/xymax; y_W = y_W/xymax
                across_cluster_values['norm'][f"{xc}-{yc}"] = x_W - y_W
                # axis 2 is RGB values so we can get them by indexing ax0 by xc and ax1 by yc
                across_cluster_cmap[f"{xc}-{yc}"] = [
                    cmap_2d[round(x_W[i]*255),round(y_W[i]*255),:] for i in np.arange(x_W.shape[0])]

In [None]:
# Save to csv
save_contrasts = ['0-1'] # Dual onset vs onset suppression
for hem in ['lh','rh']:
    for s in [ss for s in subjs if hem in hems[ss]]:
        blockid = "_".join([s,blocks[s][0]])
        fif_ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",blockid,"HilbAA_70to150_8band",
            "ecog_hilbAA70to150.fif"), preload=False, verbose=False).info['ch_names']
        pt = img_pipe.freeCoG(f'{s}_complete',hem=hem, subj_dir=ip)
        e, a = imaging_utils.clip_4mm_elecs(pt,hem=hem,elecfile_prefix="TDT_elecs_all_warped")
        e, a = imaging_utils.clip_outside_brain_elecs(pt,elecmatrix=e,anatomy=a,hem=hem,
                                                      elecfile_prefix="TDT_elecs_all_warped")
        fs_ch_names = [aa[0][0] for aa in a]
        for contrast in save_contrasts:
            xc, yc = [int(d) for d in contrast.split('-')]
            df = pd.DataFrame(columns=['subj','hem','ch_name','x','y','z','r','g','b','a'])
            subj_idxs = list(np.where(clusters['all_subjs'][:,xc]==s)[0])
            for idx in subj_idxs:
                ch_name = clusters['ch_names'][idx,clust]
                if ch_name != "EKG1":
                    fif_idx = fif_ch_names.index(ch_name.replace("-",""))
                    elecfile_idx = fs_ch_names.index(ch_name.replace("-",""))
                    x,y,z = e[elecfile_idx,:]
                    r,g,b = across_cluster_cmap[contrast][idx]; a=1.
                    new_row = pd.DataFrame({'subj':[s],'hem':[hem],'ch_name':[ch],'x':[x],'y':[y],'z':[z],
                                            'r':[r],'g':[g],'b':[b],'a':[a]})
                    df = df.append(new_row, ignore_index=True)
            df.to_csv(os.path.join(git_path,"figures","figure_3","csv",
                                   f"figure_3_cmap_across_clusts_{xc+1}-{yc+1}.csv"),index=False)