In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
from img_pipe import img_pipe
import librosa
import h5py
import pymf3

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"figures"))
import plotting_utils
sys.path.append(os.path.join(git_path,"preprocessing","events","textgrids"))
import textgrid
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs}
for s in subjs:
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) >= 1:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) >= 1:
        hems[s].append('lh')

color_palette = pd.read_csv(os.path.join(git_path,"figures","color_palette.csv"))
spkr_color = color_palette.loc[color_palette['color_id']=='perception']['hex'].values[0]
mic_color = color_palette.loc[color_palette['color_id']=='production']['hex'].values[0]
click_color = color_palette.loc[color_palette['color_id']=='click']['hex'].values[0]

### load cNMF results

In [None]:
# Load NMF results from h5
bname = 'spkrmicclick'; epoch_types = ['spkr','mic','click']; plt_colors = [spkr_color,mic_color,click_color]
nmf = dict(); nmf['resp'] = dict()
h5_fpath = os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}.hf5")
with h5py.File(h5_fpath,'r') as f:
    num_bases = np.array(f.get('num_bases')); nmf['pve'] = np.array(f.get('pve'))
    for nb in num_bases:
        nmf[nb] = dict(); nmf[nb]['W'] = np.array(f.get(f'{nb}_bases/W'))
        nmf[nb]['H'] = np.array(f.get(f'{nb}_bases/H'))
    for epoch_type in epoch_types:
        nmf['resp'][epoch_type] = np.array(f.get(f'resp/{epoch_type}'))
nmf['ch_names'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                          f"NMF_grouped_{bname}_ch_names.txt"),dtype=str)
nmf['all_subjs'] = np.loadtxt(os.path.join(git_path,"analysis","cnmf","h5",
                                           f"NMF_grouped_{bname}_all_subjs.txt"), dtype=str)

In [None]:
# Convert resp into a subject-by-subject format
resp, ch_names = dict(), dict()
for s in subjs:
    nmf_inds = np.where(nmf['all_subjs']==s)[0]
    if nmf_inds.shape[0] == 0:
        warnings.warn(f"Subject {s} missing from {bname} NMF, skipping...")
    else:
        ch_names[s] = nmf['ch_names'][nmf_inds]; resp[s] = dict()
        for epoch_type in epoch_types:
            resp[s][epoch_type] = nmf['resp'][epoch_type][nmf_inds,:]

In [None]:
# Truncate weights according to a kneepoint
nmf['k'] = 9 # 86% PVE
# clip nmf.W to k clusters, and reorder by weight
clust = nmf[nmf['k']]; W = clust['W']; clusters = dict(); clusters['W'] = np.zeros(W.shape)
clusters['all_subjs'] = np.zeros(W.shape).astype(str); clusters['ch_names'] = np.zeros(W.shape).astype(str)
clusters['resp'] = {epoch_type:np.zeros((nmf['k'],W.shape[0],
                                         nmf['resp'][epoch_types[0].shape[1]])) for epoch_type in epoch_types}
clust_sort = np.argsort(W.sum(0))
for ri,ai in enumerate(clust_sort): # relative/absolute index
    sorted_idxs = np.flip(np.argsort(W[:,ai])); clusters['W'][:,ri] = W[:,ai][sorted_idxs]
    clusters['all_subjs'][:,ri] = nmf['all_subjs'][sorted_idxs]
    clusters['ch_names'][:,ri] = nmf['ch_names'][sorted_idxs]
    for epoch_type in epoch_types:
        clusters['resp'][epoch_type][ri,:] = nmf['resp'][epoch_type][sorted_idxs]   

### all clusters' top elecs, averaged

In [None]:
n_top_elecs = 16
plt.figure(figsize=(10,10))
for clust in np.arange(nmf['k']):
    plt.subplot(3,3,clust+1)
    x = np.linspace(-1,2,clusters['resp']['mic'][clust].shape[1])
    for i,epoch_type in enumerate(epoch_types):
        plt.plot(x,clusters['resp'][epoch_type][clust][:top_elecs].mean(0),color=plt_colors[i])
        plt.axvline(0,color='k'); plt.title(f"Cluster {clust+1} top {top_elecs} elecs avg")
plt.tight_layout();
plt.savefig(os.path.join(git_path,"figures","supplemental_figure_3","pdf",f"cluster_avg_resps.pdf"))