In [None]:
import os
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import xarray as xr
path = "/Volumes/opto_loc/Data/ACR_39/swi-bl-NNXo.nc"
da = xr.open_dataarray(path)

In [None]:
import pubplots as pp

from acr.utils import NNXR_GRAY, NREM_RED, PAPER_FIGURE_ROOT, SOM_BLUE, HALO_GREEN

style_path = "/Users/driessen2@ad.wisc.edu/kdriessen/acr_dev/acr/src/acr/plot_styles/acrvec_labels.mplstyle"


# -------------------- ADJUST HERE --------------------
import os
from pathlib import Path
nbroot = os.path.join(PAPER_FIGURE_ROOT, "response_to_review", "tfr")
Path(nbroot).mkdir(parents=True, exist_ok=True)

In [None]:
from pathlib import Path

import pingouin as pg
from scipy.stats import shapiro

%reload_ext autoreload
%autoreload 2

import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import kdephys as kde
import acr

warnings.filterwarnings("ignore")

In [None]:
from matplotlib.colors import LogNorm
from matplotlib.colors import TwoSlopeNorm

In [None]:
pub_utils = acr.utils.import_publication_functions(
    "/Users/driessen2@ad.wisc.edu/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py",
    "pub_utils",
)
import pub_utils as pu

data_agg = acr.utils.import_publication_functions(
    "/Users/driessen2@ad.wisc.edu/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py",
    "data_agg",
)
import data_agg as dag

In [None]:
from pub_utils import get_event_data_stacks

In [None]:
from wavelet_tfr import *
from wavetf_utils import *
import wavetf_utils as wu

In [None]:
import zarr

In [None]:
fs = 400.23053278688525

In [None]:
from acr.utils import SOM_BLUE, ACR_BLUE, NNXR_GRAY

MAIN_EXP = 'swi'
SUBJECT_TYPE = 'som'
MAIN_COLOR = SOM_BLUE

In [None]:
subjects, exps = pu.get_subject_list(type=SUBJECT_TYPE, exp=MAIN_EXP)
nbroot = os.path.join(PAPER_FIGURE_ROOT, "response_to_review", "tfr")
Path(nbroot).mkdir(parents=True, exist_ok=True)

# Single Subject Examples - ACR_41

In [None]:
ss, se, pon, poff, ton, toff = acr.stim.get_all_stim_info(subject, exp, trn_idx=True)

In [None]:
hf = {}
hd = {}
for subject, exp in zip(subjects, exps):
    hf[subject] = acr.io.load_hypno_full_exp(subject, exp)
    hd[subject] = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp)

In [None]:
wave = {}

for subject, exp in zip(subjects, exps):
    wave[subject] = {}
    for probe in ['NNXo', 'NNXr']:
        wave[subject][probe] = {}
        for cond in ['stim', 'ebl', 'cbl']:
            wave[subject][probe][cond] = wu.load_wavelets(subject, probe, cond)

In [None]:
cbl_means = {}
for probe in ['NNXo', 'NNXr']:
    cbl_means[probe] = wu.load_avgs('cbl_avgs', subject, probe)

In [None]:
lsd_means = {}
for subject, exp in zip(subjects, exps):
    lsd_means[subject] = {}
    for probe in ['NNXo', 'NNXr']:
        lsd_means[subject][probe] = wu.load_avgs('sd_avgs', subject, probe)

In [None]:
tixs = [0, 1, 2]
tix_durs = [180, 140, 100]

In [None]:
dur_before = 0.25
dur_after = 0.25

stacks = {}
for subject, exp in zip(subjects, exps):
    stacks[subject] = {}
    ss, se, pon, poff, ton, toff = acr.stim.get_all_stim_info(subject, exp, trn_idx=True)
    for probe in ['NNXo']:
        stacks[subject][probe] = {}
        for cond in ['stim']:
            print(cond, probe)
            array = wave[subject][probe][cond].lfp.values
            array = array.swapaxes(0, 2)
            time_array = wave[subject][probe][cond].datetime.values
            for tix, durg in zip(tixs, tix_durs):
                ev_duration = durg/1000
                off_starts = pon[ton[tix]:toff[tix]]
                print(len(off_starts))
                stack = wu.get_tfr_stacks(array, time_array, off_starts, ev_duration, dur_before, dur_after, fs)
                stacks[subject][probe][tix] = wu.rel_tfr(stack, lsd_means[subject][probe])

In [None]:
stacks[subject]['NNXo'][0].shape

In [None]:
all_sub_avgs = {}
for tix in tixs:
    tix_list = []
    for subject in subjects:
        sub_data = stacks[subject]['NNXo'][tix]
        tix_list.append(np.mean(sub_data, axis=(0)))
    all_sub_avgs[tix] = np.stack(tix_list)
    

In [None]:
f, ax = plt.subplots(1, 3, figsize=(24, 8))
for tix, durg in zip(tixs, tix_durs):
    pulse_dur = durg / 1000
    ax[tix], im = wu.plot_tfr(all_sub_avgs[tix].mean(axis=0), ax[tix], vmin=0.5, vcenter=1, vline1=dur_before, vline2=dur_before+pulse_dur, vmax=6)
plt.show()

In [None]:
for subject, exp in zip(subjects, exps):
    f, ax = plt.subplots(1, 3, figsize=(24, 8))
    for i in tixs:
        ax[i], im = wu.plot_tfr(stacks[subject]['NNXo'][i], ax[i], vline1=dur_before, vmax=6)
    f.suptitle(subject)
    plt.show()

# OFF Periods

In [None]:
oodfs = {}
for subject, exp in zip(subjects, exps):
    oodf = dag.compute_hybrid_off_df(subject, exp, chan_threshold=12)
    oodf = acr.oo_utils.enhance_oodf(oodf, hf[subject], hd[subject])
    dur_group_borders = np.arange(0.05, 0.30, 0.01)

    start = 0.055
    stop = 0.355
    bin_w = 0.010

    oodf = oodf.with_columns(
        pl.when((pl.col("duration") >= start) & (pl.col("duration") < stop + bin_w))
        .then(((pl.col("duration") - start) / bin_w).floor().cast(pl.Int32))
        .otherwise(pl.lit(-1, dtype=pl.Int32))
        .alias("dgroup")
    )
    oodf = oodf.with_columns(((pl.col('dgroup')*10)+60).alias('dg'))
    oodf = oodf.with_columns(
        ((pl.col("start_datetime").shift(-1) - pl.col("end_datetime"))
        .dt.total_milliseconds() / 1000)
        .alias("off_int")
    )
    oodfs[subject] = oodf

In [None]:
dur_before = 0.3
dur_after = 0.500

durg = 120
ev_duration = durg/1000
subject = 'ACR_29'
oodf = oodfs[subject]
stacks = {}
for probe in ['NNXo']:
    stacks[probe] = {}
    for cond in ['ebl']:
        print(cond, probe)
        array = wave[subject][probe][cond].lfp.values
        array = array.swapaxes(0, 2)
        time_array = wave[subject][probe][cond].datetime.values
        off_starts = oodf.prb(probe).cdn('early_bl').filter(
            (pl.col("status") == "off") & 
            (pl.col("dg") == durg)
        )['start_datetime'].to_numpy()
        print(len(off_starts))
        stack = wu.get_tfr_stacks(array, time_array, off_starts, ev_duration, dur_before, dur_after, fs)
        stacks[probe][cond] = wu.rel_tfr(stack, cbl_means[probe])

In [None]:
f, ax = plt.subplots(1, 1, figsize=(8, 8))
ax, im = wu.plot_tfr(stacks['NNXo']['ebl'], ax, vline1=dur_before)

In [None]:
subject = 'ACR_29'
cond = 'cbl'
probe = 'NNXo'
array = wave[subject][probe][cond].lfp.values
array = array.swapaxes(0, 2)
time_array = wave[subject][probe][cond].datetime.values

In [None]:
dur_before = 0.2
dur_after = 0.2
stacks = {}
for durg in [80, 100, 140, 180]:
    ev_duration = durg/1000
    off_starts = oodf.prb(probe).cdn('circ_bl').filter(
        (pl.col("status") == "off") & 
        (pl.col("dg") == durg)
    )['start_datetime'].to_numpy()
    print(len(off_starts))
    stack = wu.get_tfr_stacks(array, time_array, off_starts, ev_duration, dur_before, dur_after, fs)
    stacks[durg] = wu.rel_tfr(stack, cbl_means[probe])

In [None]:
for i, durg in enumerate([80, 100, 140, 180]):
    with pp.destination('figma', style=style_path):
        off_dur = durg/1000
        f, ax = plt.subplots(1, 1, figsize=pp.scale(1.0, 1.8))
        ax, im = wu.plot_tfr(stacks[durg], ax, vmin=0.2, vmax=6, 
                             #vline1=dur_before, vline2=dur_before+off_dur,
                             )
        ax.set_xticklabels([None])
        ax.set_yticklabels([None])
        #f.suptitle(f'{durg}ms OFFs | {stacks[durg].shape[0]} OFFs')
        fname = f'{durg}ms--early_bl--SPECTRO.png'
        #f.savefig(os.path.join(nbroot, fname), transparent=True, bbox_inches='tight', dpi=300)
        plt.show()

In [None]:
subject = 'ACR_29'
cond = 'stim'
probe = 'NNXo'
array = wave[subject][probe][cond].lfp.values
array = array.swapaxes(0, 2)
time_array = wave[subject][probe][cond].datetime.values

In [None]:
ss, se, pon, poff, ton, toff = acr.stim.get_all_stim_info(subject, exp, trn_idx=True)

In [None]:
dur_before = 0.2
dur_after = 0.2
stacks = {}
for i, durg in enumerate([180, 140, 100, 80]):
    ev_duration = durg/1000
    off_starts = pon[ton[i]:toff[i]]
    print(len(off_starts))
    stack = wu.get_tfr_stacks(array, time_array, off_starts, ev_duration, dur_before, dur_after, fs)
    stacks[durg] = wu.rel_tfr(stack, lsd_means[subject][probe])

In [None]:
ax.get_ylim()

In [None]:
for i, durg in enumerate([180, 140, 100, 80]):
    with pp.destination('figma', style=style_path):
        off_dur = durg/1000
        f, ax = plt.subplots(1, 1, figsize=pp.scale(1.0, 1.8))
        ax, im = wu.plot_tfr(stacks[durg], ax, vmin=0.5, vmax=6, 
                             #vline1=dur_before, vline2=dur_before+off_dur,
                             )
        ax.set_xticklabels([None])
        #ax.set_yticklabels([None])
        #f.suptitle(f'{durg}ms OFFs | {stacks[durg].shape[0]} OFFs')
        fname = f'{durg}ms--STIM--SPECTRO.png'
        #f.savefig(os.path.join(nbroot, fname), transparent=True, bbox_inches='tight', dpi=300)
        plt.show()

In [None]:
with pp.destination('figma', style=style_path):
    fig, ax = plt.subplots(1, 1, figsize=pp.scale(1.495, 0.3))
    gradient = np.linspace(0, 1, 256).reshape(1, -1)
    ax.imshow(gradient, aspect='auto', cmap=scm.vik)
    ax.set_axis_off()
    #plt.tight_layout()
    ax.set_position([0, 0, 1, 1])  # fill entire figure
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.show()
    fig_path = os.path.join(nbroot, f'vik_colorbar.svg')
    fig.savefig(fig_path, transparent=True, bbox_inches='tight')

# MISC

In [None]:
f, ax = plt.subplots(1, 4, figsize=(24, 6))
for i, durg in enumerate([80, 100, 140, 180]):
    ax[i], im = wu.plot_tfr(stacks[durg], ax[i], vline1=dur_before, vmax=4)
plt.show()

In [None]:
f, ax = plt.subplots(1, 4, figsize=(24, 6))
for i, durg in enumerate([80, 100, 140, 180]):
    ax[i], im = wu.plot_tfr(stacks[durg], ax[i], vline1=dur_before, vmax=4)
plt.show()

In [None]:
reb_vals = {}
reb_dt = {}
for probe in ['NNXo', 'NNXr']:
    reb_vals[probe] = wave[subject][probe]['reb'].lfp.values
    reb_dt[probe] = wave[subject][probe]['reb'].datetime.values

In [None]:
ebl_vals = {}
ebl_dt = {}
for probe in ['NNXo', 'NNXr']:
    ebl_vals[probe] = wave[subject][probe]['ebl'].lfp.values
    ebl_dt[probe] = wave[subject][probe]['ebl'].datetime.values

In [None]:
cbl_vals = {}
cbl_dt = {}
for probe in ['NNXo', 'NNXr']:
    cbl_vals[probe] = wave[subject][probe]['cbl'].lfp.values
    cbl_dt[probe] = wave[subject][probe]['cbl'].datetime.values

In [None]:
lsd_vals = {}
lsd_dt = {}
for probe in ['NNXo', 'NNXr']:
    lsd_vals[probe] = wave[subject][probe]['lsd'].lfp.values
    lsd_dt[probe] = wave[subject][probe]['lsd'].datetime.values
esd_vals = {}
esd_dt = {}
for probe in ['NNXo', 'NNXr']:
    esd_vals[probe] = wave[subject][probe]['esd'].lfp.values
    esd_dt[probe] = wave[subject][probe]['esd'].datetime.values

In [None]:
vals = {}
vals['circ_bl'] = cbl_vals
vals['rebound'] = reb_vals
vals['early_bl'] = ebl_vals
vals['lsd'] = lsd_vals
vals['esd'] = esd_vals
dts = {}
dts['circ_bl'] = cbl_dt
dts['rebound'] = reb_dt
dts['early_bl'] = ebl_dt
dts['lsd'] = lsd_dt
dts['esd'] = esd_dt

In [None]:
dur_before = 0.2
dur_after = 0.350

ev_duration = 0

stacks = {}
for probe in ['NNXo', 'NNXr']:
    stacks[probe] = {}
    for cond in vals.keys():
        print(cond, probe)
        array = vals[cond][probe]
        array = array.swapaxes(0, 2)
        time_array = dts[cond][probe]
        off_ends = oodf.prb(probe).cdn(cond).filter(
            (pl.col("status") == "off") & 
            (pl.col("status").shift(-1) == "on") & 
            (pl.col("duration").shift(-1) >= 0.400)
        )['end_datetime'].to_numpy()
        stack = wu.get_tfr_stacks(array, time_array, off_ends, ev_duration, dur_before, dur_after, fs)
        stacks[probe][cond] = wu.rel_tfr(stack, cbl_means[probe])

In [None]:
dur_before = 0
dur_after = 0.50
ev_duration = 0
for probe in ['NNXo', 'NNXr']:
    for cond in ['lsd', 'esd']:
        fake_evs = np.arange(dts[cond][probe][0], dts[cond][probe][-1], step=np.timedelta64(520, 'ms'))
        fake_evs = fake_evs[2:len(fake_evs)-2]
        array = vals[cond][probe]
        array = array.swapaxes(0, 2)
        time_array = dts[cond][probe]
        stack = wu.get_tfr_stacks(array, time_array, fake_evs, ev_duration, dur_before, dur_after, fs)
        stacks[probe][cond] = wu.rel_tfr(stack, cbl_means[probe])   

In [None]:
dur_before = 0.35
dur_after = 0.15

ev_duration = 0

stacks = {}
for probe in ['NNXo', 'NNXr']:
    stacks[probe] = {}
    for cond in vals.keys():
        print(cond, probe)
        array = vals[cond][probe]
        array = array.swapaxes(0, 2)
        time_array = dts[cond][probe]
        off_ends = oodf.prb(probe).cdn(cond).filter(
            (pl.col("status") == "off") & 
            (pl.col("status").shift(1) == "on") & 
            (pl.col("duration").shift(1) >= 0.400)
        )['start_datetime'].to_numpy()
        stack = wu.get_tfr_stacks(array, time_array, off_ends, ev_duration, dur_before, dur_after, fs)
        stacks[probe][cond] = wu.rel_tfr(stack, cbl_means[probe])

In [None]:
dur_before = 0.200
dur_after = 0.200

durg = 200
ev_duration = durg/1000

stacks = {}
for probe in ['NNXo', 'NNXr']:
    stacks[probe] = {}
    for cond in vals.keys():
        print(cond, probe)
        array = vals[cond][probe]
        array = array.swapaxes(0, 2)
        time_array = dts[cond][probe]
        off_ends = oodf.prb(probe).cdn(cond).filter(
            (pl.col("status") == "off") & 
            (pl.col("dg") == durg)
        )['end_datetime'].to_numpy()
        print(len(off_ends))
        stack = wu.get_tfr_stacks(array, time_array, off_ends, ev_duration, dur_before, dur_after, fs)
        stacks[probe][cond] = wu.rel_tfr(stack, cbl_means[probe])

In [None]:
f, ax = plt.subplots(1, 3, figsize=(24, 8))
ax[0], im = wu.plot_tfr(stacks['NNXo']['early_bl'], ax[0], vline1=dur_before)
ax[1], im = wu.plot_tfr(stacks['NNXo']['circ_bl'], ax[1], vline1=dur_before)
ax[2], im = wu.plot_tfr(stacks['NNXo']['rebound'], ax[2], vline1=dur_before)

In [None]:
f, ax = plt.subplots(1, 3, figsize=(24, 8))
ax[0], im = wu.plot_tfr(stacks['NNXr']['early_bl'], ax[0], vline1=dur_before)
ax[1], im = wu.plot_tfr(stacks['NNXr']['circ_bl'], ax[1], vline1=dur_before)
ax[2], im = wu.plot_tfr(stacks['NNXr']['rebound'], ax[2], vline1=dur_before)

In [None]:
difcb = stacks['NNXo']['circ_bl'].mean(axis=0) - stacks['NNXr']['circ_bl'].mean(axis=0)
difebl = stacks['NNXo']['early_bl'].mean(axis=0) - stacks['NNXr']['early_bl'].mean(axis=0)
difreb = stacks['NNXo']['rebound'].mean(axis=0) - stacks['NNXr']['rebound'].mean(axis=0)

In [None]:
f, ax = plt.subplots(1, 3, figsize=(24, 8))
ax[0], im = wu.plot_tfr(difebl, ax[0], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)
ax[1], im = wu.plot_tfr(difcb, ax[1], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)
ax[2], im = wu.plot_tfr(difreb, ax[2], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)

In [None]:
f, ax = plt.subplots(1, 3, figsize=(24, 8))
ax[0], im = wu.plot_tfr(difebl, ax[0], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)
ax[1], im = wu.plot_tfr(difcb, ax[1], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)
ax[2], im = wu.plot_tfr(difreb, ax[2], vline1=dur_before, vcenter=0, vmin=-0.5, vmax=0.5)

In [None]:
difeb = stacks['NNXr']['early_bl'].mean(axis=0) - stacks['NNXr']['circ_bl'].mean(axis=0)
difreb = stacks['NNXr']['rebound'].mean(axis=0) - stacks['NNXr']['circ_bl'].mean(axis=0)

In [None]:
f, ax = plt.subplots(1, 2, figsize=(16, 8))
ax[0], im = wu.plot_tfr(difeb, ax[0], vline1=dur_before, vcenter=0, vmin=-0.1, vmax=1)
ax[1], im = wu.plot_tfr(difreb, ax[1], vline1=dur_before, vcenter=0, vmin=-0.01, vmax=1)

In [None]:
difeb = stacks['NNXo']['early_bl'].mean(axis=0) - stacks['NNXo']['circ_bl'].mean(axis=0)
difreb = stacks['NNXo']['rebound'].mean(axis=0) - stacks['NNXo']['circ_bl'].mean(axis=0)

In [None]:
f, ax = plt.subplots(1, 2, figsize=(16, 8))
ax[0], im = wu.plot_tfr(difeb, ax[0], vline1=dur_before, vcenter=0, vmin=-.1, vmax=1)
ax[1], im = wu.plot_tfr(difreb, ax[1], vline1=dur_before, vcenter=0, vmin=-.1, vmax=1)

## SD

In [None]:
f, ax = plt.subplots(1, 2, figsize=(16, 8))
ax[0], im = wu.plot_tfr(stacks['NNXo']['esd'][10:], ax[0], vline1=dur_before)
ax[1], im = wu.plot_tfr(stacks['NNXr']['esd'][10:], ax[1], vline1=dur_before)
ax[0].set_title('Early Sleep Dep')
plt.show()

f, ax = plt.subplots(1, 2, figsize=(16, 8))
ax[0], im = wu.plot_tfr(stacks['NNXo']['lsd'][10:], ax[0], vline1=dur_before)
ax[1], im = wu.plot_tfr(stacks['NNXr']['lsd'][10:], ax[1], vline1=dur_before)
ax[0].set_title('Late Sleep Dep')
plt.show()

In [None]:
wavo = wu.rel_tfr(reb_stacks['NNXo'], cbl_means['NNXo'])
wavr = wu.rel_tfr(reb_stacks['NNXr'], cbl_means['NNXr'])

In [None]:
wavdiff = wavo/wavr

In [None]:
wdiff = wavo-wavr

In [None]:
wdiff.min()

In [None]:
for chan in range(16):
    f, ax = plt.subplots(figsize=(8, 8))
    ax, im = wu.plot_tfr(wdiff[chan], ax, cmap=scm.vik, vline1=dur_before, vline2=dur_before+ev_duration, vmin=-0.7, vmax=1.2, vcenter=0)


In [None]:
for chan in range(16):
    f, ax = plt.subplots(figsize=(8, 8))
    ax, im = wu.plot_tfr(wavdiff[chan], ax, cmap=scm.vik, vline1=dur_before, vline2=dur_before+ev_duration, vmin=0.6, vmax=1.8, vcenter=1)


In [None]:
f, ax = plt.subplots(1, 2, figsize=(16, 8))
ax[0], im = wu.plot_tfr(wavr[9], ax[0], cmap=scm.vik, vline1=dur_before, vline2=dur_before+ev_duration)
ax[1], im = wu.plot_tfr(wavo[9], ax[1], cmap=scm.vik, vline1=dur_before, vline2=dur_before+ev_duration)

In [None]:
f, ax = plt.subplots(figsize=(10, 10))
ax, im = plot_tfr(tfr[5], ax, vmax=2.5, vline1=dur_before, vline2=dur_before+(durg/1000), cmap=scm.vik)