In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [None]:
import mne
import numpy as np
import pandas as pd
import os
import re
import csv
from tqdm.notebook import tqdm
import warnings
import random
import itertools as itools

from img_pipe import img_pipe

from matplotlib import pyplot as plt
from matplotlib import rcParams as rc
import matplotlib.patheffects as PathEffects
rc['pdf.fonttype'] = 42
plt.style.use('seaborn')
%matplotlib inline

import sys
sys.path.append(os.path.join(git_path,"analysis","mtrf"))
import mtrf_utils
sys.path.append(os.path.join(git_path,"preprocessing","imaging"))
import imaging_utils

In [None]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

hems = {s:[] for s in subjs if s not in no_imaging}
ignore_lh = []; ignore_rh = ['S0024']
for s in hems.keys():
    pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
    elecs = pt.get_elecs()['elecmatrix']
    if sum(elecs[:,0] > 0) > 1 and s not in ignore_rh:
        hems[s].append('rh')
    if sum(elecs[:,0] < 0) > 1 and s not in ignore_lh:
        hems[s].append('lh')

In [None]:
def epoch_data(subj,blocks,git_path,data_path,channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None, click=False):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        raw_fpath = os.path.join(
            data_path,f"sub-{subj}",blockid,"HilbAA_70to150_8band","ecog_hilbAA70to150.fif")
        if click:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_click_eve.txt")
        else:
            eventfile = os.path.join(git_path,"preprocessing","events","csv",s,blockid,
                                     f"{blockid}_{channel}_{level}_{condition}.txt")
        raw = mne.io.read_raw_fif(raw_fpath,preload=True,verbose=False)
        fs = raw.info['sfreq']
        if click:
            onset_index, offset_index, id_index = 0,2,4
        else:
            onset_index, offset_index, id_index = 0,1,2
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r])
        epochs.append(mne.Epochs(raw,events,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    return mne.concatenate_epochs(epochs,verbose=False)

In [None]:
onset_tmin, onset_tmax = 0., 0.3
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
# Make dataframe
df = pd.DataFrame(columns=['subj','ch_name','hem','fs_roi','gross_anat','condensed_roi','condition',
                           'onset_times','peak_amplitude','peak_latency'])
for hem in ['lh','rh']:
    for s in tqdm(hems.keys()):
        if hem in hems[s]:
            # Epoch data (spkr)
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                spkr_epochs = epoch_data(s,blocks[s],git_path,dp,channel='spkr',sub=sub)
                spkr_resp = spkr_epochs.get_data()
            x = spkr_epochs.times
            imin = np.where(x==onset_tmin)[0][0]; imax = np.where(x==onset_tmax)[0][0]
            fif_ch_names = spkr_epochs.info['ch_names']
            excl_ch_names = list(excl_df.loc[excl_df['subject']==s]['channel'].values)
            if s == "S0020":
                # One device for S0020 is named incorrectly so we have to write an exception for it.
                fif_ch_names = [c.replace("AIPOF'","APIOF'") for c in fif_ch_names]
            # Epoch data (mic)
            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                mic_epochs = epoch_data(s,blocks[s],git_path,dp,channel='mic',sub=sub)
                mic_resp = mic_epochs.get_data()
            # Get anat
            pt = img_pipe.freeCoG(f"{s}_complete",hem='stereo',subj_dir=data_path)
            elecs_all = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['elecmatrix']
            anat_all = pt.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
            elecs,anat = imaging_utils.clip_hem_elecs(pt,hem=hem,elecmatrix=elecs_all,anatomy=anat_all)
            fs_ch_names = [a[0][0] for a in anat]; fs_labels = [a[3][0] for a in anat]
            if s == "S0021":
                # Naming error with a few devices in this matfile we need to fix.
                fs_ch_names = [c.replace("APPI'","ASPPI'") for c in fs_ch_names]
            for elecfile_idx, ch_name in enumerate(fs_ch_names):
                if ch_name.replace('-','') in [c.replace('-','') for c in fif_ch_names]:
                    if ch_name.replace('-','') not in [c.replace('-','') for c in excl_ch_names]:
                        fs_roi = fs_labels[elecfile_idx]; condensed_roi = imaging_utils.condense_roi(fs_roi)
                        fif_idx = [c.replace('-','') for c in fif_ch_names].index(ch_name.replace('-',''))
                        # Calculate onset (spkr)
                        spkr_std = np.std(spkr_resp[:,fif_idx,:].mean(0)) * 1.5
                        if len(np.where(spkr_resp[:,fif_idx,:].mean(0)[imin:imax] > spkr_std)[0]) > 0:
                            enter_onset_spkr = [idx for idx in np.where(spkr_resp[:,fif_idx,:].mean(
                                0)>spkr_std)[0] if x[idx]>=onset_tmin and x[idx]<=onset_tmax][0]
                            exit_onset_spkr = [idx for idx in np.where(spkr_resp[:,fif_idx,:].mean(
                                0)>spkr_std)[0] if x[idx]>=x[enter_onset_spkr] and idx+1 not in np.where(
                                spkr_resp[:,fif_idx,:].mean(0)>spkr_std)[0]]
                            if len(exit_onset_spkr) > 0:
                                exit_onset_spkr = exit_onset_spkr[0]
                                if enter_onset_spkr != exit_onset_spkr:
                                    spkr_onset_tmin = x[enter_onset_spkr]; spkr_onset_tmax = x[exit_onset_spkr]
                                    spkr_onset_peak_latency = x[enter_onset_spkr:exit_onset_spkr][
                                        spkr_resp[:,fif_idx,:].mean(0)[enter_onset_spkr:exit_onset_spkr].argmax()]
                                    spkr_onset_peak_amplitude = spkr_resp[:,fif_idx,:].mean(0)[
                                        enter_onset_spkr:exit_onset_spkr].max()
                                    has_spkr_onset = True
                                else:
                                    has_spkr_onset = False
                            else:
                                has_spkr_onset = False
                        else:
                            has_spkr_onset = False
                        if not has_spkr_onset:
                            spkr_onset_tmin = np.nan; spkr_onset_tmax = np.nan
                            spkr_onset_peak_latency = np.nan; spkr_onset_peak_amplitude = np.nan
                        # Calculate onset (mic)
                        mic_std = np.std(mic_resp[:,fif_idx,:].mean(0)) * 1.5
                        if len(np.where(mic_resp[:,fif_idx,:].mean(0)[imin:imax] > mic_std)[0]) > 0:
                            enter_onset_mic = [idx for idx in np.where(mic_resp[:,fif_idx,:].mean(
                                0)>mic_std)[0] if x[idx]>=onset_tmin and x[idx]<=onset_tmax][0]
                            exit_onset_mic = [idx for idx in np.where(mic_resp[:,fif_idx,:].mean(
                                0)>mic_std)[0] if x[idx]>=x[enter_onset_mic] and idx+1 not in np.where(
                                mic_resp[:,fif_idx,:].mean(0)>mic_std)[0]]
                            if len(exit_onset_mic) > 0:
                                exit_onset_mic = exit_onset_mic[0]
                                if enter_onset_mic != exit_onset_mic:
                                    mic_onset_tmin = x[enter_onset_mic]; mic_onset_tmax = x[exit_onset_mic]
                                    mic_onset_peak_latency = x[enter_onset_mic:exit_onset_mic][
                                        mic_resp[:,fif_idx,:].mean(0)[enter_onset_mic:exit_onset_mic].argmax()]
                                    mic_onset_peak_amplitude = mic_resp[:,fif_idx,:].mean(0)[
                                        enter_onset_mic:exit_onset_mic].max()
                                    has_mic_onset = True
                                else:
                                    has_mic_onset = False
                            else:
                                has_mic_onset = False
                        else:
                            has_mic_onset = False
                        if not has_mic_onset:
                            mic_onset_tmin = np.nan; mic_onset_tmax = np.nan
                            mic_onset_peak_latency = np.nan; mic_onset_peak_amplitude = np.nan
                        # Write to DataFrame
                        if imaging_utils.gross_anat(fs_roi) in ['frontal','temporal','parietal','occipital',
                            'precentral','postcentral','insula','whitematter']:
                            if imaging_utils.gross_anat(fs_roi) == "whitematter":
                                cr = 'subcort'
                            else:
                                cr = condensed_roi
                            spkr_row = pd.DataFrame({'subj':[s],'ch_name':[ch_name],'hem':[hem],
                                'fs_roi':[fs_roi],'gross_anat':[imaging_utils.gross_anat(fs_roi)],
                                'condensed_roi':[condensed_roi],'condition':['spkr'],
                                'onset_times':[[spkr_onset_tmin,spkr_onset_tmax]],
                                'peak_amplitude':[spkr_onset_peak_amplitude],
                                'peak_latency':[spkr_onset_peak_latency]
                            })
                            df = df.append(spkr_row, ignore_index=True)
                            mic_row = pd.DataFrame({'subj':[s],'ch_name':[ch_name],'hem':[hem],
                                'fs_roi':[fs_roi],'gross_anat':[imaging_utils.gross_anat(fs_roi)],
                                'condensed_roi':[condensed_roi],'condition':['mic'],
                                'onset_times':[[mic_onset_tmin,mic_onset_tmax]],
                                'peak_amplitude':[mic_onset_peak_amplitude],
                                'peak_latency':[mic_onset_peak_latency]
                            })
                            df = df.append(mic_row, ignore_index=True)
# Save df
df.to_csv(os.path.join(git_path,"stats","onset_stats.csv"), index=False)