In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
import random
import itertools as itools

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"analysis","mtrf"))
import mtrf_utils
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}
models = ['model1','model2','model3','model4']

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

### LME #1: onset vs. sustained responses

Layout:

| subj     | elec     |  roi     |  window         |   si     | onset_tmax |
|:--------:|:--------:|:--------:|:---------------:|:--------:|:----------:|
|  S0006   | RPPST13  |  stg     |{onset,sustained}|  0.523   |   0.323    |
|   ...    |   ...    |   ...    |      ...        |   ...    |    ...     |

* LME Equation: `si ~ window + roi + (1|subj)`
* SI Equations: $SI_{n_{onset}} = \frac{1}{t}\sum\limits_{0}^{0.75}{H_\gamma L_{n,t}}-{H_\gamma S_{n,t}}$, and $SI_{n_{sustained}} = \frac{1}{t}\sum\limits_{1}^{1.75}{H_\gamma L_{n,t}}-{H_\gamma S_{n,t}}$.

In [None]:
df_spkr = pd.DataFrame(columns=['subj','elec','roi','window','si','onset_tmax'])
df_mic = pd.DataFrame(columns=['subj','elec','roi','window','si','onset_tmax'])
onset_df = pd.read_csv(os.path.join(git_path,"stats","onset_stats.csv"))

In [None]:
erp_tmin = -.5; erp_tmax = 2; reject = None; baseline = None
epochs = dict(); ch_names = dict()
for s in tqdm(subjs):
    epochs[s] = dict(); spkr_epochs, mic_epochs = [], []
    for b in blocks[s]:
        blockid = f'{s}_{b}'
        raw_fpath = os.path.join(data_path,f"sub-{s}",s,blockid,"HilbAA_70to150_8band",
                                 "ecog_hilbAA70to150.fif")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        ch_names[s] = raw.info['ch_names']; fs = raw.info['sfreq']
        # Spkr events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_spkr_sn_all.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        spkr_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
        # Mic events
        eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                 f"{blockid}_mic_sn_all.txt")
        with open(eventfile,'r') as f:
            c = csv.reader(f,delimiter='\t')
            events = np.array([[int(float(row[0])*fs),int(float(row[1])*fs),int(row[2])] for row in c])
        mic_epochs.append(mne.Epochs(raw,events,tmin=erp_tmin,tmax=erp_tmax,
                                      baseline=baseline,reject=reject,verbose=False))
    epochs[s]['spkr'] = mne.concatenate_epochs(spkr_epochs)
    epochs[s]['mic'] = mne.concatenate_epochs(mic_epochs)

In [None]:
onset_tmin, onset_tmax = 0, 0.75; sus_tmin, sus_tmax = 1., 1.75; x = epochs[s]['mic'].times
onset_inds = [np.where(x==onset_tmin)[0][0],np.where(x==onset_tmax)[0][0]]
sus_inds = [np.where(x==sus_tmin)[0][0],np.where(x==sus_tmax)[0][0]]
si = {'onset':{}, 'sustained':{}}
for s in tqdm(subjs):
    subj_si_ons = []; subj_si_sus = []
    for i,ch in enumerate(epochs[s]['mic'].info['ch_names']):
        spkr_resp = epochs[s]['spkr'].get_data(picks=[ch]).squeeze()
        ons_resp_spkr = spkr_resp[:,onset_inds[0]:onset_inds[1]].mean(0).mean(0)
        sus_resp_spkr = spkr_resp[:,sus_inds[0]:sus_inds[1]].mean(0).mean(0)
        mic_resp = epochs[s]['mic'].get_data(picks=[ch]).squeeze()
        ons_resp_mic = mic_resp[:,onset_inds[0]:onset_inds[1]].mean(0).mean(0)
        sus_resp_mic = mic_resp[:,sus_inds[0]:sus_inds[1]].mean(0).mean(0)
        subj_si_ons.append(ons_resp_spkr - ons_resp_mic)
        subj_si_sus.append(sus_resp_spkr - sus_resp_mic)
    si['onset'][s] = np.array(subj_si_ons)
    si['sustained'][s] = np.array(subj_si_sus)
si_min = np.hstack((np.hstack((list(si['onset'].values()))),
                    np.hstack((list(si['sustained'].values()))))).min()
si_max = np.hstack((np.hstack((list(si['onset'].values()))),
                    np.hstack((list(si['sustained'].values()))))).max()
for s in subjs:
    si['onset'][s] = ((si['onset'][s]-si_min)/(si_max-si_min)) * 2 - 1
    si['sustained'][s] = ((si['sustained'][s]-si_min)/(si_max-si_min)) * 2 - 1

In [None]:
df = pd.DataFrame(columns=['subj','elec','elec_type','roi','roi_condensed','window','si'])
for s in subjs:
    # Get anat
    if s not in no_imaging:
        pt = img_pipe.freeCoG(f"{s}_complete", hem='stereo', subj_dir=data_path)
        anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
        fs_ch_names = [a[0][0] for a in anat]; fs_rois = [a[3][0] for a in anat]
    # Get the onset info
    onset_chs_spkr = np.unique([ch for ch in onset_df.loc[onset_df['subj']==s]['ch_name'] if not np.isnan(
        onset_df.loc[(onset_df['subj']==s)&(onset_df['ch_name']==ch)&(onset_df['condition']=='spkr')][
            'peak_amplitude'].values[0])])
    onset_chs_mic = np.unique([ch for ch in onset_df.loc[onset_df['subj']==s]['ch_name'] if not np.isnan(
        onset_df.loc[(onset_df['subj']==s)&(onset_df['ch_name']==ch)&(onset_df['condition']=='mic')][
            'peak_amplitude'].values[0])])
    dual_onset_chs = np.intersect1d(onset_chs_spkr,onset_chs_mic)
    no_onset_chs = [ch for ch in epochs[s]['mic'].info['ch_names'] if ch not in list(
        onset_chs_spkr) and ch not in list(onset_chs_mic)]
    for fif_idx, elec in enumerate(epochs[s]['mic'].info['ch_names']):
        ons_si = si['onset'][s][fif_idx]; sus_si = si['sustained'][s][fif_idx]
        # Get anatomy
        if s not in no_imaging:
            if elec in fs_ch_names:
                fs_idx = fs_ch_names.index(elec); fs_roi = fs_rois[fs_idx]
                roi_condensed = imaging_utils.condense_roi(fs_roi)
            else:
                warnings.warn(f"Could not locate anatomy for {s} {elec}")
                fs_roi = "roi_unavailable"; roi_condensed = "roi_unavailable"
        else:
            fs_roi = "roi_unavailable"; roi_condensed = "roi_unavailable"
        if elec in onset_chs_spkr:
            if elec in dual_onset_chs:
                elec_type = "dual_onset"
            else:
                elec_type = "spkr_only"
        elif elec in onset_chs_mic:
            elec_type = "mic_only"
        else:
            elec_type = "no_onset"
        onset_row = pd.DataFrame({'subj':[s], 'elec':[elec], 'elec_type':[elec_type], 'roi':[fs_roi],
            'roi_condensed':[roi_condensed], 'window':['onset'], 'si':[ons_si]})
        sustained_row = pd.DataFrame({'subj':[s], 'elec':[elec], 'elec_type':[elec_type], 'roi':[fs_roi],
            'roi_condensed':[roi_condensed], 'window':['sustained'], 'si':[sus_si]})
        df = df.append(onset_row,ignore_index=True); df = df.append(sustained_row,ignore_index=True)
df.to_csv(os.path.join(git_path,'stats','lme','csv','onset_sustained_si.csv'), index=False)

### LME #2: A1, A2, and insular peak latencies

This is to get a more finegrained look at the differences between these regions' response profiles.

Layout:

| subj     | elec     |     roi          | peak_amp | peak_lat |
|:--------:|:--------:|:----------------:|:--------:|:--------:|
|  S0006   | RPPST13  |   {A1,A2,insula} |  0.523   |  0.323   |
|   ...    |   ...    |      ...         |   ...    |   ...    |

* LME Equation: `peak_lat ~ roi + (1|subj)`

In [None]:
a1_rois = ['HG', 'PT']; a2_rois = ['STG', 'STS', 'PP', 'MTG']
insular_rois = ['insula_inf','insula_sup','insula_post','insula_ant']
peak_tmin, peak_tmax = 0, 0.5
df = pd.DataFrame(columns=['subj','elec','roi','peak_amp','peak_lat'])
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
for s in tqdm([ss for ss in subjs if ss not in no_imaging]):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        spkr_epochs = epoch_data(s,blocks[s],git_path,data_path,channel='spkr',tmin=peak_tmin,tmax=peak_tmax)
    x = spkr_epochs.times
    excl_ch_names = [ch for ch in excl_df.loc[excl_df['subject']==s]['channel']]
    pt = img_pipe.freeCoG(f"{s}_complete", hem='stereo')
    anat = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
    fs_ch_names = [a[0][0] for a in anat]
    fs_rois = [a[3][0] for a in anat]
    for elec in spkr_epochs.info['ch_names']:
        if elec not in excl_ch_names and elec in fs_ch_names:
            fs_idx = fs_ch_names.index(elec)
            if imaging_utils.condense_roi(fs_rois[fs_idx]) in a1_rois:
                spkr_resp = spkr_epochs.get_data(picks=[elec]).squeeze().mean(0) # Avg across epochs
                spkr_peak_amp = spkr_resp.max()
                spkr_latency_idx = spkr_resp.argmax(); spkr_peak_latency = x[spkr_latency_idx]
                spkr_row = pd.DataFrame({'subj':[s], 'elec':[elec], 'roi':['a1'], 'cond':['spkr'],
                    'peak_amp':[spkr_peak_amp], 'peak_lat':[spkr_peak_latency]})
                df = df.append(spkr_row,ignore_index=True)
            if imaging_utils.condense_roi(fs_rois[fs_idx]) in a2_rois:
                spkr_resp = spkr_epochs.get_data(picks=[elec]).squeeze().mean(0) # Avg across epochs
                spkr_peak_amp = spkr_resp.max()
                spkr_latency_idx = spkr_resp.argmax(); spkr_peak_latency = x[spkr_latency_idx]
                spkr_row = pd.DataFrame({'subj':[s], 'elec':[elec], 'roi':['a2'], 'cond':['spkr'],
                    'peak_amp':[spkr_peak_amp], 'peak_lat':[spkr_peak_latency]})
                df = df.append(spkr_row,ignore_index=True)
            if imaging_utils.condense_roi(fs_rois[fs_idx]) in a2_rois:
                spkr_resp = spkr_epochs.get_data(picks=[elec]).squeeze().mean(0)
                spkr_peak_amp = spkr_resp.max()
                spkr_latency_idx = spkr_resp.argmax(); spkr_peak_latency = x[spkr_latency_idx]
                spkr_row = pd.DataFrame({'subj':[s], 'elec':[elec], 'roi':['insular'], 'cond':['spkr'],
                    'peak_amp':[spkr_peak_amp], 'peak_lat':[spkr_peak_latency]})
                df = df.append(spkr_row,ignore_index=True)
df.to_csv(os.path.join(git_path,"stats","lme","csv","temporal_vs_insular_spkr.csv", index=False)

### LME #3: mTRF model comparison

In [None]:
h5_folder = os.path.join(git_path,"analysis","mtrf","h5")
features = {model_number:mtrf_utils.get_feats(
    model_number,extend_labels=True,mode='ecog') for model_number in models}
n_feats = {model_number:len(features[model_number]) for model_number in models}
# Make delays
delay_min, delay_max = -0.3, 0.5
delays = np.arange(np.floor(delay_min*100),np.ceil(delay_max*100),dtype=int)

In [None]:
# Load strf data from pandas dataframe
rfpath = os.path.join(git_path,"analysis","mtrf","results.csv")
df = pd.read_csv(rfpath)
wts, corrs, pvals, sig_wts, sig_corrs, alphas = dict(), dict(), dict(), dict(), dict(), dict()
for m in models:
    wts[m], corrs[m], pvals[m], sig_wts[m], sig_corrs[m], alphas[m] = dict(), dict(), dict(), dict(), dict(), dict()
    b = tqdm(subjs)
    for s in b:
        b.set_description(f"Loading STRF for {s} {m}")
        with h5py.File(f"{h5_folder}weights/{s}_weights.hdf5",'r') as f:
            wts[m][s] = np.array(f.get(m))
        pvals[m][s], corrs[m][s], alphas[m][s] = [], [], []
        for i, rs in enumerate(df['subject']):
            if rs == s and df['model'][i] == m:
                corrs[m][s].append(df['r_value'][i]); pvals[m][s].append(df['p_value'][i])
                alphas[m][s].append(df['best_alpha'][i])
        corrs[m][s] = np.array(corrs[m][s]); pvals[m][s] = np.array(pvals[m][s])
        alphas[m][s] = np.array(alphas[m][s]); n_delays, n_feats, n_chans = wts[m][s].shape
        if s == "TCH14":
            n_chans -= 1 # rm EKG channel
        sig_wts[m][s] = np.ones(wts[m][s].shape); sig_corrs[m][s] = np.ones(n_chans)
        for i in np.arange(n_chans):
            if pvals[m][s][i] < 0.01:
                sig_wts[m][s][:,:,i] = wts[m][s][:,:,i]; sig_corrs[m][s][i] = corrs[m][s][i]

In [None]:
# Update xm/ym to make a different csv
xm, ym = 'model1', 'model2'
header = [['r','model','subject','channel']]
for s in subjs:
    ch_names = mne.io.read_raw_fif(os.path.join(data_path,f"sub-{s}",s,f"{s}_{blocks[0]}",
        "HilbAA_70to150_8band","ecog_hilbAA70to150.fif"), preload=True, verbose=False).info['ch_names']
    if s == "TCH14":
        ch_names = ch_names[:-1] # rm EKG channel
    for i,c in enumerate(ch_names):
        header.append([corrs[xm][s][i],xm,s,c]); header.append([corrs[ym][s][i],ym,s,c])
csv_fname = os.path.join(git_path,"stats","lme","csv",f'{xm}_{ym}.csv') 
with open(csv_fname,'w+') as f:
    csvWriter = csv.writer(f)
    csvWriter.writerows(header)