In [None]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
from pynwb import NWBFile, TimeSeries, NWBHDF5IO
from scipy.io import loadmat
from scipy.stats import zscore
import ast
from utils.plot_utils import combine_pdf_big
from utils.beh_functions import session_dirs, parseSessionID, load_model_dv, makeSessionDF, get_session_tbl, get_unit_tbl, get_history_from_nwb
from utils.ephys_functions import*
from utils.lick_utils import load_licks
from utils.combine_tools import apply_qc, to_str_intlike

from open_ephys.analysis import Session
from pathlib import Path
import glob

import json
import seaborn as sns
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
import re
from aind_dynamic_foraging_basic_analysis.plot.plot_foraging_session import plot_foraging_session
from aind_dynamic_foraging_data_utils.nwb_utils import load_nwb_from_filename
from hdmf_zarr.nwb import NWBZarrIO

import pandas as pd
import pickle
import scipy.stats as stats
from joblib import Parallel, delayed
from multiprocessing import Pool
from functools import partial
import time
import shutil 
from aind_ephys_utils import align
%matplotlib inline

In [None]:
dfs = [pd.read_csv('/root/capsule/code/data_management/session_assets.csv'),
        pd.read_csv('/root/capsule/code/data_management/hopkins_session_assets.csv')]
df = pd.concat(dfs)
session_list = df['session_id'].values.tolist()
ani_list = [str(session).split('_')[1] for session in session_list if str(session).startswith('behavior')]
session_list = [session for session in session_list if str(session).startswith('behavior')]
ani_session_df = pd.DataFrame({'animal': ani_list, 'session_id': session_list})

In [None]:
def assign_lick_label(start_time, end_time, peak_time, licks_L, licks_R):
    has_lick = 0
    if np.any((licks_L >= start_time) & (licks_L <= end_time)):
        has_lick = -1
    if np.any((licks_R >= start_time) & (licks_R <= end_time)):
        has_lick = 1
    min_time_diff_L_ind = np.argmin(np.abs(licks_L - (start_time+peak_time)))
    min_time_diff_R_ind = np.argmin(np.abs(licks_R - (start_time+peak_time)))
    min_time_diff_L = np.abs(licks_L[min_time_diff_L_ind] - (start_time+peak_time))
    min_time_diff_R = np.abs(licks_R[min_time_diff_R_ind] - (start_time+peak_time))
    if min_time_diff_L < min_time_diff_R:
        min_time_diff = licks_L[min_time_diff_L_ind] - (start_time+peak_time)
    else:
        min_time_diff = licks_R[min_time_diff_R_ind] - (start_time+peak_time)
    return has_lick, min_time_diff
def pairplot_color_code(start, max_point, color_code, feature_name, fig, subplot_gs,
                        v_range=None, center_xys=None, bins_xy=50, bins_c=50,
                        equal_aspect=True):
    """
    start, max_point: shape (2, N) arrays: [x_like, y_like] in your code it's [Y, X]
    color_code: length N
    center_xy: tuple (center_y, center_x) in your coordinate convention
    """

    # ---- layout ----
    gs_sub = gridspec.GridSpecFromSubplotSpec(
        6, 4, subplot_spec=subplot_gs,
        height_ratios=[1, 1, 1, 1, 0.5, 0.5],
        hspace=0.4, wspace=0.4
    )
    ax_main  = fig.add_subplot(gs_sub[1:4, 0:3])
    ax_xhist = fig.add_subplot(gs_sub[0,   0:3], sharex=ax_main)
    ax_yhist = fig.add_subplot(gs_sub[1:4, 3],   sharey=ax_main)

    # ---- main plot ----
    # trajectories
    for i in range(start.shape[1]):
        ax_main.plot([start[0, i], max_point[0, i]],
                     [start[1, i], max_point[1, i]],
                     color='k', alpha=0.1, linewidth=0.05)

    # color range
    if v_range is None:
        color_code = np.asarray(color_code, float)
        v_range = (np.nanquantile(color_code, 0.05), np.nanquantile(color_code, 0.95))

    ax_main.scatter(start[0, :], start[1, :], color='gray', s=2, label='Start Point')

    sg = ax_main.scatter(
        max_point[0, :], max_point[1, :],
        c=color_code, s=2, vmin=v_range[0], vmax=v_range[1]
    )

    if center_xys is not None:
        ax_main.scatter(center_xys[:, 0], center_xys[:, 1], color='red', marker='x', s=25, label='Center of Mass')

    ax_main.set_xlabel("Y")
    ax_main.set_ylabel("X")
    ax_main.set_title(f'Lick locations colored by {feature_name}')

    # ---- lock limits (CRITICAL for alignment) ----
    # Compute limits directly from the plotted data to avoid autoscale surprises.
    x = np.asarray(max_point[0, :], float)
    y = np.asarray(max_point[1, :], float)
    ok = np.isfinite(x) & np.isfinite(y)
    x = x[ok]; y = y[ok]

    xmin, xmax = np.min(x), np.max(x)
    ymin, ymax = np.min(y), np.max(y)

    # optional tiny padding (consistent across axes)
    pad_x = 0.02 * (xmax - xmin) if xmax > xmin else 1.0
    pad_y = 0.02 * (ymax - ymin) if ymax > ymin else 1.0

    ax_main.set_xlim(xmin - pad_x, xmax + pad_x)
    ax_main.set_ylim(ymin - pad_y, ymax + pad_y)

    if equal_aspect:
        # use box adjustment so limits don't change after aspect set
        ax_main.set_aspect('equal', adjustable='box')

    # Capture final limits after all adjustments
    x_lims = ax_main.get_xlim()
    y_lims = ax_main.get_ylim()

    # ---- marginals (draw AFTER limits are fixed) ----
    ax_xhist = fig.add_subplot(gs_sub[0,   0:3], sharex=ax_main)
    ax_yhist = fig.add_subplot(gs_sub[1:4, 3],   sharey=ax_main)
    ax_xhist.hist(x, bins=bins_xy, color='black', alpha=0.6)
    ax_yhist.hist(y, bins=bins_xy, orientation='horizontal', color='black', alpha=0.6)

    # Re-apply limits because hist can autoscale
    ax_xhist.set_xlim(x_lims)
    ax_yhist.set_ylim(y_lims)

    # Prevent future autoscale drift
    ax_xhist.set_autoscale_on(False)
    ax_yhist.set_autoscale_on(False)

    # Clean marginal ticks/labels
    ax_xhist.tick_params(axis='x', labelbottom=False)
    ax_xhist.tick_params(axis='y', left=False, labelleft=False)
    ax_yhist.tick_params(axis='y', labelleft=False)
    ax_yhist.tick_params(axis='x', bottom=False, labelbottom=False)

    # ---- horizontal colorbar (own axis) ----
    ax_cbar = fig.add_subplot(gs_sub[4, 0:3])
    cbar = fig.colorbar(sg, cax=ax_cbar, orientation='horizontal')
    cbar.set_label(feature_name)

    # ---- histogram of color_code ----
    ax_hist_color = fig.add_subplot(gs_sub[5, 0:3])
    cc = np.asarray(color_code, float)
    cc = cc[np.isfinite(cc)]
    bins = np.linspace(np.min(cc), np.max(cc), bins_c)
    ax_hist_color.hist(color_code, bins=bins, color='skyblue', edgecolor=None)
    ax_hist_color.axvline(v_range[0], color='b', linestyle='--')
    ax_hist_color.axvline(v_range[1], color='y', linestyle='--')
    ax_hist_color.set_yscale('log')
    ax_hist_color.set_title(f'Histogram of {feature_name}')

    return ax_main, ax_xhist, ax_yhist, ax_cbar, ax_hist_color


In [None]:
def pairplot_color_code(start, max_point, color_code, feature_name, fig, subplot_gs,
                        v_range=None, center_xys=None, bins_xy=50, bins_c=50,
                        equal_aspect=True):
    """
    start, max_point: shape (2, N) arrays: [x_like, y_like] in your code it's [Y, X]
    color_code: length N
    center_xy: tuple (center_y, center_x) in your coordinate convention
    """

    # ---- layout ----
    gs_sub = gridspec.GridSpecFromSubplotSpec(
        6, 4, subplot_spec=subplot_gs,
        height_ratios=[1, 1, 1, 1, 0.5, 0.5],
        hspace=0.4, wspace=0.4
    )
    ax_main  = fig.add_subplot(gs_sub[1:4, 0:3])
    ax_xhist = fig.add_subplot(gs_sub[0,   0:3], sharex=ax_main)
    ax_yhist = fig.add_subplot(gs_sub[1:4, 3],   sharey=ax_main)

    # ---- main plot ----
    # trajectories
    for i in range(start.shape[1]):
        ax_main.plot([start[0, i], max_point[0, i]],
                     [start[1, i], max_point[1, i]],
                     color='k', alpha=0.1, linewidth=0.05)

    # color range
    if v_range is None:
        color_code = np.asarray(color_code, float)
        v_range = (np.nanquantile(color_code, 0.05), np.nanquantile(color_code, 0.95))

    ax_main.scatter(start[0, :], start[1, :], color='gray', s=2, label='Start Point')

    sg = ax_main.scatter(
        max_point[0, :], max_point[1, :],
        c=color_code, s=2, vmin=v_range[0], vmax=v_range[1]
    )

    if center_xys is not None:
        ax_main.scatter(center_xys[:, 0], center_xys[:, 1], color='red', marker='x', s=25, label='Center of Mass')

    ax_main.set_xlabel("Y")
    ax_main.set_ylabel("X")
    ax_main.set_title(f'Lick locations colored by {feature_name}')

    # ---- lock limits (CRITICAL for alignment) ----
    # Compute limits directly from the plotted data to avoid autoscale surprises.
    x = np.asarray(max_point[0, :], float)
    y = np.asarray(max_point[1, :], float)
    ok = np.isfinite(x) & np.isfinite(y)
    x = x[ok]; y = y[ok]

    xmin, xmax = np.min(x), np.max(x)
    ymin, ymax = np.min(y), np.max(y)

    # optional tiny padding (consistent across axes)
    pad_x = 0.02 * (xmax - xmin) if xmax > xmin else 1.0
    pad_y = 0.02 * (ymax - ymin) if ymax > ymin else 1.0

    ax_main.set_xlim(xmin - pad_x, xmax + pad_x)
    ax_main.set_ylim(ymin - pad_y, ymax + pad_y)

    if equal_aspect:
        # use box adjustment so limits don't change after aspect set
        ax_main.set_aspect('equal', adjustable='box')

    # Capture final limits after all adjustments
    x_lims = ax_main.get_xlim()
    y_lims = ax_main.get_ylim()

    # ---- marginals (draw AFTER limits are fixed) ----
    ax_xhist = fig.add_subplot(gs_sub[0,   0:3], sharex=ax_main)
    ax_yhist = fig.add_subplot(gs_sub[1:4, 3],   sharey=ax_main)
    ax_xhist.hist(x, bins=bins_xy, color='black', alpha=0.6)
    ax_yhist.hist(y, bins=bins_xy, orientation='horizontal', color='black', alpha=0.6)

    # Re-apply limits because hist can autoscale
    ax_xhist.set_xlim(x_lims)
    ax_yhist.set_ylim(y_lims)

    # Prevent future autoscale drift
    ax_xhist.set_autoscale_on(False)
    ax_yhist.set_autoscale_on(False)

    # Clean marginal ticks/labels
    ax_xhist.tick_params(axis='x', labelbottom=False)
    ax_xhist.tick_params(axis='y', left=False, labelleft=False)
    ax_yhist.tick_params(axis='y', labelleft=False)
    ax_yhist.tick_params(axis='x', bottom=False, labelbottom=False)

    # ---- horizontal colorbar (own axis) ----
    ax_cbar = fig.add_subplot(gs_sub[4, 0:3])
    cbar = fig.colorbar(sg, cax=ax_cbar, orientation='horizontal')
    cbar.set_label(feature_name)

    # ---- histogram of color_code ----
    ax_hist_color = fig.add_subplot(gs_sub[5, 0:3])
    cc = np.asarray(color_code, float)
    cc = cc[np.isfinite(cc)]
    bins = np.linspace(np.min(cc), np.max(cc), bins_c)
    ax_hist_color.hist(color_code, bins=bins, color='skyblue', edgecolor=None)
    ax_hist_color.axvline(v_range[0], color='b', linestyle='--')
    ax_hist_color.axvline(v_range[1], color='y', linestyle='--')
    ax_hist_color.set_yscale('log')
    ax_hist_color.set_title(f'Histogram of {feature_name}')

    return ax_main, ax_xhist, ax_yhist, ax_cbar, ax_hist_color