# Load Data

In [None]:
#-------------------------- Standard Imports --------------------------#
%reload_ext autoreload
%autoreload 2
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import acr
import kdephys as kde
from kdephys.utils.main import td
import warnings
import pingouin as pg
from acr.utils import SOM_BLUE, ACR_BLUE, LASER_BLUE, NNXR_GRAY, NNXO_BLUE, EMG_SLATE, BACKUP_RED
import scipy
import kcsd
from acr.plots import pub, lrg
#--------------------------------- Import Publication Functions ---------------------------------#
pub_utils = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/pub_utils.py', 'pub_utils')
from pub_utils import *
data_agg = acr.utils.import_publication_functions('/home/kdriessen/gh_master/PUBLICATION__ACR/data_agg.py', 'data_agg')
from data_agg import *

#--------------------------------- PLOTTING STUFF ---------------------------------#
plt.style.use('fast')
#plt.style.use('/home/kdriessen/gh_master/acr/acr/plot_styles/acr_pub.mplstyle')
probe_ord = ['NNXr', 'NNXo']


#--------------------------------- MISCELLANEOUS ---------------------------------#
#warnings.filterwarnings('ignore')

In [None]:
subject = 'ACR_19'
exp = 'swi'
recs = acr.info_pipeline.get_exp_recs(subject, exp)

In [None]:
fp = acr.io.load_concat_raw_data(subject, [exp])

In [None]:
mua = acr.mua.load_concat_peaks_df(subject, exp, probes=probe_ord)

In [None]:
h = acr.io.load_hypno_full_exp(subject, exp)
hd = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp)
reb_start = hd['rebound']['start_time'].min()
sd_start, stim_start, stim_end, reb_start, full_x_start = acr.info_pipeline.get_sd_exp_landmarks(subject, exp, update=False, return_early=False)

# Evoked FPs + Firing Rate

In [None]:
pon, poff = acr.stim.get_individual_pulse_times(subject, exp)
ton, toff = acr.stim.get_pulse_train_times(pon, poff)

# Take the first 200 pulses at 80 ms per pulse
#ons = pon[2703:2903]
#offs = poff[2703:2903]

In [None]:
acceptable_epocs = [
    (2153, 2176),
    (2218, 2227),
    (2233, 2240),
    (2264, 2275),
    (2285, 2297),
    (2412, 2421),
]
epocs_dt = []
for epoc in acceptable_epocs:
    dt_start = acr.utils.dt_from_tdt(subject, rec=exp, tdt_time=epoc[0])
    dt_end = acr.utils.dt_from_tdt(subject, rec=exp, tdt_time=epoc[1])
    epocs_dt.append((dt_start, dt_end))

In [None]:
# Initialize lists to store pulse times that fall within the acceptable epochs
filtered_pon = []
filtered_poff = []

# For each epoch, find all pulses that fall within it
for start_dt, end_dt in epocs_dt:
    # Find indices of pulses that fall within this epoch
    in_epoch_indices = [i for i, t in enumerate(pon) if start_dt <= t <= end_dt]
    
    # Add the corresponding pulse on/off times to our filtered lists
    for idx in in_epoch_indices:
        filtered_pon.append(pon[idx])
        filtered_poff.append(poff[idx])

ons = filtered_pon[:200]
offs = filtered_poff[:200]

In [None]:
pub()
for probe_to_plot in probe_ord:
    opt = fp.prb(probe_to_plot).ts(pon[0], poff[-1])
    stim_vals = []
    for on, off in zip(ons, offs): 
        t1 = on - td(0.08)
        t2 = off + td(0.08)
        dat = opt.ts(t1, t2).values
        if dat.shape[0] != 96:
            dat = dat[:96, :]
        stim_vals.append(dat)
    stim_dat = np.array(stim_vals)
    if probe_to_plot == 'NNXo':
        stim_dat[:, :, 7] = stim_dat[:, :, [5, 6, 8, 9]].mean(axis=2)
    if probe_to_plot == 'NNXr':
        stim_dat[:, :, 6] = stim_dat[:, :, [4, 5, 7, 8]].mean(axis=2)
    stim_means = stim_dat.mean(axis=0)

    stim_stds = stim_dat.std(axis=0)
    stim_sems = stim_stds / np.sqrt(stim_dat.shape[0])
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    plt.rcParams['axes.grid'] = False
    plt.rcParams['xtick.major.size'] = 0
    plt.rcParams['ytick.major.size'] = 0
    plt.rcParams['figure.facecolor'] = 'white'
    plt.rcParams['axes.facecolor'] = 'None'

    col = ACR_BLUE if probe_to_plot == 'NNXo' else NNXR_GRAY

    for ch in range(16):
        f, ax = plt.subplots(figsize=(8, 4))
        line = ax.plot(stim_means[:, ch], color=col, linewidth=4)
        plt.fill_between(range(len(stim_means[:, ch])), 
                        stim_means[:, ch] - stim_sems[:, ch],#*2.576, 
                        stim_means[:, ch] + stim_sems[:, ch],#*2.576, 
                        alpha=0.3, color=line[0].get_color())
        plt.tight_layout()
        #print(ax.get_ylim())
        ax.set_ylim(-820, 405)
        f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{probe_to_plot}--evoked_fp_CH{ch}.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')
    for ch in [0, 15]:
        f, ax = plt.subplots(figsize=(8, 4))
        line = ax.plot(stim_means[:, ch], color=col, linewidth=4)
        plt.fill_between(range(len(stim_means[:, ch])), 
                        stim_means[:, ch] - stim_sems[:, ch]*2.576, 
                        stim_means[:, ch] - stim_sems[:, ch]*2.576, 
                        stim_means[:, ch] + stim_sems[:, ch]*2.576, 
                        alpha=0.3, color=line[0].get_color())
        plt.tight_layout()
        ax.set_ylim(-820, 405)
        ax.axvline(32, color='red', linestyle='--')
        ax.axvline(65, color='red', linestyle='--')
        ax.axvspan(5, 45, ymin=0.4, ymax=0.6, color='red', alpha=0.2)
        ax.axhspan(-250, 250, xmin=0.92, xmax=0.98, color='green', alpha=0.2)
        f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{probe_to_plot}--evoked_fp_CH{ch}_optrode--SCALED_100ms_500mV_WithStim.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')

    # -------------------------------------------------------- KCSD -------------------------------------------------------- #
    gdx = .05
    ele_pos = np.arange(start=gdx, step=gdx, stop=((gdx*16)+.001))
    k = kcsd.KCSD1D(ele_pos.reshape(-1, 1), stim_means.T, gdx=0.01, sigma=1, R_init=0.23)
    k.L_curve()

    new_y_coords = np.linspace(ele_pos.min(), ele_pos.max(), k.values("CSD").shape[1])

    # Create a new xarray with doubled channel dimension
    csd = xr.DataArray(
        data=np.zeros((len(k.values("CSD")), len(new_y_coords))),
        dims=['channel', 'time'],
        coords={
            'channel': np.arange(len(k.values("CSD"))),
            'time': np.arange(len(new_y_coords)),
        }
    )
    f, ax = plt.subplots(figsize=(6, 10))
    csdl = xr.zeros_like(csd)
    csdl.values = k.values("CSD")
    ax.imshow(csdl.values, cmap='PiYG', origin='upper', alpha=0.8, aspect=(10/6))  # Setting a smaller aspect ratio to make the plot taller
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])


    # Remove the colorbar/scalebar
    if len(f.axes) > 1:  # If there's a colorbar (it would be an additional axis)
        cbar_ax = f.axes[-1]  # Get the last axes which should be the colorbar
        cbar_ax.remove()  # Remove the colorbar axis
    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{probe_to_plot}--evoked_fp--KCSD_CSD.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')

In [None]:
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))


def plot_color_gradients(name='PiYG'):
    # Create figure and adjust figure height to number of colormaps
    fig, ax = plt.subplots(figsize=(10, 1))
    fig.subplots_adjust(top=0.99, bottom=0.01,
                        left=0.01, right=0.99)

    
    ax.imshow(gradient, aspect='auto', cmap=mpl.colormaps[name])

    # Turn off *all* ticks & spines, not just the ones with colormaps.
    ax.set_axis_off()
    plt.tight_layout()
    return f, ax
f, ax = plot_color_gradients(name='PiYG')

# Firing Rate

In [None]:
plt.style.use('fast')
plt.style.use('/home/kdriessen/gh_master/acr/acr/plot_styles/acr_pub.mplstyle')
plt.rcParams['xtick.bottom'] = True

In [None]:
# Don't change this code
ons = pd.to_datetime(ons)
offs = pd.to_datetime(offs)
mua_short = mua.ts(stim_start, reb_start)
dto = mua_short.prb('NNXo')['datetime'].to_numpy()
dtr = mua_short.prb('NNXr')['datetime'].to_numpy()
befores = []
pulses = []
afters = []
for on, off in zip(ons, offs):
    befores.append((on-td(.08), on))
    pulses.append((on, off))
    afters.append((off, off+td(.08)))

In [None]:
for prb in probe_ord:
    pub()
    col = ACR_BLUE if prb == 'NNXo' else NNXR_GRAY
    dtprb = mua_short.prb(prb)['datetime'].to_numpy()
    dtprb = pd.to_datetime(dtprb)
    #slice up dtprb based on befores, pulses, afters
    befores_dtprb = []
    pulses_dtprb = []
    afters_dtprb = []
    for b, p, a in zip(befores, pulses, afters):
        # Convert datetime objects to indices by finding the closest matching times in dtprb
        b_start_idx = np.searchsorted(dtprb, b[0])
        b_end_idx = np.searchsorted(dtprb, b[1])
        p_start_idx = np.searchsorted(dtprb, p[0])
        p_end_idx = np.searchsorted(dtprb, p[1])
        a_start_idx = np.searchsorted(dtprb, a[0])
        a_end_idx = np.searchsorted(dtprb, a[1])
        
        # Use integer indices for slicing
        befores_dtprb.append(dtprb[b_start_idx:b_end_idx])
        pulses_dtprb.append(dtprb[p_start_idx:p_end_idx])
        afters_dtprb.append(dtprb[a_start_idx:a_end_idx])
    spike_df_final = pd.DataFrame()
    for trl in range(200):
        before_spikes = befores_dtprb[trl]
        pulse_spikes = pulses_dtprb[trl]
        after_spikes = afters_dtprb[trl]
        
        
        full_spikes = pd.DatetimeIndex(np.concatenate([
            before_spikes,  # First trial's before spikes
            pulse_spikes,   # First trial's pulse spikes
            after_spikes    # First trial's after spikes
        ]))
        
        full_offsets = (full_spikes - full_spikes[0]).total_seconds()
        full_offsets = full_offsets*1000
        
        spike_vals = np.repeat([trl+1], (len(before_spikes) + len(pulse_spikes) + len(after_spikes)))
        
        spike_df = pd.DataFrame({'spike_times': full_offsets, 'spike_vals': spike_vals})
        spike_df_final = pd.concat([spike_df_final, spike_df], ignore_index=True)
     
    f, ax = plt.subplots(figsize=(4, 3.5))
    alpha = 0.9 if prb=='NNXo' else 0.7
    sns.scatterplot(data=spike_df_final, x='spike_times', y='spike_vals', ax=ax, color=col, s=1.5, alpha=alpha)
    ax.set_xticks([0, 50, 100, 150, 200])
    ax.set_yticks([0, 50, 100, 150, 200])
    ax.set_xlim(-1, 240)
    ax.set_ylim(0, 201)

    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{prb}--200Trials--spike_raster.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')
    ax.axvspan(81, 160, color='red', alpha=0.2)
    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{prb}--200Trials--spike_raster--STIM_LABELLED.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')
    
    # -------------------------------------------------------- Spike Count by Time -------------------------------------------------------- #
    # ------------------------------------------------------------------------------------------------------------------------------------- #
    time_group_starts = np.arange(0, 240, 10)
    time_group_ends = np.arange(10, 250, 10)
    time_group_labels = np.arange(5, 236, 10)
    time_group_labels.shape == time_group_starts.shape == time_group_ends.shape
    # Create time_group column based on which bin spike_times falls into
    spike_df_final['time_period'] = pd.cut(spike_df_final['spike_times'], 
                                        bins=np.concatenate([time_group_starts, [time_group_ends[-1]]]),
                                        labels=time_group_labels)
    tg_counts = spike_df_final.groupby(['time_period', 'spike_vals']).count().reset_index()
    tg_counts['spike_times'] = tg_counts['spike_times'] / tg_counts.groupby('time_period').mean()['spike_times'].max() # Normalize to max value
    tg_means = tg_counts.groupby('time_period').mean()
    tg_sems = tg_counts.groupby('time_period').sem()
    means = tg_means['spike_times'].values
    sems = tg_sems['spike_times'].values
    tps = tg_means.index.values

    lrg()
    f, ax = plt.subplots(figsize=(8, 2.5))
    plt.errorbar(tps,
                means,
                yerr=sems,
                color=col, 
                linewidth=0, 
                marker='o', 
                markersize=0, 
                markerfacecolor=col, 
                markeredgecolor=col,
                capsize=4,
                capthick=3.5, 
                elinewidth=3.5, 
                ecolor=col, 
                alpha=0.9,  # This alpha applies to the error bars
                zorder=1)

    # Add a separate line plot for the connecting line with different alpha
    plt.plot(tps, means, 
            color=col, 
            linewidth=4.5, 
            alpha=0.6)  # Adjust this value for the desired line transparency

    ax.set_yticks([0, 0.5, 1])
    ax.set_ylim(0, 1.1)
    ax.set_yticks([0.25, 0.75], minor=True)

    ax.set_xticks([0, 50, 100, 150, 200])
    ax.set_xlim(0, 240)
    plt.tight_layout()
    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{prb}--SpikeCount-by-Time.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')
    spike_df_final =spike_df_final.dropna(subset=['time_period'])
    spike_df_final['time_period'] = spike_df_final['time_period'].astype(int)
    spike_df_final['time_zone'] = 'None'

    spike_df_final.loc[spike_df_final['time_period'] < 81, 'time_zone'] = 'before'
    spike_df_final.loc[(spike_df_final['time_period'] > 81) & (spike_df_final['time_period'] < 160), 'time_zone'] = 'pulse'
    spike_df_final.loc[spike_df_final['time_period'] > 160, 'time_zone'] = 'after'
    sdf_counts =spike_df_final.groupby(['time_zone', 'spike_vals']).count().reset_index()
    sdf_counts['spike_times_rel'] = sdf_counts['spike_times'] / sdf_counts['spike_times'].max()

    # FINAL VIOLIN PLOT
    f, ax = plt.subplots(figsize=(8, 3.5))
    sns.violinplot(data=sdf_counts, x='time_zone', y='spike_times_rel', ax=ax, 
                color=col, alpha=0.9, 
                order=['before', 'pulse', 'after'], 
                hue='time_zone', palette=[NNXR_GRAY, ACR_BLUE, NNXR_GRAY], 
                hue_order=['before', 'pulse', 'after'], 
                split=False, width=0.8, 
                inner_kws=dict(box_width=10, whis_width=3, color=".8"))
    ax.set_ylim(-0.1, 1.1)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticks([0.25, 0.75], minor=True)
    plt.tight_layout()
    f.savefig(f'{PAPER_FIGURE_ROOT}/schems/opto_evoked_plots/{subject}-{exp}-{prb}--SpikeCount-by-ZONE--Violin.png', dpi=600, bbox_inches='tight', transparent=True, facecolor='none')