# Benchmark Analyses - Surgical Outcome Predictors

Using spectral analyses, we might be interested in how the "distribution" of a certain frequency band changes as a result of the resection. For example, we may look at the different frequency bands of each channel decomposed using the Morlet wavelet transform.

In [32]:
import numpy as np
import pandas as pd
import mne
import os
import json
import os.path as op
from pathlib import Path
import collections
from pprint import pprint
from natsort import natsorted

from mne.io import RawArray
from mne import create_info
from mne_bids import BIDSPath, get_entity_vals, read_raw_bids
import mne
from mne.time_frequency import read_tfrs

mne.utils.use_log_level('error')

import ptitprince as pt
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl

import pingouin as pg
import dabest
from hyppo.independence import MGC
from hyppo.ksample import KSample

from eztrack.utils import Normalize
from eztrack.io.base import _add_desc_to_bids_fname, concatenate_derivs
from eztrack.viz import _load_turbo, generate_heatmap
from eztrack.posthoc.hypo import compute_null

_load_turbo()

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
np.random.seed(12345)

In [6]:
import warnings
warnings.filterwarnings('ignore')

In [7]:
def load_concat_derivs(deriv_path, subject, session, desc, tasks=None):
    # get all the subject/sessions in the derivative path
    subjects = get_entity_vals(deriv_path, 'subject')
    ignore_subjects = [sub for sub in subjects if sub != subject]
    sessions = get_entity_vals(deriv_path, 'session', ignore_subjects=ignore_subjects)
    ignore_sessions = [ses for ses in sessions if ses != session]

    # store all derivatives found in a list
    derivs = []
    onsets = []
    descriptions = []
    prevlen = 0
    
    if session == 'extraoperative':
        tasks = ['interictal', 'ictal']
        
    if tasks is None:
        # get all the tasks associated if not passed in
        tasks = get_entity_vals(deriv_path, 'task', 
                            ignore_subjects=ignore_subjects,
                            ignore_sessions=ignore_sessions
                           )
    for task in tasks:
        # get all file paths for this subject
        search_str = f'*ses-{session}*task-{task}*desc-{desc}*.json'
        deriv_fpaths = natsorted(list((deriv_path / f'sub-{subject}').glob(search_str)))

#         print(f'Found {len(deriv_fpaths)} derivative file paths')

        for idx, deriv_fpath in enumerate(deriv_fpaths):
            deriv = read_derivative_npy(deriv_fpath, preload=True, 
                                        verbose=False)

            if 'ch_axis' not in deriv.info:
                deriv.info['ch_axis'] = [0]
                rowderiv.info['ch_axis'] = [0 ]

            # create derivative structure
            onsets.append(prevlen + len(deriv))
            descriptions.append(f'ses-{session}-task-{task}-run-{idx+1}')
            prevlen += len(deriv)
    #         if derivative is None:
    #             derivative = deriv.copy()
    #             derivs.append(deriv.copy())
    #         else:
                # check that all channel names are ordered
    #                 if derivative.ch_names != deriv.ch_names:
    #                     deriv.reorder_channels(derivative.ch_names)
    #                     rowderiv.reorder_channels(derivative.ch_names)
    #                 assert derivative.ch_names == deriv.ch_names
    #                 derivative.append(deriv.copy())
            derivs.append(deriv.copy())
    return derivs, onsets, descriptions

In [9]:
def generate_deriv_list(derivs, rowderivs, derivtype='col',
                       baseline_mean=None, baseline_std=None):
    derivative = None
    new_derivs = []

    # loop through each derivative
    for deriv, rowderiv in zip(derivs, rowderivs):
        deriv.normalize()
        rowderiv.normalize()
        orig_filenames = deriv._filenames
        
        coldata = deriv.get_data()
        rowdata = rowderiv.get_data()

        if derivtype == 'col':
            data = coldata.copy()
        elif derivtype == 'row':
            data = rowdata.copy()
        elif derivtype == 'abs':
            # combine the data if wanted
            data = np.abs(coldata - rowdata)
        elif derivtype == 'prod':
            data = np.multiply(coldata, rowdata)
            
        # re-create a new derivative
        new_deriv = DerivativeArray(data, info=deriv.info, verbose=False)

        if derivative is None:
            if baseline_mean is not None:
                # subtract baseline vector from each time point
                data = data - baseline_mean[:, None]
            if baseline_std is not None:
                data = data / baseline_std[:, None]
            new_deriv = DerivativeArray(data, info=deriv.info, verbose=False)

            derivative = new_deriv.copy()
        else:
            print('Adding new data...')
            if not all([ch in deriv.ch_names for ch in derivative.ch_names]):
                # get the set difference of channels
                add_chs = list(set(derivative.ch_names) - set(deriv.ch_names))
                ch_type = derivative.get_channel_types()[0]
                info = create_deriv_info(ch_names=add_chs, sfreq=derivative.info['sfreq'], 
                                         ch_types=ch_type, 
                                         description=derivative.description,
                                        ch_axis=[0])
                addderiv = DerivativeArray(np.ones((len(add_chs), len(deriv)))*-1, 
                                           info=info, verbose=False)

                # add derivative chs
                new_deriv = new_deriv.add_channels([addderiv])

            if derivative.ch_names != new_deriv.ch_names:
                # add channels and reorder if necessary
                new_deriv.reorder_channels(derivative.ch_names)

                nonrz_inds = [idx for idx, ch in enumerate(new_deriv.ch_names) 
                             if ch not in add_chs]
                data = new_deriv.get_data()
                info = new_deriv.info
                if baseline_mean is not None:
                    data[nonrz_inds, :] = data[nonrz_inds, :] - baseline_mean[nonrz_inds, np.newaxis]
                if baseline_std is not None:
                    data[nonrz_inds, :] = data[nonrz_inds, :] / baseline_std[nonrz_inds, np.newaxis]
                
                # make sure data that was disconnected is hardcode set to nan
                rz_inds = [idx for idx in range(data.shape[0]) if idx not in nonrz_inds]
                data[rz_inds, :] = np.nan
                
                new_deriv = DerivativeArray(data, info=info, verbose=False)
            else:
                if baseline_mean is not None:
                    # subtract baseline vector from each time point
                    data = data - baseline_mean[:, None]
                if baseline_std is not None:
                    data = data / baseline_std[:, None]

                new_deriv = DerivativeArray(data, info=deriv.info, verbose=False)
            derivative.append(new_deriv) 
            
        # make sure filenames persist
        new_deriv._filenames = orig_filenames
        new_derivs.append(new_deriv)
    return new_derivs                   

In [10]:
def compute_block_bootstrap_stats(pre_deriv, post_deriv, subject, df_summ=None, threshold=None):
    # compute effect size difference with sub-sampling
    pre_blocks = pre_deriv.subsample_blocks()
    post_blocks = post_deriv.subsample_blocks()

    cohensd = []
    stats = []
    pvals = []

    for preb, postb in zip(pre_blocks, post_blocks):
        if threshold is not None:
            preb[np.abs(preb) < threshold] = np.nan
            postb[np.abs(postb) < threshold] = np.nan

        # drop any nans
        preb = preb[~np.isnan(preb)]
        postb = postb[~np.isnan(postb)]

        stat, pvalue = KSample("Dcorr").test(preb, postb)
        stats.append(stat)
        pvals.append(pvalue)

        es = pg.compute_effsize(preb, postb, 
                                paired=False, eftype='cohen')

        cohensd.append(es)
        
        if df_summ is not None:
            df_summ.append([subject, es, stat, pvalue])
    return cohensd, stats, pvals

# Define Paths and Parameters for Analysis

In [24]:
# paths to BIDS dataset / derivatives
root = Path('/Users/adam2392/OneDrive - Johns Hopkins/sickkids/')
deriv_root = root / 'derivatives'

# derivative experiment markers
reference = 'average'
deriv_chain = Path('tfr') / reference
deriv_path = deriv_root / deriv_chain

# where to save the data
figures_path = deriv_root / 'figures'

# all session to analyze
sessions = ['extraoperative', 'preresection', 
            'intraresection', 'postresection']

# the derivative ``desc`` entity description
desc = 'gamma'  # which frequency band

threshold = None
baseline = False

# Visualize Spectrograms

In [25]:
print(deriv_path)

/Users/adam2392/OneDrive - Johns Hopkins/sickkids/derivatives/tfr/average


In [26]:
subject= 'E1'
session = 'preresection'
task = 'pre'

# get all file paths for this subject
search_str = f'sub-{subject}_ses-{session}*task-{task}*desc-{desc}*.h5'

deriv_fpaths = natsorted(list(deriv_path.rglob(search_str)))

for deriv_fpath in deriv_fpaths:
    power = read_tfrs(deriv_fpath)[0]
    
    print(power)
    

<AverageTFR | time : [0.000000, 320.123120], freq : [30.000000, 90.000000], nave : 1, channels : 98, ~490.7 MB>


In [37]:
fig, ax = plt.subplots()
sns.heatmap(
#     Normalize().compute_fragilitymetric(
    10*np.log10(power.data.mean(axis=1)), ax=ax,
#     invert=True), 
            cmap='viridis', yticklabels=power.ch_names)
ax.set(
    yscale='log',
)

[None]

ValueError: math domain error

# Initialize dataframe for plotting over all subjects

In [12]:
# keep track of the dataframe summary
df = pd.DataFrame()

# Load data for all subjects

In [234]:
# get all subjects analyze
subjects = get_entity_vals(deriv_path, 'subject')

print(f'All subjects analyzed are: {subjects}')

All subjects analyzed are: ['E1', 'E3', 'E4', 'E5', 'E6', 'E7']


In [235]:
featurename = 'Col Fragility'
# featurename = 'Row Fragility'
# featurename = 'Absolute Fragility'

# cbarlabel = 'Absolute Diff Fragility'
cbarlabel = 'Col Fragility'
# cbarlabel = 'Row Fragility'

In [236]:
subj_derivlists = dict()

for subject in subjects:
    print(subject)
    # get list of sessions for subject
    ignore_subjects = [sub for sub in subjects if sub != subject]
#     sessions = get_entity_vals(deriv_path, 'session', ignore_subjects=ignore_subjects)
    sessions = ['extraoperative', 'preresection', 'intraresection', 'postresection']
    print(f'Sessions in the deriv path {sessions}')
    
    # compute the channel's mean row perturbation values during interictal awake periods
    # as a baseline
    mean_vec, std_vec = compute_baseline(subject, deriv_root, deriv_chain,
                                     task='interictalawake', 
                                         desc='rowperturbmatrix')

    # load all the column perturbation derivatives
    derivs = []
    onsets = []
    descriptions = []
    
    # keep track of data frame summary statistics
    df_summ = []

    for session in sessions:
        print(session)
        derivs_, onsets_, descrips_ = load_concat_derivs(deriv_path, subject, session, 
                                    desc='perturbmatrix')
        derivs.extend(derivs_)
        onsets.extend(onsets_)
        descriptions.extend(descrips_)
    
    # load all the row perturbation derivatives
    rowderivs = []
    for session in sessions:
        print(session)
        derivs_, _, _ = load_concat_derivs(deriv_path, subject, session, 
                                    desc='rowperturbmatrix')
        rowderivs.extend(derivs_)
    print(len(rowderivs))
    
    # create list of all the onset times 
    onsets = []
    prevonset = 0
    for deriv in derivs:
        onsets.append(len(deriv) + prevonset)
        prevonset += len(deriv)
        
    # read in the resected channels for the dataset
    bids_path = BIDSPath(subject=subject, root=root,
                     suffix='channels', extension='.tsv')
    ch_fpaths = bids_path.match()

    # read in sidecar channels.tsv
    channels_pd = pd.read_csv(ch_fpaths[0], sep='\t')
    description_chs = pd.Series(channels_pd.description.values, index=channels_pd.name).to_dict()
    resected_chs = [ch for ch, description in description_chs.items() if description == 'resected']
    resected_inds = [idx for idx, ch in enumerate(deriv.ch_names) if ch in resected_chs]
    nrz_inds = [idx for idx in range(len(deriv.ch_names)) if idx not in resected_inds]
    
    # generate concatenated list of derivatives
    if baseline:
        baseline_kwargs=dict(baseline_mean=mean_vec,
                                   baseline_std=std_vec
                            )
    else:
        baseline_kwargs = dict()
    derivlist = generate_deriv_list(derivs.copy(), rowderivs.copy(), 
                                    derivtype=derivtype,
                                   **baseline_kwargs
                                   )
    
    # only get the pre/post resection data
    pre_deriv = [deriv for deriv in derivlist if 'task-pre' in deriv.filenames[0]][0]
    post_deriv = [deriv for deriv in derivlist if 'task-post' in deriv.filenames[0]][0]

    pre_data = pre_deriv.get_data()
    post_data = post_deriv.get_data()
    print(pre_data.shape, post_data.shape)
    
    subj_derivlists[subject] = derivlist
    
    # compute block-bootstrap statistics
    cohensd, stats, pvals = compute_block_bootstrap_stats(pre_deriv, post_deriv, subject=subject,
                                                          df_summ=df_summ, threshold=threshold)
    
    # create 
    subj_df = pd.DataFrame(df_summ, columns=['subject', 'es', 'stat', 'pval'])
    if df.empty:
        df = subj_df
    else:
        df = pd.concat((df, subj_df), axis=0)
    
display(subj_df.head())
print(df.shape, subj_df.shape)

E1
Sessions in the deriv path ['extraoperative', 'preresection', 'intraresection', 'postresection']
extraoperative
preresection
intraresection
postresection
extraoperative
preresection
intraresection
postresection
9
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
(98, 1279) (98, 120)
E3
Sessions in the deriv path ['extraoperative', 'preresection', 'intraresection', 'postresection']
extraoperative
preresection
intraresection
postresection
extraoperative
preresection
intraresection
postresection
8
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
Adding new data...
(83, 485) (83, 282)
E4
Sessions in the deriv path ['extraoperative', 'preresection', 'intraresection', 'postresection']
extraoperative
preresection
intraresection
postresection
extraoperative
preresection
intraresection
postresection
7
Adding new data...
Adding 

Unnamed: 0,subject,es,stat,pval
0,E7,-0.066054,0.005525,0.0009107672
1,E7,0.329984,0.023254,5.226604e-11
2,E7,0.291601,0.019912,1.156708e-09
3,E7,0.129869,0.003109,0.01004537
4,E7,0.263963,0.017543,1.046223e-08


(600, 4) (100, 4)
