In [None]:
#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import kdephys as kde
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import acr
import xarray as xr
from scipy import stats
# ---------------------------- EXTRAS --------------------------------#
from kdephys.plot.main import _title, bp_plot
import kdephys.utils.spectral as sp
bands = sp.bands
from scipy.stats import normaltest
import warnings
warnings.filterwarnings('ignore')
import matplotlib as mpl
from acr.utils import *
pu = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py', 'pu')
dag = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py', 'dag')
import pingouin as pg
from scipy.stats import shapiro

plt.rcdefaults()
acr.plots.supl()

In [None]:
from acr.utils import SOM_BLUE, ACR_BLUE, NNXR_GRAY
SUB_TYPE = 'SOM'
EXP_TYPES = ['offind', 'tonic']
SUBJECT_TYPE = 'som'
MAIN_COLOR = SOM_BLUE

In [None]:
subjects, exps = pu.get_subject_list(type='som', exp='swisin')

In [None]:
notebook_figure_root = f'{acr.utils.PAPER_FIGURE_ROOT}/swa/tonic_comp'

In [None]:
def read_comp_spg(subject, tag):
    path = f'{acr.utils.pub_data_root}/comp_spg-dfs/{subject}--{tag}.parquet'
    df = pl.read_parquet(path)
    return df

In [None]:
fh = {}
for subject in subjects:
    fh[f'{subject}-offind'] = acr.io.load_hypno_full_exp(subject, acr.utils.sub_swi_exps[subject][0])
    fh[f'{subject}-tonic'] = acr.io.load_hypno_full_exp(subject, acr.utils.sub_swisin_exps[subject][0])


In [None]:
# drop the known bad channels
scd = {}
scd['ACR_44'] = [9, 11, 16]

#relbp = drop_sub_channels(df, scd)

In [None]:
dfs = []
for et in EXP_TYPES:
    for subject in subjects:
        print(subject, et)
        df = read_comp_spg(subject, et)
        df = acr.hypnogram_utils.label_df_with_full_bl(df, state='NREM')
        df = df.with_columns(xtype=pl.lit(et))
        df = df.with_columns(subject=pl.lit(subject))
        dfs.append(df)

In [None]:
spgs = []
for df in dfs:
    spgs.append(df.filter((pl.col('condition') == 'rebound') | (pl.col('full_bl') == 'True')))

In [None]:
# gets the low frequencies (delta)
lows = [spg_reb.filter((pl.col('frequency') > 0.4) & (pl.col('frequency') < 3.9)) for spg_reb in spgs]

In [None]:
bps = [low.group_by(['subject','store', 'channel', 'datetime', 'xtype', 'full_bl', 'condition']).agg(pl.col('power').sum().alias('power')) for low in lows]

In [None]:
bp = pl.concat(bps)
bp = bp.sort('datetime', 'store', 'channel')

In [None]:
relbp = dag.relativize_df(bp, 'full_bl', 'True', 'mean', 'power', ['xtype', 'subject', 'store', 'channel'])

In [None]:
rebbp = relbp.filter((pl.col('condition') == 'rebound'))

In [None]:
rm = rebbp.group_by(['subject', 'xtype', 'store', 'channel']).agg(pl.col('power_rel').mean().alias('power_rel')).sort('subject', 'xtype', 'store', 'channel')

In [None]:
reb_means = pu.drop_sub_channels(rm, scd)

In [None]:
reb_means = reb_means.group_by(['subject', 'xtype', 'store']).agg(pl.col('power_rel').mean()).sort('subject', 'xtype', 'store')

In [None]:
EXP_TYPE2P = 'offind'
fig_id = f'{EXP_TYPE2P}_0.5-4Hz_rebound_all_chans'
fig_name = f'{SUBJECT_TYPE}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

nnxo = reb_means.xt('offind').prb('NNXo')['power_rel'].to_numpy()
nnxr = reb_means.xt('offind').prb('NNXr')['power_rel'].to_numpy()

nnxo_off = nnxo/nnxr

f, ax = acr.plots.gen_paired_boxplot(nnxr, nnxo)
plt.show()

f, ax = acr.plots.gen_paired_boxplot(nnxr/nnxr, nnxo/nnxr, one_sided=True)
ax.set_ylim(0.9, 1.005)
ax.set_xticklabels(['Contra. Control', 'Optrode'])
f.savefig(fig_path, dpi=600, transparent=True, bbox_inches='tight')
plt.show()

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')

stats = pg.ttest(nnxr, nnxo, paired=True)
# stats = pg.wilcoxon(nnxr, nnxo)

hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')

#r = acr.stats.calculate_wilx_r(stats['W-val'][0], len(nnxr))


if write:
    # ==== Write Stats Results ====
    stats_name = f'{fig_name}'
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pu.write_source_data(source_data, stats_name)
stats

In [None]:
EXP_TYPE2P = 'tonic'
fig_id = f'{EXP_TYPE2P}_0.5-4Hz_rebound_all_chans'
fig_name = f'{SUBJECT_TYPE}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

nnxo = reb_means.xt('tonic').prb('NNXo')['power_rel'].to_numpy()
nnxr = reb_means.xt('tonic').prb('NNXr')['power_rel'].to_numpy()

nnxo_ton = nnxo/nnxr

f, ax = acr.plots.gen_paired_boxplot(nnxr, nnxo, colors=[NNXR_GRAY, MAIN_COLOR], alphas=[0.9, 0.6], fsize=(3.5, 4))
plt.show()

f, ax = acr.plots.gen_paired_boxplot(nnxr/nnxr, nnxo/nnxr, one_sided=True, colors=[NNXR_GRAY, MAIN_COLOR], alphas=[0.9, 0.6], fsize=(3.5, 4))
ax.set_ylim(0.9, 1.005)
ax.set_xticklabels(['Contra. Control', 'Optrode'])
f.savefig(fig_path, dpi=600, transparent=True, bbox_inches='tight')
plt.show()

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')

stats = pg.ttest(nnxr, nnxo, paired=True)
# stats = pg.wilcoxon(nnxr, nnxo)

hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')

#r = acr.stats.calculate_wilx_r(stats['W-val'][0], len(nnxr))


if write:
    # ==== Write Stats Results ====
    stats_name = f'{fig_name}'
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pu.write_source_data(source_data, stats_name)
stats

In [None]:
EXP_TYPE2P = 'TONIC-vs-OFFIND'
fig_id = f'{EXP_TYPE2P}_0.5-4Hz_rebound_all_chans'
fig_name = f'{SUBJECT_TYPE}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

f, ax = acr.plots.gen_paired_boxplot(nnxo_ton, nnxo_off, colors=[MAIN_COLOR, MAIN_COLOR], alphas=[0.6, 0.9], fsize=(4, 5))
ax.set_xticklabels(['Tonic', 'OFF Induction'])
f.savefig(fig_path, dpi=600, transparent=True, bbox_inches='tight')
plt.show()

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')

stats = pg.ttest(nnxr, nnxo, paired=True)
# stats = pg.wilcoxon(nnxr, nnxo)

hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')

#r = acr.stats.calculate_wilx_r(stats['W-val'][0], len(nnxr))


if write:
    # ==== Write Stats Results ====
    stats_name = f'{fig_name}'
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pu.write_source_data(source_data, stats_name)
stats

# Full SPG plots

In [None]:
# gets the full_spgs
fulls = [spg_reb.filter((pl.col('frequency') > -1) & (pl.col('frequency') < 20)) for spg_reb in spgs]

In [None]:
spg = pl.concat(fulls)
spg = spg.sort('datetime', 'store', 'channel')

In [None]:
relbp = dag.relativize_df(spg, 'full_bl', 'True', 'mean', 'power', ['xtype', 'subject', 'store', 'channel', 'frequency'])

In [None]:
rebbp = relbp.filter((pl.col('condition') == 'rebound'))

In [None]:
rm = rebbp.group_by(['subject', 'xtype', 'store', 'channel', 'frequency']).agg(pl.col('power_rel').mean().alias('power_rel')).sort('subject', 'xtype', 'store', 'channel', 'frequency')

In [None]:
rm = pu.drop_sub_channels(rm, scd)

In [None]:
rm_probe = rm.group_by(['subject', 'xtype', 'store', 'frequency']).agg(pl.col('power_rel').mean())

In [None]:
rm_probe = rm_probe.sort('subject', 'xtype', 'store', 'frequency')

In [None]:
fig_id = 'TONIC_full_spg_rebound'
fig_name = f'{SUBJECT_TYPE}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

plt.rcdefaults()  # Reset to matplotlib defaults
acr.plots.supl()
plt.rcParams['xtick.bottom'] = True

f, ax = plt.subplots(1, 1, figsize=(4.5, 3))
sns.lineplot(data=rm_probe.xt('tonic'), x='frequency', y='power_rel', hue='store', ax=ax, palette=[NNXR_GRAY, MAIN_COLOR], hue_order=['NNXr', 'NNXo'], errorbar='se', lw=4, alpha=0.6)

ax.set_xlim(0, 20)  
ax.set_xticks([0, 5, 10, 15, 20])
ax.set_ylabel('')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylim(0.9, 1.825)
#remove the legend
ax.legend_.remove()
#ax.axhline(y=1.59, xmin=0.09, xmax=0.15)
#ax.axvspan(0.3, 2.4, ymin=0.985, ymax=1.0)
#ax.axvspan(2.6, 3.5, ymin=0.985, ymax=1.0)

plt.savefig(fig_path, dpi=600, bbox_inches='tight', transparent=True)

In [None]:
# =============================
# ===== Write Source Data =====
# =============================
dfs = []
freqs = []
for freq in rm_probe.xt('tonic')['frequency'].unique():
    freqs.append(freq)
    freq_df = rm_probe.xt('tonic').filter(pl.col('frequency') == freq).sort(['store', 'subject'])
    nnxr = freq_df.filter(pl.col('store') == 'NNXr')['power_rel'].to_numpy()
    nnxo = freq_df.filter(pl.col('store') == 'NNXo')['power_rel'].to_numpy()
    f = round(freq, 1)
    fdf = pd.DataFrame({'freq_bin': f, 'contra_control_power': nnxr, 'optrode_power': nnxo, 'subject': np.arange(len(nnxr))})
    dfs.append(fdf)
spg_source_dat = pd.concat(dfs)

In [None]:
pu.write_source_data(spg_source_dat, 'SOM_full_spg_source_data--TONIC')

In [None]:
fig_id = 'OFFIND_full_spg_rebound'
fig_name = f'{SUBJECT_TYPE}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

plt.rcdefaults()  # Reset to matplotlib defaults
acr.plots.supl()
plt.rcParams['xtick.bottom'] = True

f, ax = plt.subplots(1, 1, figsize=(4.5, 3))
sns.lineplot(data=rm_probe.xt('offind'), x='frequency', y='power_rel', hue='store', ax=ax, palette=[NNXR_GRAY, MAIN_COLOR], hue_order=['NNXr', 'NNXo'], errorbar='se', lw=4)

ax.set_xlim(0, 20)  
ax.set_xticks([0, 5, 10, 15, 20])
ax.set_ylabel('')
ax.set_xlabel('Frequency (Hz)')
#remove the legend
ax.legend_.remove()
ax.set_ylim(0.9, 1.825)
#ax.axhline(y=1.59, xmin=0.09, xmax=0.15)
#ax.axvspan(0.3, 2.4, ymin=0.985, ymax=1.0)
#ax.axvspan(2.6, 3.5, ymin=0.985, ymax=1.0)

plt.savefig(fig_path, dpi=600, bbox_inches='tight', transparent=True)

In [None]:
# =============================
# ===== Write Source Data =====
# =============================
dfs = []
freqs = []
for freq in rm_probe.xt('offind')['frequency'].unique():
    freqs.append(freq)
    freq_df = rm_probe.xt('offind').filter(pl.col('frequency') == freq).sort(['store', 'subject'])
    nnxr = freq_df.filter(pl.col('store') == 'NNXr')['power_rel'].to_numpy()
    nnxo = freq_df.filter(pl.col('store') == 'NNXo')['power_rel'].to_numpy()
    f = round(freq, 1)
    fdf = pd.DataFrame({'freq_bin': f, 'contra_control_power': nnxr, 'optrode_power': nnxo, 'subject': np.arange(len(nnxr))})
    dfs.append(fdf)
spg_source_dat = pd.concat(dfs)

In [None]:
pu.write_source_data(spg_source_dat, 'SOM_full_spg_source_data--OFFIND')