In [1]:
%load_ext autoreload

In [86]:
%autoreload 2

from hydra import compose, initialize
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm
import h5py
import numpy as np
from UniversalUnconsciousness.data_utils import *
from UniversalUnconsciousness.plot_utils import *
from UniversalUnconsciousness.power_analysis import *

plt.style.use('UniversalUnconsciousness.sci_style')
load_font()

In [3]:
with initialize(version_base="1.3", config_path="../../UniversalUnconsciousness/DeLASE_analysis/conf"):
    cfg = compose(config_name="config")

In [4]:
cfg.params.low_pass = None
cfg.params.high_pass = None

In [5]:
anesthetic_agent_list = cfg.plotting.anesthetic_agent_list
img_save_dir = cfg.plotting.img_save_dir
curve_colors = cfg.plotting.curve_colors
loc_roc_colors = cfg.plotting.loc_roc_colors
os.makedirs(img_save_dir, exist_ok=True)

In [6]:
verbose = False
agent_data = {}
for data_class, agent in tqdm(anesthetic_agent_list):
    cfg.params.data_class = data_class
    if 'propofol' in cfg.params.data_class:
        session_list = [f[:-4] for f in os.listdir(os.path.join(cfg.params.all_data_dir, 'anesthesia', 'mat', cfg.params.data_class)) if f.endswith('.mat')]
    else:
        session_list = [f[:-4] for f in os.listdir(os.path.join(cfg.params.all_data_dir, cfg.params.data_class, 'mat')) if f.endswith('.mat')]
        session_list = [session for session in session_list if session not in ['PEDRI_Ketamine_20220203']]
        session_list = [session for session in session_list if agent.lower()[:3] in session.lower()]
    
    areas = ['all']
    agent_data[(data_class, agent)] = {'session_list': session_list}

    session_lists, locs, rocs, ropaps = get_session_plot_info(cfg, session_list, verbose=False)

    agent_data[(data_class, agent)]['session_lists'] = session_lists
    agent_data[(data_class, agent)]['locs'] = locs
    agent_data[(data_class, agent)]['rocs'] = rocs
    agent_data[(data_class, agent)]['ropaps'] = ropaps

    noise_filter_info = get_noise_filter_info(cfg, session_list, verbose=verbose)
    agent_data[(data_class, agent)]['noise_filter_info'] = noise_filter_info
    
    pca_chosen = get_pca_chosen(cfg, session_list, areas, noise_filter_info, verbose=verbose)
    agent_data[(data_class, agent)]['pca_chosen'] = pca_chosen
    all_indices_to_run = collect_grid_indices_to_run(cfg, session_list, areas, noise_filter_info, pca_chosen, verbose=verbose)
    if all_indices_to_run:
        raise ValueError(f"Sessions for agent {agent} have incomplete grid search - cannot continue")
    
    grid_params_to_use = get_grid_params_to_use(cfg, session_list, areas, noise_filter_info, pca_chosen, verbose=verbose)
    agent_data[(data_class, agent)]['grid_params_to_use'] = grid_params_to_use
    all_indices_to_run = collect_delase_indices_to_run(cfg, session_list, areas, noise_filter_info, pca_chosen, grid_params_to_use, verbose=verbose)
    
    if all_indices_to_run:
        raise ValueError(f"Sessions for agent {agent} have incomplete DeLASE - cannot continue")

    delase_results = get_delase_results(cfg, session_list, areas, grid_params_to_use, pca_chosen, verbose=verbose)
    agent_data[(data_class, agent)]['delase_results'] = delase_results

  0%|          | 0/3 [00:00<?, ?it/s]

Only 0 valid windows could be found for section 'awake lever1' with times [np.float64(-36.23385444444445), np.float64(-26.050830555555557)]
Only 3 valid windows could be found for section 'early unconscious' with times [np.float64(26.003301111111107), 45]
Only 3 valid windows could be found for section 'awake lever1' with times [np.float64(-36.10079944444445), np.float64(-26.03974944444445)]
Only 0 valid windows could be found for section 'awake lever2' with times [np.float64(-10.029036111111115), 0]
Only 1 valid windows could be found for section 'awake lever1' with times [np.float64(-36.120676111111116), np.float64(-26.045618333333337)]


In [130]:
data_class, agent = ('anesthesiaLvrOdd', 'dexmedetomidine')
# session = 'SPOCK_Dexmedetomidine_20210923'
# session = 'PEDRI_Dexmedetomidine_20220310'
session = 'SPOCK_Dexmedetomidine_20210916'

# data_class, agent = ('anesthesiaLvrOdd', 'ketamine')
# session = 'SPOCK_Ketamine_20210727' # high
# session = 'SPOCK_Ketamine_20210804' # low

# data_class, agent = ('propofolPuffTone', 'propofol')
# session = "Mary-Anesthesia-20160912-02"

cfg.params.data_class = data_class
delase_results = agent_data[(data_class, agent)]['delase_results']

area = 'all'
# agent_data[(data_class, agent)]['session_lists']

In [131]:
freq_powers, freq_r2_scores = perform_power_analysis(delase_results, cfg, session, area, top_percent=0.1, verbose=True)

  0%|          | 0/1103 [00:00<?, ?it/s]

In [132]:
if 'propofol' in cfg.params.data_class:
    session_file = h5py.File(os.path.join(cfg.params.all_data_dir, 'anesthesia', 'mat', cfg.params.data_class, f'{session}.mat'), 'r')
else:
    session_file = h5py.File(os.path.join(cfg.params.all_data_dir, cfg.params.data_class, 'mat', f'{session}.mat'), 'r')

In [None]:
freq_band_to_plot = 'delta'

if 'propofol' not in cfg.params.data_class:
    infusion_start = session_file['sessionInfo']['infusionStart'][0, -1]
else:
    infusion_start = session_file['sessionInfo']['drugStart'][0, 0]
stab_means = delase_results[session][area].stability_params.apply(lambda x: x[:int(0.1*len(x))].mean())
stab_sems = delase_results[session][area].stability_params.apply(lambda x: x[:int(0.1*len(x))].std()/np.sqrt(len(x[:int(0.1*len(x))])))
time_vals = (delase_results[session][area].window_start - infusion_start)/60
plt.plot(time_vals, stab_means, color=curve_colors[agent])
plt.fill_between(time_vals, stab_means - stab_sems, stab_means + stab_sems, alpha=0.2, color=curve_colors[agent])
# Get time points where pct_correct > 0.25
if 'propofol' not in cfg.params.data_class:
    pct_correct, pct_correct_windows = get_pct_correct(cfg, session_file, lever_window=120, stride=0.1)
    high_perf_times = (pct_correct_windows - infusion_start)/60
    high_perf_mask = (pct_correct <= 0.1) & (pct_correct_windows > infusion_start)

    # Fill background for high performance periods
    plt.fill_between(high_perf_times, plt.ylim()[0], plt.ylim()[1], 
                    where=high_perf_mask,
                    color=loc_roc_colors[agent], alpha=0.1)


plt.xlabel('Time Relative to Anesthesia Start (min)')
plt.ylabel('Mean Instability ($s^{-1}$)')
ax2 = plt.gca().twinx()
ax2.plot(time_vals, freq_powers[freq_band_to_plot], color='k', alpha=0.5)
ax2.spines['right'].set_visible(True)
ax2.set_ylabel(f'{freq_band_to_plot.capitalize()} Power')
plt.title(f'{session}')
plt.show()

In [127]:
freq_r2_scores

{'delta': 0.7838465630257663,
 'theta': 0.37548644483918736,
 'alpha': 0.2518574388192333,
 'beta': 0.017730662233818917,
 'gamma': 0.15676556636864136}