# Funcs + Imports

In [None]:
#-------------------------- Standard Imports --------------------------#
import kdephys as kde
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import acr
from acr.plots import pub, lrg
from acr.utils import NNXR_GRAY, SOM_BLUE
lrg()
plt.style.use('fast')
from scipy import stats
# ---------------------------- EXTRAS --------------------------------#
from kdephys.plot.main import _title, bp_plot
import kdephys.utils.spectral as sp
bands = sp.bands
import matplotlib.animation as animation
from matplotlib.collections import PathCollection
from IPython.display import HTML
import pickle

import warnings
warnings.filterwarnings('ignore')
def _plt_style():
    plt.style.use('fast')
    plt.style.use('/home/kdriessen/gh_master/kdephys/kdephys/plot/acr_plots.mplstyle')
    return


In [None]:
def quick_trace_plot(data, times, stim_starts, stim_ends, color=SOM_BLUE, hspace=-0.6, figsize=(28, 10)):
    """Quick plot of raw data with stimulations

    Parameters
    ----------
    data : np.ndarray
        Raw data to plot, of shape (n_channels, n_samples)
    times : np.ndarray
        Times of the data, of shape (n_samples,)
    stim_starts : np.ndarray
        Start times of stimulations
    stim_ends : np.ndarray
        End times of stimulations
    color : str, optional
        Color of the traces, by default SOM_BLUE
    hspace : float, optional
        Space between traces, by default -0.6
    figsize : tuple, optional
        Size of the figure, by default (28, 10)
    """
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    plt.rcParams['axes.grid'] = False
    plt.rcParams['xtick.major.size'] = 0
    plt.rcParams['figure.facecolor'] = 'white'
    plt.rcParams['axes.facecolor'] = 'None'
    
    f, ax = plt.subplots(data.shape[0], 1, figsize=figsize)
    for i in range(data.shape[0]):
        ax[i].plot(times, data[i, :], color=color)
    plt.subplots_adjust(hspace=hspace)
    for on, off in zip(stim_starts, stim_ends):
        for a in ax:
            a.set_xlim(times[0], times[-1])
            a.axvspan(on, off, color='cornflowerblue', ymin=0.325, ymax=0.712, alpha=0.5)
            ax.set_ylim(-830, 870)
    return f, ax
def simple_raster(data, stim_start, stim_ends, xname='datetime', yname='negchan', color='blue', figsize=(28, 4)):
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    plt.rcParams['axes.grid'] = False
    plt.rcParams['xtick.major.size'] = 0
    plt.rcParams['figure.facecolor'] = 'white'
    plt.rcParams['axes.facecolor'] = 'None'
    
    assert xname in data.columns, f"xname {xname} not in data"
    assert yname in data.columns, f"yname {yname} not in data"

    f, ax = plt.subplots(figsize=figsize)
    ax = sns.scatterplot(data, x=xname, y=yname, linewidth=0, alpha=0.7, s=60, ax=ax, color=color)
    ax.set_yticks([])
    ax.set_xticks([])
    plt.tight_layout()
    for on, off in zip(stim_start, stim_ends):
        ax.axvspan(on, off, color='cornflowerblue', alpha=0.5)
    return f, ax

In [None]:
def animated_trace_numpy(data, times, stim_starts, stim_ends, color=SOM_BLUE, 
                         hspace=-0.6, figsize=(14, 5), fps=30, duration=7, 
                         save_path=None, ylims=None):
    """
    Create an animated plot of neural traces with stimulus overlays.
    
    Parameters
    ----------
    data : np.ndarray
        Raw data to plot, of shape (n_channels, n_samples)
    times : np.ndarray
        Times of the data, of shape (n_samples,)
    stim_starts : list-like
        Start times of stimulations
    stim_ends : list-like
        End times of stimulations
    color : str, optional
        Color of the traces, by default SOM_BLUE
    hspace : float, optional
        Space between traces, by default -0.6
    figsize : tuple, optional
        Size of the figure, by default (28, 10)
    fps : int, optional
        Frames per second for the animation, by default 30
    duration : float, optional
        Duration of the animation in seconds, by default 7
    save_path : str, optional
        Path to save the animation, by default None
        
    Returns
    -------
    anim : matplotlib.animation.Animation
        The animation object
    """
    # Set up figure and axes
    f, axes = plt.subplots(data.shape[0], 1, figsize=figsize)
    if data.shape[0] == 1:
        axes = [axes]  # Make iterable if only one channel
    
    plt.subplots_adjust(hspace=hspace)
    
    # Style settings
    plt.rcParams['ytick.left'] = False
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    
    # Convert times to numeric if they are datetime objects
    if isinstance(times[0], pd.Timestamp):
        start_time = times[0]
        numeric_times = np.array([(t - start_time).total_seconds() for t in times])
        stim_starts_sec = np.array([(t - start_time).total_seconds() for t in stim_starts])
        stim_ends_sec = np.array([(t - start_time).total_seconds() for t in stim_ends])
    else:
        # Assuming they're already numeric
        numeric_times = times
        start_time = times[0]
        stim_starts_sec = np.array([t for t in stim_starts])
        stim_ends_sec = np.array([t for t in stim_ends])
    
    # Time range for the animation
    time_range = numeric_times[-1] - numeric_times[0]
    
    # Calculate y-axis limits for each channel
    y_padding = .001  # Add 10% padding to y limits
    for i, ax in enumerate(axes):
        channel_data = data[i, :]
        y_min = np.min(channel_data)
        y_max = np.max(channel_data)
        y_range = y_max - y_min
        if ylims is not None:
            y_min, y_max = ylims[i]
            y_range = y_max - y_min
        # Set y limits with padding
        ax.set_ylim(y_min - y_range * y_padding, y_max + y_range * y_padding)
        
        # Set x limits
        if isinstance(times[0], pd.Timestamp):
            ax.set_xlim(times[0], times[-1])
        else:
            ax.set_xlim(numeric_times[0], numeric_times[-1])
            
        ax.set_yticks([])
    
    # Create empty line objects for each channel
    lines = []
    for i in range(data.shape[0]):
        line, = axes[i].plot([], [], color=color, lw=1.5)
        lines.append(line)
    
    # Create empty spans to track added spans
    spans = []
    
    # Precompute stimulus span coordinates for all channels
    stim_span_info = []
    for start, end in zip(stim_starts, stim_ends):
        stim_span_info.append((start, end))
    
    # Number of frames for the animation
    n_frames = int(fps * duration)
    
    def init():
        # Initialize empty lines
        for line in lines:
            line.set_data([], [])
        return lines
    
    def update(frame):
        # Calculate current time based on frame number
        print(frame)
        current_progress = frame / n_frames
        current_idx = min(int(current_progress * len(numeric_times)), len(numeric_times) - 1)
        
        # Update each line with data up to the current time
        for i, line in enumerate(lines):
            if isinstance(times[0], pd.Timestamp):
                line.set_data(times[:current_idx+1], data[i, :current_idx+1])
            else:
                line.set_data(numeric_times[:current_idx+1], data[i, :current_idx+1])
        
        # Add stimulus spans that should be visible by this time
        current_time = numeric_times[current_idx]
        artists = lines.copy()
        
        # Check which stimuli should be visible and add them if not already added
        for i, (start, end) in enumerate(zip(stim_starts_sec, stim_ends_sec)):
            # Create a unique identifier for this span
            span_id = f"span_{i}"
            
            # If current time has passed the start of this stimulus and we haven't added it yet
            if current_time >= start and not any(s[0] == span_id for s in spans):
                # Add the span to all axes
                for ax_idx, ax in enumerate(axes):
                    # Create the span
                    span = ax.axvspan(
                        stim_span_info[i][0],  # Use original time format
                        stim_span_info[i][1],  # Use original time format
                        color='cornflowerblue',
                        ymin=0.325,
                        ymax=0.712,
                        alpha=0.4
                    )
                    
                    # Track this span with its ID
                    spans.append((span_id, span))
                    artists.append(span)
        
        # Add all existing spans to the artists list
        for _, span in spans:
            if span not in artists:
                artists.append(span)
        
        return artists
    
    # Create the animation
    anim = animation.FuncAnimation(
        f, update, frames=n_frames, init_func=init, blit=True, interval=1000/fps
    )
    
    if save_path:
        # Save animation with transparent background
        if save_path.endswith('.mov'):
            writer = animation.FFMpegWriter(
                fps=fps,
                metadata=dict(artist='Matplotlib'),
                bitrate=1800,
                codec='prores_ks',
                extra_args=['-profile:v', '4444', '-pix_fmt', 'yuva444p10le']
            )
            anim.save(save_path, dpi=300, writer=writer)
        else:
            anim.save(save_path, writer='ffmpeg', fps=fps)
    
    plt.close()
    return anim
def animated_raster(data, stim_start, stim_ends, xname='datetime', yname='negchan', color='blue', 
                    figsize=(14, 2), fps=30, duration=7, save_path=None, trange=None):
    """
    Create an animated raster plot where data points and stimulation markers appear in real-time.
    
    Parameters:
    -----------
    data : pandas DataFrame
        DataFrame containing spike times and channel information
    stim_start : list-like
        List of stimulus start times
    stim_ends : list-like
        List of stimulus end times
    xname : str
        Column name for x-axis (time)
    yname : str
        Column name for y-axis (channels)
    color : str
        Color for spikes
    figsize : tuple
        Figure size
    fps : int
        Frames per second for the animation
    duration : float
        Duration of the animation in seconds
    save_path : str, optional
        If provided, save the animation to this path
        
    Returns:
    --------
    animation : matplotlib.animation.FuncAnimation
        The animation object
    """
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    plt.rcParams['axes.grid'] = False
    plt.rcParams['xtick.major.size'] = 0
    plt.rcParams['figure.facecolor'] = 'None'
    plt.rcParams['axes.facecolor'] = 'None'
    # Make plot completely transparent
    plt.rcParams['savefig.transparent'] = True
    plt.rcParams['figure.frameon'] = False
    plt.rcParams['axes.edgecolor'] = 'None'
    plt.rcParams['figure.subplot.wspace'] = 0
    plt.rcParams['figure.subplot.hspace'] = 0
    
    assert xname in data.columns, f"xname {xname} not in data"
    assert yname in data.columns, f"yname {yname} not in data"
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_yticks([])
    ax.set_xticks([])
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    
    # Convert datetime objects to timestamps for easier comparison
    if isinstance(data[xname].iloc[0], pd.Timestamp):
        start_time = data[xname].min()
        end_time = data[xname].max()
        time_range = (end_time - start_time).total_seconds()
        
        data_times = [(t - start_time).total_seconds() for t in data[xname]]
        stim_starts_sec = [(t - start_time).total_seconds() for t in stim_start]
        stim_ends_sec = [(t - start_time).total_seconds() for t in stim_ends]
    else:
        # Assuming numeric time values
        start_time = data[xname].min()
        end_time = data[xname].max()
        time_range = end_time - start_time
        
        data_times = [t - start_time for t in data[xname]]
        stim_starts_sec = [t - start_time for t in stim_start]
        stim_ends_sec = [t - start_time for t in stim_ends]
    
    if trange is not None:
        time_range = trange
    # Set x limits to the full range of data
    y_min = data[yname].min()
    y_max = data[yname].max()
    
    ax.set_xlim(0, time_range)
    ax.set_ylim(y_min - 0.5, y_max + 0.5)
    
    # Initialize empty scatter plot
    scatter = ax.scatter([], [], alpha=0.85, s=35, color=color, linewidth=0)
    
    # Initialize empty list for stimulus spans
    spans = []
    
    # Frame count calculation
    n_frames = int(fps * duration)
    
    def init():
        # Initial state (empty)
        scatter.set_offsets(np.empty((0, 2)))
        return [scatter]
    
    def animate(frame):
        # Calculate the current time in seconds based on the frame
        current_time = time_range * frame / n_frames
        
        # Add data points that should be visible at this time
        visible_indices = [i for i, t in enumerate(data_times) if t <= current_time]
        
        if len(visible_indices) > 0:
            visible_x = [data_times[i] for i in visible_indices]
            visible_y = [data[yname].iloc[i] for i in visible_indices]
            scatter.set_offsets(np.column_stack([visible_x, visible_y]))
        
        # Add any stim spans that should be visible at this time
        for i, (start, end) in enumerate(zip(stim_starts_sec, stim_ends_sec)):
            if start <= current_time and i >= len(spans):
                # Create new span
                span = ax.axvspan(start, end, color='cornflowerblue', alpha=0.4)
                spans.append(span)
        
        # Return all artists that need to be updated
        artists = [scatter]
        artists.extend(spans)
        return artists
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, animate, init_func=init, frames=n_frames,
        interval=1000/fps, blit=True
    )
    
    if save_path:
        # Save animation with transparent background
        if save_path.endswith('.mov'):
            writer = animation.FFMpegWriter(
                fps=fps,
                metadata=dict(artist='Matplotlib'),
                bitrate=1800,
                codec='prores_ks',
                extra_args=['-profile:v', '4444', '-pix_fmt', 'yuva444p10le']
            )
            anim.save(save_path, dpi=300, writer=writer)
        else:
            anim.save(save_path, writer='ffmpeg', fps=fps)
    
    plt.close()
    return anim

# Check + Animate

times

OFF_induction --> 16890

SD-wake --> 92



In [None]:
subject = 'ACR_25'
exp = 'swi'

#---------------- Adjust Parameters Here -----------------# 
stores = ['NNXo', 'NNXr']
rel_state='NREM'
#---------------------------------------------------------#
notebook_save_root = '/Volumes/opto_loc/Data/ACR_PROJECT_MATERIALS/plots_presentations_etc/paper_figures/animations/raw_data'

In [None]:
# ----------------------------------------- subject_info + Hypno -----------------------------------------
h = acr.io.load_hypno_full_exp(subject, exp)
hd = acr.hypnogram_utils.create_acr_hyp_dict(subject, exp)
si = acr.info_pipeline.load_subject_info(subject)
sort_ids = [f'{exp}-{store}' for store in stores]
recordings = acr.info_pipeline.get_exp_recs(subject, exp)
stim_store = si['stim-exps'][exp][0]

sd_true_start, stim_start, stim_end, rebound_start, full_exp_start = acr.info_pipeline.get_sd_exp_landmarks(subject, exp, update=True)


mua = acr.mua.load_concat_peaks_df(subject, exp)
fp = acr.io.load_concat_raw_data(subject, recordings)
pon, poff = acr.stim.get_individual_pulse_times(subject, exp)


In [None]:
anim_name = 'SD_wake'
probe_to_animate = 'NNXr'
start_time = 16890
duration = '10s'
st = acr.utils.dt_from_tdt(subject, exp, start_time)

In [None]:
# PLOT THE RASTER
m2p = mua.prb(probe_to_animate).ts(st, st+pd.Timedelta(duration)).to_pandas()
pon2p = pon[(pon>=st)&(pon<st+pd.Timedelta(duration))]
poff2p = poff[(poff>st)&(poff<=st+pd.Timedelta(duration))]

f, ax = simple_raster(m2p, pon2p, poff2p, color=SOM_BLUE, figsize=(28, 4))


In [None]:
# PLOT THE TRACES TO BE ANIMATED
plt.rcParams['ytick.left'] = False
fp2p = fp.prb(probe_to_animate).ts(st, st+pd.Timedelta(duration))

data = fp2p.data.T
times = fp2p.datetime.data
f, ax =quick_trace_plot(data, times, pon2p, poff2p)
for a in ax:
    a.set_xlim(times[0], times[-1])
    a.set_ylim(-830, 870)

In [None]:
#limits = pickle.load(open(f'{subject}_{exp}_nnxo_lims.pkl', 'rb'))
limits = [(-830, 870) for i in range(16)]

In [None]:
fp2p = fp.prb(probe_to_animate).ts(st, st+pd.Timedelta(duration))
data = fp2p.data.T
times = fp2p.datetime.data
pon2p = pon[(pon>=st)&(pon<st+pd.Timedelta(duration))]
poff2p = poff[(poff>st)&(poff<=st+pd.Timedelta(duration))]

col = SOM_BLUE if probe_to_animate == 'NNXo' else NNXR_GRAY
# Create the animation
trace_anim = animated_trace_numpy(
    data=data,
    times=times,
    stim_starts=pon2p,
    stim_ends=poff2p,
    color=col,
    fps=30,
    duration=int(duration.split('s')[0]),
    save_path=f'{notebook_save_root}/{anim_name}--LFP--{probe_to_animate}.mov',
    ylims=limits
)

# To display in the notebook instead of saving
# HTML(trace_anim.to_jshtml())

In [None]:
# select data for raster plot
m2p = mua.prb(probe_to_animate).ts(st, st+pd.Timedelta(duration)).to_pandas()
pon2p = pon[(pon>=st)&(pon<st+pd.Timedelta(duration))]
poff2p = poff[(poff>st)&(poff<=st+pd.Timedelta(duration))]
# try using the total duraiton of the field potential data
timespd = pd.to_datetime(times)
tr = (timespd.max() - timespd.min()).total_seconds()

# Create the animated raster plot
# Generate the animation - adjust parameters as needed
anim = animated_raster(
    data=m2p, 
    stim_start=pon2p, 
    stim_ends=poff2p, 
    color=col, 
    figsize=(14, 2),
    fps=30,
    duration=int(duration.split('s')[0]),  # Match the 7s data duration
    save_path=f'{notebook_save_root}/{anim_name}--RASTER--{probe_to_animate}.mov',  # Comment out to preview instead of saving
    trange=tr
)

# Display the animation in the notebook
# HTML(anim.to_jshtml())

# Uncomment the above line to display the animation in the notebook
# Or use the save_path parameter to save to a file

# Generates the pickle files with ylim values

In [None]:
def animated_trace_numpy(data, times, stim_starts, stim_ends, color=SOM_BLUE, 
                         hspace=-0.6, figsize=(28, 10), fps=30, duration=7, 
                         save_path=None, ylims=None):
    """
    Create an animated plot of neural traces with stimulus overlays.
    
    Parameters
    ----------
    data : np.ndarray
        Raw data to plot, of shape (n_channels, n_samples)
    times : np.ndarray
        Times of the data, of shape (n_samples,)
    stim_starts : list-like
        Start times of stimulations
    stim_ends : list-like
        End times of stimulations
    color : str, optional
        Color of the traces, by default SOM_BLUE
    hspace : float, optional
        Space between traces, by default -0.6
    figsize : tuple, optional
        Size of the figure, by default (28, 10)
    fps : int, optional
        Frames per second for the animation, by default 30
    duration : float, optional
        Duration of the animation in seconds, by default 7
    save_path : str, optional
        Path to save the animation, by default None
        
    Returns
    -------
    anim : matplotlib.animation.Animation
        The animation object
    """
    # Set up figure and axes
    f, axes = plt.subplots(data.shape[0], 1, figsize=figsize)
    if data.shape[0] == 1:
        axes = [axes]  # Make iterable if only one channel
    
    plt.subplots_adjust(hspace=hspace)
    
    # Style settings
    plt.rcParams['ytick.left'] = False
    plt.rcParams['axes.spines.bottom'] = False
    plt.rcParams['axes.spines.left'] = False
    plt.rcParams['axes.spines.right'] = False
    plt.rcParams['axes.spines.top'] = False
    
    # Convert times to numeric if they are datetime objects
    if isinstance(times[0], pd.Timestamp):
        start_time = times[0]
        numeric_times = np.array([(t - start_time).total_seconds() for t in times])
        stim_starts_sec = np.array([(t - start_time).total_seconds() for t in stim_starts])
        stim_ends_sec = np.array([(t - start_time).total_seconds() for t in stim_ends])
    else:
        # Assuming they're already numeric
        numeric_times = times
        start_time = times[0]
        stim_starts_sec = np.array([t for t in stim_starts])
        stim_ends_sec = np.array([t for t in stim_ends])
    
    # Time range for the animation
    time_range = numeric_times[-1] - numeric_times[0]
    
    # Calculate y-axis limits for each channel
    nnxo_lims = []
    y_padding = .001  # Add 10% padding to y limits
    for i, ax in enumerate(axes):
        channel_data = data[i, :]
        y_min = np.min(channel_data)
        y_max = np.max(channel_data)
        y_range = y_max - y_min
        nnxo_lims.append((y_min, y_max))
        # Set y limits with padding
        ax.set_ylim(y_min - y_range * y_padding, y_max + y_range * y_padding)
        
        # Set x limits
        if isinstance(times[0], pd.Timestamp):
            ax.set_xlim(times[0], times[-1])
        else:
            ax.set_xlim(numeric_times[0], numeric_times[-1])
            
        ax.set_yticks([])
    
    # Define the pickle file path
    pickle_path = f"{subject}_{exp}_nnxo_lims.pkl"
    
    # Save the nnxo_lims list to the pickle file
    with open(pickle_path, 'wb') as f:
        pickle.dump(nnxo_lims, f)
    
    # Create empty line objects for each channel
    lines = []
    for i in range(data.shape[0]):
        line, = axes[i].plot([], [], color=color, lw=1)
        lines.append(line)
    
    # Create empty spans to track added spans
    spans = []
    
    # Precompute stimulus span coordinates for all channels
    stim_span_info = []
    for start, end in zip(stim_starts, stim_ends):
        stim_span_info.append((start, end))
    
    # Number of frames for the animation
    n_frames = int(fps * duration)
    
    def init():
        # Initialize empty lines
        for line in lines:
            line.set_data([], [])
        return lines
    
    def update(frame):
        # Calculate current time based on frame number
        print(frame)
        current_progress = frame / n_frames
        current_idx = min(int(current_progress * len(numeric_times)), len(numeric_times) - 1)
        
        # Update each line with data up to the current time
        for i, line in enumerate(lines):
            if isinstance(times[0], pd.Timestamp):
                line.set_data(times[:current_idx+1], data[i, :current_idx+1])
            else:
                line.set_data(numeric_times[:current_idx+1], data[i, :current_idx+1])
        
        # Add stimulus spans that should be visible by this time
        current_time = numeric_times[current_idx]
        artists = lines.copy()
        
        # Check which stimuli should be visible and add them if not already added
        for i, (start, end) in enumerate(zip(stim_starts_sec, stim_ends_sec)):
            # Create a unique identifier for this span
            span_id = f"span_{i}"
            
            # If current time has passed the start of this stimulus and we haven't added it yet
            if current_time >= start and not any(s[0] == span_id for s in spans):
                # Add the span to all axes
                for ax_idx, ax in enumerate(axes):
                    # Create the span
                    span = ax.axvspan(
                        stim_span_info[i][0],  # Use original time format
                        stim_span_info[i][1],  # Use original time format
                        color='cornflowerblue',
                        ymin=0.325,
                        ymax=0.712,
                        alpha=0.5
                    )
                    
                    # Track this span with its ID
                    spans.append((span_id, span))
                    artists.append(span)
        
        # Add all existing spans to the artists list
        for _, span in spans:
            if span not in artists:
                artists.append(span)
        
        return artists
    
    # Create the animation
    anim = animation.FuncAnimation(
        f, update, frames=n_frames, init_func=init, blit=True, interval=1000/fps
    )
    
    if save_path:
        # Save as video file
        anim.save(save_path, writer='ffmpeg', fps=fps)
        plt.close(f)
    
    return anim
