In [None]:
import os
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import xarray as xr
path = "/Volumes/opto_loc/Data/ACR_39/swi-bl-NNXo.nc"
da = xr.open_dataarray(path)

In [None]:
import pubplots as pp

from acr.utils import NNXR_GRAY, NREM_RED, PAPER_FIGURE_ROOT, SOM_BLUE, HALO_GREEN

style_path = "/Users/driessen2@ad.wisc.edu/kdriessen/acr_dev/acr/src/acr/plot_styles/acrvec_labels.mplstyle"


# -------------------- ADJUST HERE --------------------
import os
from pathlib import Path
nbroot = os.path.join(PAPER_FIGURE_ROOT, "response_to_review", "tfr")
Path(nbroot).mkdir(parents=True, exist_ok=True)

In [None]:
from pathlib import Path

import pingouin as pg
from scipy.stats import shapiro

%reload_ext autoreload
%autoreload 2

import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import kdephys as kde
import acr

warnings.filterwarnings("ignore")

In [None]:
from matplotlib.colors import LogNorm
from matplotlib.colors import TwoSlopeNorm

In [None]:
pub_utils = acr.utils.import_publication_functions(
    "/Users/driessen2@ad.wisc.edu/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py",
    "pub_utils",
)
import pub_utils as pu

data_agg = acr.utils.import_publication_functions(
    "/Users/driessen2@ad.wisc.edu/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py",
    "data_agg",
)
import data_agg as dag

In [None]:
from pub_utils import get_event_data_stacks

In [None]:
from wavelet_tfr import *
from wavetf_utils import *
import wavetf_utils as wu

In [None]:
import zarr

In [None]:
fs = 400.23053278688525

In [None]:
from acr.utils import SOM_BLUE, ACR_BLUE, NNXR_GRAY

MAIN_EXP = 'swi'
SUBJECT_TYPE = 'acr'
MAIN_COLOR = ACR_BLUE

In [None]:
subjects, exps = pu.get_subject_list(type=SUBJECT_TYPE, exp=MAIN_EXP)
nbroot = os.path.join(PAPER_FIGURE_ROOT, "response_to_review", "tfr", 'acr')
Path(nbroot).mkdir(parents=True, exist_ok=True)

# Wavelet Bands

In [None]:
ss, se, pon, poff, ton, toff = acr.stim.get_all_stim_info(subject, exp, trn_idx=True)

In [None]:
hf = {}
hd = {}
for subject, exp in zip(subjects, exps):
    hf[subject] = acr.io.load_hypno_full_exp(subject, exp)
    hd[subject] = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp)

In [None]:
wave = {}
for subject, exp in zip(subjects, exps):
    wave[subject] = {}
    for probe in ['NNXo', 'NNXr']:
        wave[subject][probe] = {}
        for cond in ['stim', 'cbl']:
            wave[subject][probe][cond] = wu.load_wavelets(subject, probe, cond)

In [None]:
bands = {
    'theta': (4, 8),
    'sigma': (9, 16),
    'beta': (16, 30),
    'gamma': (30, 80),
}

In [None]:
cbl_vals = {}
cbl_dt = {}
for subject, exp in zip(subjects, exps):
    cbl_vals[subject] = {}
    cbl_dt[subject] = {}
    for probe in ['NNXo']:
        cbl_dt[subject][probe] = wave[subject][probe]['cbl'].datetime.values
        cbl_vals[subject][probe] = {}
        for band in bands:
            print(subject, probe, band)
            cbl_vals[subject][probe][band] = wave[subject][probe]['cbl'].lfp.sel(frequency=slice(bands[band][0], bands[band][1])).sum(dim='frequency').values

In [None]:
stim_vals = {}
stim_dt = {}
for subject, exp in zip(subjects, exps):
    stim_vals[subject] = {}
    stim_dt[subject] = {}
    for probe in ['NNXo']:
        stim_dt[subject][probe] = wave[subject][probe]['stim'].datetime.values
        stim_vals[subject][probe] = {}
        for band in bands:
            print(subject, probe, band)
            stim_vals[subject][probe][band] = wave[subject][probe]['stim'].lfp.sel(frequency=slice(bands[band][0], bands[band][1])).sum(dim='frequency').values

In [None]:
pons = {}
for subject, exp in zip(subjects, exps):
    ss, se, pon, poff, ton, toff = acr.stim.get_all_stim_info(subject, exp, trn_idx=True)
    pons[subject] = pon

In [None]:
cbl_means = {}
for subject, exp in zip(subjects, exps):
    cbl_means[subject] = {}
    for probe in ['NNXo', 'NNXr']:
        cbl_means[subject][probe] = wu.load_avgs('cbl_avgs', subject, probe)

In [None]:
lsd_means = {}
for subject, exp in zip(subjects, exps):
    lsd_means[subject] = {}
    for probe in ['NNXo', 'NNXr']:
        lsd_means[subject][probe] = wu.load_avgs('sd_avgs', subject, probe)

In [None]:
oodfs = {}
for subject, exp in zip(subjects, exps):
    oodf = dag.compute_hybrid_off_df(subject, exp, chan_threshold=12)
    oodf = acr.oo_utils.enhance_oodf(oodf, hf[subject], hd[subject])
    dur_group_borders = np.arange(0.05, 0.30, 0.01)

    start = 0.055
    stop = 0.355
    bin_w = 0.010

    oodf = oodf.with_columns(
        pl.when((pl.col("duration") >= start) & (pl.col("duration") < stop + bin_w))
        .then(((pl.col("duration") - start) / bin_w).floor().cast(pl.Int32))
        .otherwise(pl.lit(-1, dtype=pl.Int32))
        .alias("dgroup")
    )
    oodf = oodf.with_columns(((pl.col('dgroup')*10)+60).alias('dg'))
    oodf = oodf.with_columns(
        ((pl.col("start_datetime").shift(-1) - pl.col("end_datetime"))
        .dt.total_milliseconds() / 1000)
        .alias("off_int")
    )
    oodfs[subject] = oodf

In [None]:
oodf_full = pl.concat(oodfs.values())

oodf_full.write_parquet('/Users/driessen2@ad.wisc.edu/kdriessen/acr_dev/acr_revs/src_dat/oodfs/all_acr.parquet')

In [None]:
sub_stims = {}
sub_cbls = {}
for subject in subjects:
    print(subject)
    oodf = oodfs[subject]
    natural_on_starts = oodf.filter(pl.col('condition')=='circ_bl').filter(pl.col('off_int')>0.350).prb('NNXo')['end_datetime'].to_numpy()
    natural_on_starts = natural_on_starts[:-1]
    induced_on_starts = pons[subject][:1800] + np.timedelta64(180, 'ms')
    
    stim_time = stim_dt[subject]['NNXo']
    cbl_time = cbl_dt[subject]['NNXo']
    freqs = wave[subject]['NNXo']['cbl'].frequency.values
    stim_stack = {}
    for band in bands.keys():
        stim_array = stim_vals[subject]['NNXo'][band]
        ev_duration = 0
        dur_before = 0.0
        dur_after = 0.350
        stack = wu.get_tfr_stacks(stim_array.T, stim_time, induced_on_starts, ev_duration, dur_before, dur_after, fs)
        band_ix1 = bands[band][0]
        band_ix2 = bands[band][1]
        band_indices = np.where((freqs >= band_ix1) & (freqs <= band_ix2))[0]
        band_ref = lsd_means[subject]['NNXo'][:, band_indices].sum(axis=1)
        stim_stack[band] = stack / band_ref[np.newaxis, np.newaxis, :]
        sub_stims[subject] = stim_stack
    cbl_stack = {}
    for band in bands.keys():
        cbl_array = cbl_vals[subject]['NNXo'][band]
        ev_duration = 0
        dur_before = 0.0
        dur_after = 0.350
        stack = wu.get_tfr_stacks(cbl_array.T, cbl_time, natural_on_starts, ev_duration, dur_before, dur_after, fs)
        band_ix1 = bands[band][0]
        band_ix2 = bands[band][1]
        band_indices = np.where((freqs >= band_ix1) & (freqs <= band_ix2))[0]
        band_ref = cbl_means[subject]['NNXo'][:, band_indices].sum(axis=1)
        cbl_stack[band] = stack / band_ref[np.newaxis, np.newaxis, :]
        sub_cbls[subject] = cbl_stack

In [None]:
sub_cbl = {}
sub_stim = {}
for band in bands.keys():
    stim_list = []
    cbl_list = []
    for subject in subjects:
        stim_list.append(np.mean(sub_stims[subject][band][:, :, :], axis=(0, 2)))
        cbl_list.append(np.mean(sub_cbls[subject][band][:, :, :], axis=(0, 2)))
    sub_cbl[band] = np.stack(cbl_list)
    sub_stim[band] = np.stack(stim_list)

In [None]:

band_colors = {
    'theta': 'red',
    'sigma': 'orange',
    'beta': 'blue',
    'gamma': 'green',
}

f, ax = plt.subplots(2, 1, figsize=(12, 8))
for band in bands.keys():
    if band == 'gamma' or band == 'beta':
        continue
    cbl_sems = sub_cbl[band].std(axis=0)[:100] / np.sqrt(sub_cbl[band].shape[0])
    stim_sems = sub_stim[band].std(axis=0)[:100] / np.sqrt(sub_stim[band].shape[0])
    ax[0].fill_between(np.arange(100), sub_cbl[band].mean(axis=0)[:100]-cbl_sems, sub_cbl[band].mean(axis=0)[:100]+cbl_sems, color=band_colors[band], alpha=0.3)
    ax[1].fill_between(np.arange(100), sub_stim[band].mean(axis=0)[:100]-stim_sems, sub_stim[band].mean(axis=0)[:100]+stim_sems, color=band_colors[band], alpha=0.3)
    ax[0].plot(sub_cbl[band].mean(axis=0)[:100], color=band_colors[band])
    ax[1].plot(sub_stim[band].mean(axis=0)[:100], color=band_colors[band])

# Bandpass-Filtered Data

In [None]:
lfpbl = {}
for subject, exp in zip(subjects, exps):
    lbl = acr.io.load_raw_data(subject, f'{exp}-bl', store='NNXo')
    start = lbl.datetime.values.min()
    end = start+pd.Timedelta('12h')
    lfpbl[subject] = lbl.sel(datetime=slice(start, end))

In [None]:
lfpx = {}
for subject, exp in zip(subjects, exps):
    lfpxx = acr.io.load_raw_data(subject, f'{exp}', store='NNXo')
    se = pons[subject][-1]+pd.Timedelta('30m')
    start = lfpxx.datetime.values.min()
    lfpx[subject] = lfpxx.sel(datetime=slice(start, se))

In [None]:
bp_bands = {
    'theta': (4, 8),
    'sigma': (9, 16),
    'beta': (16, 30),
    'gamma': (30, 80),
}

In [None]:
bpbl = {}
bpx = {}
for subject in subjects:
    bpx[subject] = {}
    bpbl[subject] = {}
    for band in bp_bands.keys():
        print(subject, band)
        bpbl[subject][band] = kde.xr.spectral.bandpass_filter_raw_data(lfpbl[subject], bp_bands[band])
        bpx[subject][band] = kde.xr.spectral.bandpass_filter_raw_data(lfpx[subject], bp_bands[band])

In [None]:
dur_before = 0.0
dur_after = 0.500

In [None]:
band_stacks_exp = {}
for subject in subjects:
    #if acr.utils.sub_probe_locations[subject] != 'frontal':
    #    continue
    band_stacks_exp[subject] = {}
    induced_on_starts = pons[subject][:600] + np.timedelta64(180, 'ms')
    dt_vals = bpx[subject]['theta'].datetime.values
    for band in bp_bands.keys():
        print(subject, band)
        band_vals = bpx[subject][band].values
        band_stack = wu.find_indices_and_slice_array(band_vals, dt_vals, induced_on_starts, int(fs*dur_before), int(fs*dur_after))
        band_stacks_exp[subject][band] = band_stack

In [None]:
band_stacks_bl = {}
for subject in subjects:
    #if acr.utils.sub_probe_locations[subject] != 'frontal':
    #    continue
    oodf = oodfs[subject]
    
    band_stacks_bl[subject] = {}
    natural_on_starts = oodf.cdn('circ_bl').prb('NNXo').filter(pl.col('off_int')>0.500)['end_datetime'].to_numpy()
    dt_vals = bpbl[subject]['theta'].datetime.values
    for band in bp_bands.keys():
        print(subject, band)
        band_vals = bpbl[subject][band].values
        band_stack = wu.find_indices_and_slice_array(band_vals, dt_vals, natural_on_starts, int(fs*dur_before), int(fs*dur_after))
        band_stacks_bl[subject][band] = band_stack 

In [None]:
sub_means_exp = {}
for band in bp_bands.keys():
    band_list = []
    for subject in band_stacks_exp.keys():
        if subject == 'ACR_16':
            continue
        band_list.append(np.mean(band_stacks_exp[subject][band][:, :, :], axis=(0)))
    sub_means_exp[band] = np.stack(band_list)

In [None]:
sub_means_bl = {}
for band in bp_bands.keys():
    band_list = []
    for subject in band_stacks_bl.keys():
        if subject == 'ACR_16':
            continue
        band_list.append(np.mean(band_stacks_bl[subject][band][:, :, :], axis=(0)))
    sub_means_bl[band] = np.stack(band_list)

In [None]:
end = (sigma_mean.shape[0]*(1/fs))
time_vec = np.linspace(0, end, sigma_mean.shape[0])

In [None]:
sigma_mean = sub_means_exp['sigma'].mean(axis=0)
sigma_sems = sub_means_exp['sigma'].std(axis=0) / np.sqrt(sub_means_exp['sigma'].shape[0])
f, ax = kde.plot.main.atomic_lfp(sigma_mean.T, times=time_vec, sems=sigma_sems.T, figsize=(3, 6), color=MAIN_COLOR)
fig_name = 'sigma_on_period_induced.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

In [None]:
sigma_mean = sub_means_bl['sigma'].mean(axis=0)
sigma_sems = sub_means_bl['sigma'].std(axis=0) / np.sqrt(sub_means_bl['sigma'].shape[0])
f, ax = kde.plot.main.atomic_lfp(sigma_mean.T, times=time_vec, sems=sigma_sems.T, figsize=(3, 6), color=MAIN_COLOR)
fig_name = 'sigma_on_period_cbl.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

In [None]:
theta_mean = sub_means_exp['theta'].mean(axis=0)
theta_sems = sub_means_exp['theta'].std(axis=0) / np.sqrt(sub_means_exp['theta'].shape[0])
f, ax = kde.plot.main.atomic_lfp(theta_mean.T, times=time_vec, sems=theta_sems.T, figsize=(3, 6), color=MAIN_COLOR)
for a in ax:
    a.set_ylim((np.float64(-80.04596654386401), np.float64(50.43603376894117)))
fig_name = 'theta_on_period_induced.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

In [None]:
theta_mean = sub_means_bl['theta'].mean(axis=0)
theta_sems = sub_means_bl['theta'].std(axis=0) / np.sqrt(sub_means_bl['theta'].shape[0])
f, ax = kde.plot.main.atomic_lfp(theta_mean.T, times=time_vec, sems=theta_sems.T, figsize=(3, 6), color=MAIN_COLOR)
fig_name = 'theta_on_period_cbl.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

In [None]:
lfp_stacks_exp = []
for subject in subjects:
    #if acr.utils.sub_probe_locations[subject] != 'frontal':
    #    continue
    if subject == 'ACR_16':
        continue
    induced_on_starts = pons[subject][:600] + np.timedelta64(180, 'ms')
    
    dt_vals = lfpx[subject].datetime.values
    
    print(subject)
    
    lfp_vals = lfpx[subject].values
    lfp_stack = wu.find_indices_and_slice_array(lfp_vals, dt_vals, induced_on_starts, int(fs*dur_before), int(fs*dur_after))
    lfp_stacks_exp.append(lfp_stack.mean(axis=0))
lfp_stacks_exp = np.stack(lfp_stacks_exp)

In [None]:
lfp_stacks_bl = []
for subject in subjects:
    #if acr.utils.sub_probe_locations[subject] != 'frontal':
    #    continue
    if subject == 'ACR_16':
        continue
    oodf = oodfs[subject]
    
    natural_on_starts = oodf.cdn('circ_bl').prb('NNXo').filter(pl.col('off_int')>0.500)['end_datetime'].to_numpy()
    dt_vals = lfpbl[subject].datetime.values
    
    print(subject, len(natural_on_starts))
    
    lfp_vals = lfpbl[subject].values
    lfp_stack = wu.find_indices_and_slice_array(lfp_vals, dt_vals, natural_on_starts, int(fs*dur_before), int(fs*dur_after))
    lfp_stacks_bl.append(lfp_stack.mean(axis=0))
lfp_stacks_bl = np.stack(lfp_stacks_bl)

In [None]:
lfp_mean = lfp_stacks_bl.mean(axis=0)
lfp_sems = lfp_stacks_bl.std(axis=0) / np.sqrt(lfp_stacks_bl.shape[0])
f, ax = kde.plot.main.atomic_lfp(lfp_mean.T, times=time_vec, sems=lfp_sems.T, figsize=(3, 6), color=MAIN_COLOR)
fig_name = 'LFP_on_period_cbl.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

In [None]:
lfp_mean = lfp_stacks_exp.mean(axis=0)
lfp_sems = lfp_stacks_exp.std(axis=0) / np.sqrt(lfp_stacks_exp.shape[0])
f, ax = kde.plot.main.atomic_lfp(lfp_mean.T, times=time_vec, sems=lfp_sems.T, figsize=(3, 6), color=MAIN_COLOR)
fig_name = 'LFP_on_period_exp.svg'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')

ax[8].hlines(10, 0.2, 0.3, color='green', linewidth=3)
ax[8].vlines(0.4, -50, 50, color='red', linewidth=3)
plt.show()
fig_name = f'LABELLED__{fig_name}'
f.savefig(os.path.join(nbroot, fig_name), transparent=True, dpi=300, bbox_inches='tight')