In [None]:
#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import acr
import kdephys as kde
from kdephys.utils.main import td
import warnings
import pingouin as pg
from acr.utils import SOM_BLUE, ACR_BLUE, LASER_BLUE, NNXR_GRAY, NNXO_BLUE, EMG_SLATE, BACKUP_RED
from kdephys.plot.main import base_trace_plot, base_raster

#--------------------------------- Import Publication Functions ---------------------------------#
pub_utils = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py', 'pub_utils')
from pub_utils import *
data_agg = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py', 'data_agg')
from data_agg import *

#--------------------------------- PLOTTING STUFF ---------------------------------#
plt.style.use('fast')
plt.style.use('/home/kdriessen/gh_master/acr/acr/plot_styles/acr_pub.mplstyle')
probe_ord = ['NNXr', 'NNXo']


#--------------------------------- MISCELLANEOUS ---------------------------------#
#warnings.filterwarnings('ignore')

# Control Raw Data Plots

In [None]:
subject = 'ACR_25'
exp = 'swisin'
recs = acr.info_pipeline.get_exp_recs(subject, exp)

In [None]:
fp = acr.io.load_concat_raw_data(subject, [exp])

In [None]:
mua = acr.mua.load_concat_peaks_df(subject, exp, probes=probe_ord)

In [None]:
h = acr.io.load_hypno_full_exp(subject, exp)
hd = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp)
reb_start = hd['rebound']['start_time'].min()
sd_start, stim_start, stim_end, reb_start, full_x_start = acr.info_pipeline.get_sd_exp_landmarks(subject, exp, update=False, return_early=False)

## LFPs + separate MUA

TIMES (fill in for state_start below)
--------

NREM: 18817.5
REM: 19053.5
Wake: 5791

In [None]:
state_start = 18817.5
start = acr.utils.dt_from_tdt(subject, exp, state_start)
states = kde.hypno.hypno.get_states(h, np.array([start]))
state = states[0]
state

In [None]:
for probe in ['NNXo', 'NNXr']:
    col = SOM_BLUE if probe == 'NNXo' else NNXR_GRAY
    dur = 4
    ds = fp.ts(start, start + td(dur)).prb(probe)
    d2p = ds.panda()
    ms = mua.ts(start, start + td(dur)).prb(probe)
    m2p = ms.to_pandas()

    # LFP PLOT
    g = base_trace_plot(d2p, height=0.5, aspect=50, hspace=-0.6, color=col)
    for i, ax in enumerate(g.axes.flatten()):
        #print(ax.get_ylim())
        ax.set_ylim(-1050, 1050)
    g.figure.savefig(f'{PAPER_FIGURE_ROOT}/schems/{subject}_{exp}_{probe}_RAW--LFPs--{state}.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')
    
    # Scaled LFP PLOT
    g = base_trace_plot(d2p, height=0.5, aspect=50, hspace=-0.6, color=col)
    for i, ax in enumerate(g.axes.flatten()):
        ax.set_ylim(-1050, 1050)
        if i == 7:
            ax.axhspan(-500, 500, xmin=0.9, xmax=0.95, color='red', alpha=0.5)
            ax.axvspan(start+td(1), start+td(2), color='green', alpha=0.5, linewidth=4)
    g.figure.savefig(f'{PAPER_FIGURE_ROOT}/schems/{subject}_{exp}_{probe}_RAW--LFPs--{state}__SCALED-1000mv_1sec.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')

    # MUA RASTER PLOT
    f, ax = base_raster(m2p, color=col, figsize=(20, 2.5))
    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/{subject}_{exp}_{probe}_RAW--MUA--RASTER--{state}.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')

In [None]:
acr.info_pipeline.get_channel_map(subject)

# KCSD

In [None]:
import kcsd

In [None]:
gdx = .05
ele_pos = np.arange(start=gdx, step=gdx, stop=((gdx*16)+.001))
ds = ds.assign_coords(y=("channel", ele_pos)) 
d2p = ds.panda()
ds = ds.swap_dims({'datetime': 'time'})

In [None]:
k = kcsd.KCSD1D(ds['y'].values.reshape(-1, 1), ds.transpose("channel", "time").values, gdx=0.01, sigma=1, R_init=0.23)
k.L_curve()

In [None]:
#csd = xr.zeros_like(ds)
# Double the length of the 'channel' dimension
# Create new y-coordinates with double the resolution
new_y_coords = np.linspace(ds['y'].values.min(), ds['y'].values.max(), k.values("CSD").shape[0])

# Create a new xarray with doubled channel dimension
csd = xr.DataArray(
    data=np.zeros((len(ds['time']), len(new_y_coords))),
    dims=['time', 'channel'],
    coords={
        'time': ds['time'],
        'channel': np.arange(len(new_y_coords)),
        'y': ('channel', new_y_coords)
    }
)

In [None]:
f, ax = plt.subplots(figsize=(20, 6))
csdl = xr.zeros_like(csd)
csdl.values = k.values("CSD").T
csdl.plot.imshow(x='time', y='y', cmap='rocket_r', origin='upper', vmin=-14000, vmax=14000)

In [None]:
# Apply Gaussian smoothing to the CSD data in both dimensions
from scipy.ndimage import gaussian_filter

# Create a smoothed version of the CSD data
# Sigma values control the smoothing amount (adjust as needed)
sigma_space = 8  # Smoothing in time dimension
sigma_time = 4 # Smoothing in spatial dimension
csdl_smoothed = xr.zeros_like(csdl)
csdl_smoothed.values = gaussian_filter(csdl.values, sigma=[sigma_time, sigma_space])

# Display both original and smoothed data
f, ax = plt.subplots(figsize=(20, 6))
csdl_smoothed.plot.imshow(x='time', y='y', cmap='rocket', origin='upper', ax=ax, vmin=-8000, vmax=8000)
#ax.set_title(f'Smoothed CSD (Gaussian filter) sigma_space = {sigma_space}, sigma_time = {sigma_time}', fontsize=20)
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_xticks([])
ax.set_yticks([])
# Remove the colorbar/scalebar
#plt.colorbar(ax=ax, visible=False)  # This creates but hides the colorbar
# Alternative approach: access the last created colorbar and remove it
if len(f.axes) > 1:  # If there's a colorbar (it would be an additional axis)
    cbar_ax = f.axes[-1]  # Get the last axes which should be the colorbar
    cbar_ax.remove()  # Remove the colorbar axis

