In [None]:
from acr.utils import SOM_BLUE, NNXR_GRAY, CTRL_BLUE

MAIN_EXP = 'swi'
SUBJECT_TYPE = 'control'
MAIN_COLOR = CTRL_BLUE

In [None]:
import numpy as np
import matplotlib as mpl
from matplotlib.colors import ListedColormap

def modified_coolwarm_low(low="#1a9850", *, N=256, name="coolgreen"):
    """
    Return a version of 'coolwarm' where the *lower* half fades
    into `low` (instead of blue).

    Parameters
    ----------
    low : str or tuple
        Target low-end colour (e.g. "#1a9850" for green).
    N : int
        Number of discrete samples pulled from the base map (256 by default).
    name : str
        Name given to the resulting colormap object.

    Returns
    -------
    matplotlib.colors.ListedColormap
    """
    base = mpl.cm.get_cmap("coolwarm", N)           # original map, N samples
    colors = base(np.linspace(0, 1, N))             # RGBA array, shape (N,4)
    mid = N // 2                                    # neutral grey index
    lo_rgba = mpl.colors.to_rgba(low)

    # Replace the lower half [0 : mid] with a linear blend: `low` → mid-grey
    t = np.linspace(0, 1, mid + 1)[:, None]         # 0 → 1 from low to centre
    colors[:mid + 1] = (1 - t) * lo_rgba + t * colors[mid]

    return ListedColormap(colors, name=name)

In [None]:

#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import acr
import pingouin as pg
from scipy.stats import shapiro, normaltest
import numpy as np
import kdephys as kde

import warnings
warnings.filterwarnings('ignore')

from cmcrameri import cm as scm

probe_ord = ['NNXr', 'NNXo']
hue_ord = [NNXR_GRAY, MAIN_COLOR]

from acr.plots import lrg
plt.rcdefaults()
lrg()
plt.rcParams['xtick.bottom'] = False
#--------------------------------- Import Publication Functions ---------------------------------#
pu = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py', 'pub_utils')
dag = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py', 'data_agg')
from acr.utils import *

In [None]:
subjects, exps = pu.get_subject_list(type=SUBJECT_TYPE, exp=MAIN_EXP)

In [None]:
notebook_figure_root = f'{pu.PAPER_FIGURE_ROOT}/synchrony/sttc'

In [None]:
muas = {}
for subject, exp in zip(subjects, exps):
        mua = acr.mua.load_concat_peaks_df(subject, exp)
        muas[subject] = mua

In [None]:
full_hyps = {}
hyp_dicts = {}
for subject, exp in zip(subjects, exps):
    full_hyps[subject] = acr.io.load_hypno_full_exp(subject, exp)
    hyp_dicts[subject] = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp, true_stim=True, duration='3600s')

In [None]:
bmtx = {}
for subject in muas.keys():
    bmtx[subject] = {}
    for condition in ['early_bl', 'circ_bl', 'stim', 'rebound']:
        print(subject, condition)
        hypno_to_use = hyp_dicts[subject][condition]
        bmtx[subject][condition] = acr.sync.dual_probe_sttc(muas[subject], hypno_to_use, delta_ms=5)

In [None]:
for subject in muas.keys():
    hypno_to_use = acr.hypnogram_utils.get_full_bl_hypno(full_hyps[subject], state='NREM')
    bmtx[subject]['full_bl'] = acr.sync.dual_probe_sttc(muas[subject], hypno_to_use, delta_ms=5)

In [None]:
nan_out_chans = False
scd = {}
scd['ACR_37'] = [3, 15]
scd['ACR_35'] = [3]
scd['ACR_29'] = [0, 1]
conds = ['rebound', 'full_bl', 'early_bl', 'circ_bl', 'stim']
if nan_out_chans:
    for subject in scd.keys():
        for probe in ['NNXo', 'NNXr']:
            for condition in conds:
                old_dat = bmtx[subject][condition][probe]
                new_dat = acr.sync.mask_bad_channels(old_dat, scd[subject])
                bmtx[subject][condition][probe] = new_dat

In [None]:
mtx_avgs = []
for subject in bmtx.keys():
    for cond in bmtx[subject].keys():
        for probe in bmtx[subject][cond].keys():
            for i, bout_mat in enumerate(bmtx[subject][cond][probe]):
                avg = np.nanmean(bout_mat)
                mtx_avg = pd.DataFrame(
                    {
                        'condition': cond,
                        'bout_ix': i,
                        'avg': avg,
                        'probe': probe,
                        'subject': subject
                    },
                    index=[0])
                mtx_avgs.append(mtx_avg)
mtxdf = pd.concat(mtx_avgs)
mtxdf = pl.from_dataframe(mtxdf)
mtxdf = dag.relativize_df(mtxdf, 'condition', 'full_bl', 'mean', 'avg', ['subject', 'probe'])
mtxmean = mtxdf.cdn('rebound').group_by(['subject', 'probe']).agg(pl.col('avg_rel').mean())
mtxmean = mtxmean.sort(['subject', 'probe'])

In [None]:
# average across bouts (i.e. keep channel-pair info) for all conditions
cond_avgs = {}
for subject in bmtx.keys():
    cond_avgs[subject] = {}
    for cond in bmtx[subject].keys():
        cond_avgs[subject][cond] = {}
        for probe in bmtx[subject][cond].keys():
            cond_avgs[subject][cond][probe] = acr.sync.average_sttc_matrices(bmtx[subject][cond][probe])

In [None]:
# make the averaged condition-matrices all relative to the full baseline - SIMPLE RATIO
reb_rel_avgs = {}
for subject in bmtx.keys():
    reb_rel_avgs[subject] = {}
    for probe in ['NNXr', 'NNXo']:
        full_bl_mtx = cond_avgs[subject]['full_bl'][probe]
        reb_mtx = cond_avgs[subject]['rebound'][probe]
        rel_mtx = reb_mtx / full_bl_mtx
        reb_rel_avgs[subject][probe] = rel_mtx

In [None]:
# make the averaged condition-matrices all relative to the full baseline - FISHER DIFFERENCE
reb_rel_avgs = {}
for subject in bmtx.keys():
    reb_rel_avgs[subject] = {}
    for probe in ['NNXr', 'NNXo']:
        full_bl_mtx = cond_avgs[subject]['full_bl'][probe]
        reb_mtx = cond_avgs[subject]['rebound'][probe]
        eps = 1e-6
        full_bl_mtx = np.arctanh(np.clip(full_bl_mtx, -1+eps, 1-eps))
        reb_mtx = np.arctanh(np.clip(reb_mtx, -1+eps, 1-eps))
        rel_mtx = reb_mtx - full_bl_mtx
        reb_rel_avgs[subject][probe] = rel_mtx

In [None]:
#obvious outlier/artifactual pairs
#reb_rel_avgs['ACR_29']['NNXo'][5, 6] = np.nan
#reb_rel_avgs['ACR_29']['NNXr'][5, 6] = np.nan


In [None]:
plt.rcdefaults()
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = False

In [None]:
plt.rcdefaults()
acr.plots.supl()

In [None]:
# better averaging method - first average each condition across bouts into a single matrix, then make each channel-pair value relative to its own channel-pair value from the baseline. Then average all of those channel-pairs into a single number.
fig_id = 'REBOUND_STTC_all_channel_average'
fig_name = f'{SUBJECT_TYPE}__{MAIN_EXP}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

nnxr = np.array([np.nanmean(reb_rel_avgs[subject]['NNXr']) for subject in reb_rel_avgs.keys()])
nnxo = np.array([np.nanmean(reb_rel_avgs[subject]['NNXo']) for subject in reb_rel_avgs.keys()])
f, ax = acr.plots.gen_paired_boxplot(nnxr, nnxo, colors=[NNXR_GRAY, MAIN_COLOR], fsize=(3.5, 4))
print(ax.get_ylim())

ax.set_xticklabels(['Contra Control', 'Optrode'])
f.savefig(fig_path, dpi=600, bbox_inches='tight', transparent=True)

In [None]:
# =============================
# ========== STATS ============
# =============================
write = True

diffs = nnxr - nnxo
shap_stat, shap_p = shapiro(diffs) # test the paired differences for normality
print(f'shapiro_p-value: {shap_p}')

stats = pg.ttest(nnxr, nnxo, paired=True)
# stats = pg.wilcoxon(nnxr, nnxo)

hg = pg.compute_effsize(nnxr, nnxo, paired=True, eftype='hedges')
print(f'hedges g: {hg}')

#r = acr.stats.calculate_wilx_r(stats['W-val'][0], len(nnxr))


if write:
    # ==== Write Stats Results ====
    stats_name = f'{fig_name}'
    acr.stats.write_stats_result(stats_name, 'paired_ttest', stats['T'][0], stats['p-val'][0], 'g', hg)
    
    # ===== Write Source Data =====
    source_data = pd.DataFrame({'contra_control': nnxr, 'off_induction': nnxo, 'subject': np.arange(len(nnxr))})
    pu.write_source_data(source_data, stats_name)
stats

In [None]:
nnxr_norm = nnxr-nnxr
nnxo_norm = nnxo-nnxr

# better averaging method - first average each condition across bouts into a single matrix, then make each channel-pair value relative to its own channel-pair value from the baseline. Then average all of those channel-pairs into a single number.
fig_id = 'REBOUND_STTC_all_channel_average-NORMALIZED'
fig_name = f'{SUBJECT_TYPE}__{MAIN_EXP}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'


f, ax = acr.plots.gen_paired_boxplot(nnxr_norm, nnxo_norm, colors=[NNXR_GRAY, MAIN_COLOR], fsize=(3.5, 4), one_sided=True)
print(ax.get_ylim())
#ax.set_yticks([0.7, 0.8, 0.9, 1])
#ax.set_ylim(0.7, 1.01)

f.savefig(fig_path, dpi=600, bbox_inches='tight', transparent=True)

# Matrix - Main Plots 

In [None]:
from matplotlib.colors import TwoSlopeNorm

In [None]:
np.nanmin(diffs)

In [None]:
np.nanmax(diffs)

In [None]:
plt.rcdefaults()
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False

#cmap = kde.plot.main.custom_diverging_cmap(low=SOM_BLUE, high=NNXR_GRAY)
cmap = modified_coolwarm_low(low=CTRL_BLUE)
all_nnxo = [reb_rel_avgs[subject]['NNXo'] for subject in reb_rel_avgs.keys()]
all_nnxr = [reb_rel_avgs[subject]['NNXr'] for subject in reb_rel_avgs.keys()]
all_nnxo = acr.sync.average_sttc_matrices(all_nnxo)
all_nnxr = acr.sync.average_sttc_matrices(all_nnxr)

diffs = all_nnxo - all_nnxr
norm = TwoSlopeNorm(vmin=-0.11, vcenter=0, vmax=0.11)
f, ax = plt.subplots(1, 1, figsize=(5, 5))
sns.heatmap(diffs, cmap=cmap, ax=ax, norm=norm, cbar_kws={'ticks': [-0.3, 0, 0.3]})

fig_id = 'REBOUND_full_channel_map_all_subjects'
fig_name = f'{SUBJECT_TYPE}__{MAIN_EXP}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

f.savefig(fig_path, transparent=True, dpi=600, bbox_inches='tight')

In [None]:
plt.rcdefaults()
acr.plots.lrg()
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False

#cmap = kde.plot.main.custom_diverging_cmap(low=SOM_BLUE, high=NNXR_GRAY)
cmap = modified_coolwarm_low(low=SOM_BLUE)
all_nnxo = [reb_rel_avgs[subject]['NNXo'] for subject in reb_rel_avgs.keys()]
all_nnxr = [reb_rel_avgs[subject]['NNXr'] for subject in reb_rel_avgs.keys()]
all_nnxo = acr.sync.average_sttc_matrices(all_nnxo)
all_nnxr = acr.sync.average_sttc_matrices(all_nnxr)

diffs = all_nnxo - all_nnxr
norm = TwoSlopeNorm(vmin=-0.08, vcenter=0, vmax=0.08)
f, ax = plt.subplots(1, 1, figsize=(5, 5))
sns.heatmap(diffs, cmap=cmap, ax=ax, norm=norm, cbar=False)

fig_id = 'REBOUND_full_channel_map_all_subjects_NOSCALE'
fig_name = f'{SUBJECT_TYPE}__{MAIN_EXP}__{fig_id}'
fig_path = f'{notebook_figure_root}/{fig_name}.png'

f.savefig(fig_path, transparent=True, dpi=600, bbox_inches='tight')