In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
from img_pipe import img_pipe
import librosa
import h5py
import pymf3

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]
click_color = color_palette.loc[color_palette['color_id']=='click']['hex'].values[0]

### Load cNMF results

In [None]:
# Load NMF results from h5
bname = 'spkrmicclick'; epoch_types = ['spkr','mic','click']; plt_colors = [spkr_color,mic_color,click_color]
nmf = dict(); nmf['resp'] = dict()
h5_fpath = os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}.hf5")
with h5py.File(h5_fpath,'r') as f:
    num_bases = np.array(f.get('num_bases')); nmf['pve'] = np.array(f.get('pve'))
    for nb in num_bases:
        nmf[nb] = dict(); nmf[nb]['W'] = np.array(f.get(f'{nb}_bases/W'))
        nmf[nb]['H'] = np.array(f.get(f'{nb}_bases/H'))
    for epoch_type in epoch_types:
        nmf['resp'][epoch_type] = np.array(f.get(f'resp/{epoch_type}'))
nmf['ch_names'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                          f"NMF_grouped_{bname}_ch_names.txt"),dtype=str)
nmf['all_subjs'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                           f"NMF_grouped_{bname}_all_subjs.txt"), dtype=str)

In [None]:
# Convert resp into a subject-by-subject format
resp, ch_names = dict(), dict()
for s in subjs:
    nmf_inds = np.where(nmf['all_subjs']==s)[0]
    if nmf_inds.shape[0] == 0:
        warnings.warn(f"Subject {s} missing from {bname} NMF, skipping...")
    else:
        ch_names[s] = nmf['ch_names'][nmf_inds]; resp[s] = dict()
        for epoch_type in epoch_types:
            resp[s][epoch_type] = nmf['resp'][epoch_type][nmf_inds,:]

In [None]:
# Truncate weights according to a kneepoint
nmf['k'] = 9 # 86% PVE
# clip nmf.W to k clusters, and reorder by weight
clust = nmf[nmf['k']]; W = clust['W']; clusters = dict(); clusters['W'] = np.zeros(W.shape)
clusters['all_subjs'] = np.zeros(W.shape).astype(str); clusters['ch_names'] = np.zeros(W.shape).astype(str)
clusters['resp'] = {epoch_type:np.zeros((nmf['k'],W.shape[0],
                                         nmf['resp'][epoch_types[0].shape[1]])) for epoch_type in epoch_types}
clust_sort = np.argsort(W.sum(0))
for ri,ai in enumerate(clust_sort): # relative/absolute index
    sorted_idxs = np.flip(np.argsort(W[:,ai])); clusters['W'][:,ri] = W[:,ai][sorted_idxs]
    clusters['all_subjs'][:,ri] = nmf['all_subjs'][sorted_idxs]
    clusters['ch_names'][:,ri] = nmf['ch_names'][sorted_idxs]
    for epoch_type in epoch_types:
        clusters['resp'][epoch_type][ri,:] = nmf['resp'][epoch_type][sorted_idxs]   

### Kneepoint panel

In [None]:
plt.figure(figsize=(5,4))
bname = "spkrmicclick"
plt.plot(num_bases,nmf['pve'],color='r')
plt.axvline(nmf['k'],color='k',
            label="%.1f%% PVE k=%d"%(nmf['pve'][nmf['k']-2]*100,nmf['k']))
plt.gca().set_xticks(np.arange(num_bases[0],num_bases[-1],8))
plt.xlabel("Number of clusters")
plt.ylabel("Percent variance explained (PVE)")
plt.legend(frameon=True,framealpha=1,loc='lower right');
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","kneepoint.pdf"))

### Cluster average response

In [None]:
n_top_elecs = 16; clust = 0 # update accordingly, this is zero-indexed so 0 is dual onset clust
# You can also cross-reference with figure S3 if you need to know which cluster is which response type.
plt.figure(figsize=(5,5))
x = np.linspace(-1,2,clusters['resp']['mic'][clust].shape[1])
for i,epoch_type in enumerate(epoch_types):
    plt.plot(x,clusters['resp'][epoch_type][clust][:top_elecs].mean(0),color=plt_colors[i])
    plt.axvline(0,color='k'); plt.title(f"Cluster {clust+1} top {top_elecs} elecs avg")
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf",f"cluster_{clust+1}_avg_resp.pdf"))

### Single electrode plots

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

In [None]:
def sem(epochs):
    '''
    calculates standard error margin across epochs
    epochs should have shape (epochs,samples)
    '''
    sem_below = epochs.mean(0) - (epochs.std(0)/np.sqrt(epochs.shape[0]))
    sem_above = epochs.mean(0) + (epochs.std(0)/np.sqrt(epochs.shape[0]))
    return sem_below, sem_above

In [None]:
subj, elec = "S0018", "PST-PI'5" # Update accordingly
tmin, tmax = -0.5, 1.0
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    spkr_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='spkr',tmin=tmin,tmax=tmax)
    mic_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='mic',tmin=tmin,tmax=tmax)
    click_epochs = epoch_data(subj,blocks[subj],git_path,data_path,channel='spkr',click=True,tmin=tmin,tmax=tmax)
ch_names = spkr_epochs.info['ch_names']
ch_idx = ch_names.index(elec)
# Load anat
if os.path.isdir(os.path.join(data_path,f"{subj}_complete")):
    patient = img_pipe.freeCoG(f"{subj}_complete",hem='stereo',subj_dir=data_path)
    anat = patient.get_elecs()['anatomy']
    anat_idx = [a[0][0] for a in anat].index(elec)
    fs_roi = anat[anat_idx][3][0]
else:
    fs_roi = "anatomy unknown"
x = spkr_epochs.times
fig = plt.figure(figsize=(5,5))
# Plot spkr
spkr_y = spkr_epochs.get_data(picks=[elec]).squeeze()
spkr_y_below, spkr_y_above = sem(spkr_y)
plt.plot(x,spkr_y.mean(0),color=spkr_color)
plt.fill_between(x,spkr_y_below,spkr_y_above,color=spkr_color,alpha=0.3)
# Plot mic
mic_y = mic_epochs.get_data(picks=[elec]).squeeze()
mic_y_below, mic_y_above = sem(mic_y)
plt.plot(x,mic_y.mean(0),color=mic_color)
plt.fill_between(x,mic_y_below,mic_y_above,color=mic_color,alpha=0.3)
# Plot click
click_y = click_epochs.get_data(picks=[elec]).squeeze()
click_y_below, click_y_above = sem(click_y)
plt.plot(x,click_y.mean(0),color=click_color)
plt.fill_between(x,click_y_below,click_y_above,color=click_color,alpha=0.3)
# Plt decorations
plt.title(f"{subj} {elec} {fs_roi}", fontsize=14)
plt.axvline(0,color='k')
# Plt settings
xlims = [x[0], x[-1]]
plt.gca().set_xlim(xlims)
xticks = np.round(np.arange(xlims[0],xlims[-1]+.5,.5),decimals=1)
plt.gca().set_xticks(xticks)
plt.gca().set_xticklabels(xticks,fontsize=12)
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf",f"{subj}_{elec}_spkrmicclick.pdf"))

### Colorbars
for 2d and 1d recon heatmaps

In [None]:
cmap = cm.Reds
plt.figure(figsize=(12,2))
plt.imshow(np.repeat(np.expand_dims(np.arange(100),axis=1),3,axis=1).T,aspect='auto',cmap=cmap)
plt.axis('off');
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","legend_1d.pdf"))

In [None]:
# Read in colormap from .png
cmap_2d = plt.imread(os.path.join(git_path,"figures","figure_3","RdBuPr_splinesqrt22.png"))
plt.figure(figsize=(5,5))
plt.imshow(cmap_2d)
plt.grid(False)
plt.xlabel("RGB 0-255")
plt.ylabel("RGB 0-255");
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf","legend_2d.pdf"))

### Pie charts

In [None]:
k = nmf['k']; n_top_elecs=50
top_elecs = pd.DataFrame(columns=["subj","elec","clust","elec_rank","w",
                                  "fs_roi","fs_roi_condensed","gross_anat"])
for clust in np.arange(k):
    for n in np.arange(n_top_elecs):
        elec = clusters['ch_names'][n,clust]; s = clusters['all_subjs'][n,clust]; w = clusters['W'][n,clust]
        if s not in no_imaging:
            pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
            anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
            fs_ch_names = [a[0][0].replace("-","") for a in anat]
            if s == "S0020":
                fs_ch_names = [c.replace("AIPOF'","APIOF'") for c in fs_ch_names]
            if elec.replace("-","") in fs_ch_names:
                fs_idx = fs_ch_names.index(elec.replace("-",""))
                fs_roi = [a[3][0] for a in anat][fs_idx]
                fs_roi_condensed = imaging_utils.condense_roi(fs_roi)
                gross_anat = imaging_utils.gross_anat(fs_roi)
            else:
                fs_roi = "anatomy unavailable"; fs_roi_condensed = "anatomy unavailable"
                gross_anat = "anatomy unavailable"
        else:
            fs_roi = "anatomy unavailable"; fs_roi_condensed = "anatomy unavailable"
            gross_anat = "anatomy unavailable"
        new_row = pd.DataFrame({'subj':[s],'elec':[elec],'clust':[clust],'elec_rank':[n],'w':[w],
            'fs_roi':[fs_roi],'fs_roi_condensed':[fs_roi_condensed],'gross_anat':[gross_anat]})
        top_elecs = top_elecs.append(new_row,ignore_index=True)               

In [None]:
xc, xlabel = 1, "Onset suppression"; yc, ylabel =  0, "Dual onset"; percentages = {xc:[], yc:[]}
xc_subjs = np.unique(np.hstack((top_elecs.loc[top_elecs['clust']==xc]['subj'].values)))
yc_subjs = np.unique(np.hstack((top_elecs.loc[top_elecs['clust']==yc]['subj'].values)))
for s in xc_subjs:
    xc_elecs = np.unique(np.hstack((top_elecs.loc[(top_elecs['clust']==xc)&(
        top_elecs['subj']==s)]['elec'].values)))
    num, denom = [], []
    for elec in xc_elecs:
        for clust in [xc,yc]:
            clust_elecs = top_elecs.loc[(top_elecs['clust']==clust)&(top_elecs['subj']==s)]
            if len(clust_elecs) > 0:
                if elec in clust_elecs['elec'].values:
                    w = top_elecs.loc[(top_elecs['clust']==clust)&(top_elecs['subj']==s)&(
                        top_elecs['elec']==elec)]['w'].values[0]
                    if clust == xc:
                        num.append(w)
                    denom.append(w)
        percentages[xc].append(sum(num)/sum(denom))
for s in yc_subjs:
    yc_elecs = np.unique(np.hstack((top_elecs.loc[(top_elecs['clust']==yc)&(
        top_elecs['subj']==s)]['elec'].values)))
    num, denom = [], []
    for elec in yc_elecs:
        for clust in [xc,yc]:
            clust_elecs = top_elecs.loc[(top_elecs['clust']==clust)&(top_elecs['subj']==s)]
            if len(clust_elecs) > 0:
                if elec in clust_elecs['elec'].values:
                    w = top_elecs.loc[(top_elecs['clust']==clust)&(top_elecs['subj']==s)&(
                        top_elecs['elec']==elec)]['w'].values[0]
                    if clust == yc:
                        num.append(w)
                    denom.append(w)
        percentages[yc].append(sum(num)/sum(denom))
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.title(f"Top {n_top_elecs} {xlabel} electrodes")
w1 = np.array(percentages[xc]).mean()
w2 = 1-w1
plt.pie(np.array((w1,w2)), colors=[c1, c2], startangle=90, explode=[0.2,0.],
        labels=["%.1f%%"%(w1*100),"%.1f%%"%(w2*100)])
plt.subplot(1,2,2)
plt.title(f"Top {n_top_elecs} {ylabel} electrodes")
w1 = np.array(percentages[yc]).mean()
w2 = 1-w1
plt.pie(np.array((w1,w2)), colors=[c2, c1], startangle=90, explode=[0.2,0.],
        labels=["%.1f%%"%(w1*100),"%.1f%%"%(w2*100)])
plt.bar(0,0,color=c1,label="Onset suppression cluster weight")
plt.bar(0,0,color=c2,label="Dual onset cluster weight")
plt.legend(frameon=True, bbox_to_anchor=(1,1));
plt.tight_layout();
plt.savefig(os.path.join(git_path,"figures","figure_3","pdf",f"{xlabel}_{ylabel}_pie.pdf"))