In [None]:
scd = {}
#scd['ACR_52'] = [9, 10, 11, 12, 14]
scd['ACR_50'] = [10]
scd['ACR_53'] = [8]

In [None]:
from acr.utils import SOM_BLUE, NNXR_GRAY, HALO_GREEN

MAIN_EXP = 'swisin'
SUBJECT_TYPE = 'halo'
MAIN_COLOR = HALO_GREEN

In [None]:
#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import acr
import warnings
import pingouin as pg
from scipy.stats import shapiro, normaltest

warnings.filterwarnings('ignore')
probe_ord = ['NNXr', 'NNXo']
hue_ord = [NNXR_GRAY, MAIN_COLOR]

from acr.plots import lrg, supl

plt.style.use('fast')
supl()
plt.rcParams['xtick.bottom'] = False
#--------------------------------- Import Publication Functions ---------------------------------#
pub_utils = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py', 'pub_utils')
from pub_utils import *
data_agg = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py', 'data_agg')
from data_agg import *

In [None]:
subjects, exps = get_subject_list(type=SUBJECT_TYPE, exp=MAIN_EXP)

In [None]:
full_hyps = {}
hyp_dicts = {}
for subject, exp in zip(subjects, exps):
    full_hyps[subject] = acr.io.load_hypno_full_exp(subject, exp)
    hyp_dicts[subject] = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp, true_stim=True, extra_rebounds=False)

In [None]:
#load all raw bp dfs
bp_dfs = {}
for sub, exp in zip(subjects, exps):
    bp = load_raw_bp_df(sub, exp)
    #bp = acr.hypnogram_utils.label_df_with_hypno_conditions(bp, hyp_dicts[sub])
    #bp = data_agg._label_12_hr_baseline(bp)
    bp_dfs[sub] = bp

In [None]:
reldfs = []
for subject in subjects:
    print(subject)
    subdf = bp_dfs[subject]
    subdf = acr.hypnogram_utils.label_df_with_hypno_conditions(subdf, hyp_dicts[subject])
    print('Done conditions')
    subdf = acr.hypnogram_utils.label_df_with_states(subdf, full_hyps[subject])
    subdf = acr.hypnogram_utils.label_df_with_full_bl(subdf)
    subdf = make_raw_bp_df_relative_to_baseline(subdf, col_to_use='full_bl', value_to_use='True', state_to_use='NREM', method='mean')
    print('Done states/bl')
    reldfs.append(subdf)
relbp = pl.concat(reldfs)

In [None]:
# drop the known bad channels
relbp = drop_sub_channels(relbp, scd)
rebbp = relbp.filter(pl.col('condition') == 'rebound')

In [None]:
# quick data check to make sure all remaining channels are good
g = sns.catplot(data=relbp.filter(pl.col('band') == 'delta').filter(pl.col('condition')=='rebound').to_pandas(), 
                x='channel', 
                y='bandpower_rel', 
                hue='store', 
                kind='box',
                row='subject',
                row_order=subjects,
                showfliers=False,
                showmeans=True, meanprops={"marker": "_", 'markerfacecolor':'gold','markeredgecolor':'gold','markersize':'24', "markeredgewidth": "3"},
                height=10,
                aspect=3,
                hue_order=probe_ord, 
                palette=hue_ord
                )

#for ax in g.axes.flat:
    #ax.set_ylim(0, 5)

In [None]:
reb_avgs = rebbp.group_by(['subject', 'store', 'band']).mean().to_pandas()
reb_avgs = sub_regions_to_df(reb_avgs)

# SWA, simple average of all channels

In [None]:
plt.rcdefaults()
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = False

In [None]:
acr.plots.gen_paired_boxplot?

In [None]:
# DEFINE THE DATA
# --------------------------------------------
BAND = 'delta'
SUBJECT_GROUP = 'ALL_SUBJECTS'
CHANNEL_GROUP = 'ALL_CHANNELS'
FIG_ID = f'{BAND}_rebound_{MAIN_EXP}'
figure_name = f'{MAIN_EXP}--{SUBJECT_TYPE.upper()}--{SUBJECT_GROUP}--{CHANNEL_GROUP}--{FIG_ID}.png'

#---------------- filter the data ----------------#
data = reb_avgs.query('band == @BAND')
data = data.sort_values(by=['subject', 'store'])
nnxr = data.query('store == "NNXr"')['bandpower_rel'].values
nnxo = data.query('store == "NNXo"')['bandpower_rel'].values

# PLOT AND SAVE
# ---------------------------------------------
f, ax = acr.plots.gen_paired_boxplot(nnxr, nnxo, colors=[NNXR_GRAY, MAIN_COLOR])
print(ax.get_yticks())
ax.set_yticks([1.5, 1.7, 1.9, 2.1])
print(ax.get_ylim())
ax.set_xlim(0.30, 0.70) # x limits
ax.set_xticklabels(['Contra Control', 'Optrode'])

plt.savefig(f'{PAPER_FIGURE_ROOT}/swa/{figure_name}', dpi=600, bbox_inches='tight', transparent=True)

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
#agostino_stat, agostino_p = normaltest(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')
#print(f'd,agostino_p-value: {agostino_p}')
stats = pg.ttest(nnxr, nnxo, paired=True)
hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')
if write:
    # ===== Write Stats =====
    stats_name = figure_name
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pub_utils.write_source_data(source_data, stats_name)
stats

In [None]:
# DEFINE THE DATA
# --------------------------------------------
BAND = 'delta'
SUBJECT_GROUP = 'ALL_SUBJECTS'
CHANNEL_GROUP = 'ALL_CHANNELS'
FIG_ID = f'{BAND}_rebound_{MAIN_EXP}_NORMALIZED'
figure_name = f'{MAIN_EXP}--{SUBJECT_TYPE.upper()}--{SUBJECT_GROUP}--{CHANNEL_GROUP}--{FIG_ID}.png'

#---------------- filter the data ----------------#
data = reb_avgs.query('band == @BAND')
data = data.sort_values(by=['subject', 'store'])
nnxr = data.query('store == "NNXr"')['bandpower_rel'].values
nnxo = data.query('store == "NNXo"')['bandpower_rel'].values

nnxr_norm = nnxr/nnxr
nnxo_norm = nnxo/nnxr

# PLOT AND SAVE
# ---------------------------------------------
f, ax = acr.plots.gen_paired_boxplot(nnxr_norm, nnxo_norm, colors=[NNXR_GRAY, MAIN_COLOR], fsize=(4.5, 5.5), one_sided=True)

ax.set_xlim(0.30, 0.70) # x limits
ax.set_xticklabels(['Contra Control', 'Optrode'])
#ax.set_ylim(0.95, 1.05)
ax.set_yticks([0.95, 1.0, 1.05, 1.1, 1.15])

plt.savefig(f'{PAPER_FIGURE_ROOT}/swa/{figure_name}', dpi=600, bbox_inches='tight', transparent=True)

# Full Spectrograms

In [None]:
reb_spgs = {}
for subject, exp in zip(subjects, exps):
    if subject in scd:
        to_drop = scd[subject]
    else:
        to_drop=None
    reb_spgs[subject] = create_reb_spg(subject, exp, drop_chans=to_drop)
full_df = pd.concat(reb_spgs.values())

In [None]:
full_means = full_df.groupby(['store', 'frequency']).mean(numeric_only=True).reset_index()

In [None]:
from statsmodels.stats.multitest import multipletests
# statistical-tests for all frequencies
t_tests = {}
freqs = []
p_vals = []
for freq in full_df['frequency'].unique():
    freqs.append(freq)
    freq_df = full_df.loc[full_df['frequency'] == freq].sort_values(['store', 'subject'])
    nnxr = freq_df.query('store == "NNXr"')['power_rel'].values
    nnxo = freq_df.query('store == "NNXo"')['power_rel'].values
    
    # will use wilcoxon across all frequencies to avoid differing normality assumptions across different frequencies, etc
    stat = pg.wilcoxon(nnxr, nnxo)
    p_vals.append(stat['p-val'][0])
    print(freq, stat['p-val'][0])

In [None]:
# Correct for multiple Comparisons using Benjamini–Hochberg False discovery rate correction 
rej, p_fdr, _, _ = multipletests(p_vals, alpha=0.05, method='fdr_bh')

for p_cor, freq in zip(p_fdr, freqs):
    if p_cor<0.05:
        print(round(freq, 1), round(p_cor, 3))
    #print(round(freq, 1), round(p_cor, 3))


In [None]:
# =============================
# ===== Write Source Data =====
# =============================
dfs = []
for freq in full_df['frequency'].unique():
    freqs.append(freq)
    freq_df = full_df.loc[full_df['frequency'] == freq].sort_values(['store', 'subject'])
    nnxr = freq_df.query('store == "NNXr"')['power_rel'].values
    nnxo = freq_df.query('store == "NNXo"')['power_rel'].values
    f = round(freq, 1)
    fdf = pd.DataFrame({'freq_bin': f, 'contra_control_power': nnxr, 'optrode_power': nnxo, 'subject': np.arange(len(nnxr))})
    dfs.append(fdf)
spg_source_dat = pd.concat(dfs)
pub_utils.write_source_data(spg_source_dat, 'HALO_full_spg_source_data--TONIC')

In [None]:
# DEFINE THE DATA
# --------------------------------------------
SUBJECT_GROUP = 'ALL_SUBJECTS'
CHANNEL_GROUP = 'ALL_CHANNELS'
FIG_ID = f'FULL__SPG__REBOUND'
figure_name = f'{MAIN_EXP}--{SUBJECT_TYPE.upper()}--{SUBJECT_GROUP}--{CHANNEL_GROUP}--{FIG_ID}.png'
# --------------------------------------------------------------------------------------------------

plt.rcdefaults()  # Reset to matplotlib defaults
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = True

f, ax = plt.subplots(1, 1, figsize=(4.5, 3))
sns.lineplot(data=full_df, x='frequency', y='power_rel', hue='store', ax=ax, palette=[NNXR_GRAY, MAIN_COLOR], hue_order=probe_ord, errorbar='se', lw=4)

ax.set_xlim(0, 20)  
ax.set_xticks([0, 5, 10, 15, 20])
ax.set_yticks([1, 1.5, 2])
#ax.set_ylabel('')

#remove the legend
ax.legend_.remove()
#ax.axhline(y=1.59, xmin=0.09, xmax=0.15)
#ax.axvspan(0.3, 2.4, ymin=0.985, ymax=1.0)
#ax.axvspan(2.6, 3.5, ymin=0.985, ymax=1.0)

plt.savefig(f'{PAPER_FIGURE_ROOT}/swa/{figure_name}', dpi=600, bbox_inches='tight', transparent=True)

# During Inhibition

In [None]:
nbroot = f'{acr.utils.PAPER_FIGURE_ROOT}/swa'

In [None]:
plt.rcdefaults()
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = False

In [None]:
delt = relbp.filter(pl.col('band')=='delta')
delt_rel = data_agg.relativize_df(delt, ref_to_col='condition', ref_to_val='early_sd', avg_method='mean', col_to_relativize='bandpower', on=['subject', 'store', 'channel'])

In [None]:
conds = ['early_sd','late_sd', 'stim']
sdf = delt_rel.filter(pl.col('condition').is_in(conds)).filter(pl.col('band')=='delta')
sdf = sdf.filter(~((pl.col('condition').is_in(['early_sd', 'late_sd']))&(pl.col('bandpower_rel') > 1.5)))
sdf = sdf.filter(pl.col('channel') >= 9) #using only the deep channels here to avoid the inevitable omnetics-caused artifact on the superficial channels during periods of intense movement
stimdf = sdf.filter(pl.col('condition') == 'early_sd')
stimmeans = stimdf.group_by(['subject', 'store']).agg(pl.col('bandpower_rel').mean()).sort(['subject', 'store'])

In [None]:
# try to clearn out as much obvious artifact from the wake data as possible here
conds = ['early_sd','late_sd', 'stim']
sdf = relbp.filter(pl.col('condition').is_in(conds))
sdf = sdf.filter(~((pl.col('condition').is_in(['early_sd', 'late_sd']))&(pl.col('bandpower_rel') > 10)))
sdf = sdf.filter(~((pl.col('store') == 'NNXr')&(pl.col('bandpower_rel') > 10)))
sdf = sdf.filter(~((pl.col('store') == 'NNXo')&(pl.col('bandpower_rel') > 10)))
sdfm = sdf.group_by(['condition','subject', 'store', 'channel']).agg(pl.col('bandpower_rel').mean())

In [None]:
for sub in subjects:
    subdf = sdfm.filter(pl.col('subject') == sub)
    f, ax = plt.subplots(2, 1, figsize=(18, 6))    
    sns.barplot(data=subdf.prb('NNXo').to_pandas(), x='channel', y='bandpower_rel', hue='condition', hue_order=conds, ax=ax[0], palette=['gray', NNXR_GRAY, MAIN_COLOR])
    sns.barplot(data=subdf.prb('NNXr').to_pandas(), x='channel', y='bandpower_rel', hue='condition', hue_order=conds, ax=ax[1], palette=['gray', NNXR_GRAY, 'black'])
    ax[0].set_title(f'{sub}-NNXo')
    ax[1].set_title(f'{sub}-NNXr')
    ax[0].axhline(y=1, color='red', linestyle='--')
    ax[1].axhline(y=1, color='red', linestyle='--')
    ax[0].set_ylim(0, 3)
    ax[1].set_ylim(0, 3)
    ax[0].legend().remove()
    ax[1].legend().remove()
    plt.show()

In [None]:
stimdf = sdf.filter(pl.col('condition') == 'stim')
stimmeans = stimdf.group_by(['subject', 'store']).agg(pl.col('bandpower_rel').mean()).sort(['subject', 'store'])

In [None]:
fig_id = 'delta_during_stim'
fig_name = f'{SUBJECT_TYPE}__{MAIN_EXP}__{fig_id}'
fig_path = f'{nbroot}/{fig_name}.png'

nnxo = stimmeans.filter(pl.col('store') == 'NNXo')['bandpower_rel'].to_numpy()
nnxr = stimmeans.filter(pl.col('store') == 'NNXr')['bandpower_rel'].to_numpy()
f, ax = acr.plots.gen_paired_boxplot(nnxr, nnxo, colors=[NNXR_GRAY, MAIN_COLOR])
ax.set_xticklabels(['Contra. Control', 'Optrode'])
plt.show()
f, ax = acr.plots.gen_paired_boxplot(nnxr/nnxr, nnxo/nnxr, colors=[NNXR_GRAY, MAIN_COLOR], one_sided=True)
ax.set_xticklabels(['Contra. Control', 'Optrode'])
ax.set_ylim(0.87, 2.5)
ax.set_yticks([1, 1.5, 2, 2.5])
print(ax.get_ylim())
plt.show()

f.savefig(fig_path, dpi=600, bbox_inches='tight', transparent=True)

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')

stats = pg.ttest(nnxr, nnxo, paired=True)
# stats = pg.wilcoxon(nnxr, nnxo)

hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')

#r = acr.stats.calculate_wilx_r(stats['W-val'][0], len(nnxr))


if write:
    # ==== Write Stats Results ====
    stats_name = f'{fig_name}'
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pub_utils.write_source_data(source_data, stats_name)
stats