In [4]:
"""
@author: pho
"""
%load_ext autoreload
%autoreload 2
import sys
import importlib

import numpy as np
import pandas as pd

## Panel:
import param
import panel as pn
from panel.interact import interact, interactive, fixed, interact_manual
from panel.viewable import Viewer

pn.extension()

## Pho's Custom Libraries:
from pyphocorehelpers.general_helpers import get_arguments_as_optional_dict, inspect_callable_arguments
from pyphocorehelpers.function_helpers import compose_functions
from pyphocorehelpers.indexing_helpers import partition, build_spanning_bins, compute_spanning_bins, compute_position_grid_size
from pyphocorehelpers.print_helpers import PrettyPrintable, WrappingMessagePrinter
from pyphocorehelpers.geometry_helpers import compute_data_extent, compute_data_aspect_ratio, corner_points_from_extents

from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import * # get_neuron_identities
from pyphoplacecellanalysis.General.SessionSelectionAndFiltering import batch_filter_session
from pyphoplacecellanalysis.General.ComputationResults import ComputationResult
# from PendingNotebookCode import estimation_session_laps


from neuropy.analyses.laps import estimation_session_laps
from neuropy.core.epoch import NamedTimerange

# Neuropy:
from neuropy.analyses.placefields import PlacefieldComputationParameters, perform_compute_placefields
from neuropy.core.neuron_identities import NeuronIdentity, build_units_colormap, PlotStringBrevityModeEnum
from neuropy.utils.debug_helpers import debug_print_placefield, debug_print_spike_counts, debug_print_subsession_neuron_differences

# def estimate_session_laps_load_function(regular_load_function, a_base_dir):
#     session = regular_load_function(a_base_dir)
#     ## Estimate the Session's Laps data using my algorithm from the loaded position data.
#     session = estimation_session_laps(session)
#     return session

known_data_session_type_dict = {'kdiba':KnownDataSessionTypeProperties(load_function=(lambda a_base_dir: DataSessionLoader.kdiba_old_format_session(a_base_dir)),
                               basedir=Path(r'R:\data\KDIBA\gor01\one\2006-6-07_11-26-53')),
                'bapun':KnownDataSessionTypeProperties(load_function=(lambda a_base_dir: DataSessionLoader.bapun_data_session(a_base_dir)),
                               basedir=Path('R:\data\Bapun\Day5TwoNovel'))
               }

known_data_session_type_dict['kdiba'].post_load_functions = [lambda a_loaded_sess: estimation_session_laps(a_loaded_sess)]
# known_data_session_type_dict = {'kdiba':KnownDataSessionTypeProperties(load_function=(lambda a_base_dir: estimation_session_laps(DataSessionLoader.kdiba_old_format_session(a_base_dir))),
#                                basedir=Path(r'R:\data\KDIBA\gor01\one\2006-6-07_11-26-53')),
#                 'bapun':KnownDataSessionTypeProperties(load_function=(lambda a_base_dir: DataSessionLoader.bapun_data_session(a_base_dir)),
#                                basedir=Path('R:\data\Bapun\Day5TwoNovel'))
#                }

# known_data_session_type_dict['kdiba'].name
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
%matplotlib qt
from matplotlib.transforms import IdentityTransform


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Common Setup

In [9]:
def compute_position_grid_bin_size(x, y, num_bins=(64,64), debug_print=False):
    """ Compute Required Bin size given a desired number of bins in each dimension
    Usage:
        active_grid_bin = compute_position_grid_bin_size(curr_kdiba_pipeline.sess.position.x, curr_kdiba_pipeline.sess.position.y, num_bins=(64, 64)
    """
    out_grid_bin_size, out_bins, out_bins_infos = compute_position_grid_size(x, y, num_bins=num_bins)
    active_grid_bin = tuple(out_grid_bin_size)
    if debug_print:
        print(f'active_grid_bin: {active_grid_bin}') # (3.776841861770752, 1.043326930905373)
    return active_grid_bin

In [None]:
# from scipy.stats import multivariate_normal

# multivariate_normal.pdf()


# from sklearn.neighbors import KernelDensity

# kde = KernelDensity(kernel='gaussian', bandwidth=1.0).fix(X)

In [6]:
from pyphocorehelpers.plotting.background_helpers import build_image_from_text

fig = plt.figure()
rgba1 = build_image_from_text(r"IQ: $\sigma_i=15$", color="blue", fontsize=20, dpi=200)
rgba2 = build_image_from_text(r"some other string", color="red", fontsize=20, dpi=200)
# One can then draw such text images to a Figure using `.Figure.figimage`.
fig.figimage(rgba1, 100, 50)
fig.figimage(rgba2, 100, 150)

# One can also directly draw texts to a figure with positioning
# in pixel coordinates by using `.Figure.text` together with
# `.transforms.IdentityTransform`.
fig.text(100, 250, r"IQ: $\sigma_i=15$", color="blue", fontsize=20,
        transform=IdentityTransform())
fig.text(100, 350, r"some other string", color="red", fontsize=20,
        transform=IdentityTransform())
plt.show()

# Bapun Format:

In [None]:
# curr_bapun_pipeline = NeuropyPipeline(name='bapun_pipeline', session_data_type='bapun', basedir=known_data_session_type_dict['bapun'].basedir, load_function=known_data_session_type_dict['bapun'].load_function)
curr_bapun_pipeline = NeuropyPipeline.init_from_known_data_session_type('bapun', known_data_session_type_dict['bapun'])
curr_bapun_pipeline.is_loaded
size_bytes = curr_bapun_pipeline.sess.__sizeof__() # 1753723032
f'object size: {size_bytes/(1024*1024)} MB'

In [None]:
curr_bapun_pipeline.sess.spikes_df.spikes.time_variable_name # 't_rel_seconds'
curr_bapun_pipeline.sess.position.to_dataframe().position.time_variable_name # t
# .spikes.time_variable_name # 't_rel_seconds'
# spk_df.spikes.time_variable_name

In [None]:
# Bapun/DataFrame style session filter functions:
def build_bapun_any_maze_epochs_filters(sess):
    def _temp_filter_session_by_epoch1(sess):
        """ 
        Usage:
            active_session, active_epoch = _temp_filter_session(curr_bapun_pipeline.sess)
        """
        active_epoch = sess.epochs.get_named_timerange('maze1')
        ## All Spikes:
        # active_epoch_session = sess.filtered_by_epoch(active_epoch) # old
        active_session = batch_filter_session(sess, sess.position, sess.spikes_df, active_epoch.to_Epoch())
        return active_session, active_epoch

    def _temp_filter_session_by_epoch2(sess):
        """ 
        Usage:
            active_session, active_epoch = _temp_filter_session(curr_bapun_pipeline.sess)
        """
        active_epoch = sess.epochs.get_named_timerange('maze2')
        ## All Spikes:
        # active_epoch_session = sess.filtered_by_epoch(active_epoch) # old
        active_session = batch_filter_session(sess, sess.position, sess.spikes_df, active_epoch.to_Epoch())
        return active_session, active_epoch

    return {'maze1':_temp_filter_session_by_epoch1}
    # return {'maze1':_temp_filter_session_by_epoch1,
    #         'maze2':_temp_filter_session_by_epoch2
    #        }

    # return {'maze1': lambda x: (x.filtered_by_epoch(x.epochs.get_named_timerange('maze1')), x.epochs.get_named_timerange('maze1')),
    #         'maze2': lambda x: (x.filtered_by_epoch(x.epochs.get_named_timerange('maze2')), x.epochs.get_named_timerange('maze2'))
    #        }

active_session_filter_configurations = build_bapun_any_maze_epochs_filters(curr_bapun_pipeline.sess)
curr_bapun_pipeline.filter_sessions(active_session_filter_configurations)
active_grid_bin = compute_position_grid_bin_size(curr_bapun_pipeline.sess.position.x, curr_bapun_pipeline.sess.position.y, num_bins=(64, 64))
active_session_computation_config = PlacefieldComputationParameters(speed_thresh=0.0, grid_bin=active_grid_bin, smooth=(1.5, 1.5), frate_thresh=0.1, time_bin_size=0.5)
curr_bapun_pipeline.perform_computations(active_session_computation_config)
curr_bapun_pipeline.prepare_for_display() # TODO: pass a display config
curr_bapun_pipeline.display(DefaultDisplayFunctions._display_2d_placefield_result_plot_ratemaps_2D, 'maze1') # works!

In [None]:
_display_result(curr_bapun_pipeline.computation_results['maze1'])

In [None]:
_display_result(curr_bapun_pipeline.computation_results['maze2'])

# KDiba Format:

In [7]:
## Data must be pre-processed using the MATLAB script located here: 
# R:\data\KDIBA\gor01\one\IIDataMat_Export_ToPython_2021_11_23.m
# From pre-computed .mat files:
## 07: 
# basedir = r'R:\data\KDIBA\gor01\one\2006-6-07_11-26-53'
# # ## 08:
# basedir = r'R:\data\KDIBA\gor01\one\2006-6-08_14-26-15'
# curr_kdiba_pipeline = NeuropyPipeline(name='kdiba_pipeline', session_data_type='kdiba', basedir=known_data_session_type_dict['kdiba'].basedir, load_function=known_data_session_type_dict['kdiba'].load_function)
curr_kdiba_pipeline = NeuropyPipeline.init_from_known_data_session_type('kdiba', known_data_session_type_dict['kdiba'])

# curr_bapun_pipeline
curr_kdiba_pipeline.is_loaded
size_bytes = curr_kdiba_pipeline.sess.__sizeof__() # 1753723032
f'object size: {size_bytes/(1024*1024)} MB'
# ## Estimate the Session's Laps data using my algorithm from the loaded position data.
# curr_kdiba_pipeline.sess = estimation_session_laps(curr_kdiba_pipeline.sess)
# curr_kdiba_pipeline.sess.epochs

basedir is already Path object.
	 basepath: R:\data\KDIBA\gor01\one\2006-6-07_11-26-53
	 session_name: 2006-6-07_11-26-53
Loading matlab import file results to R:\data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.epochs_info.mat... done.
Loading matlab import file results to R:\data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.position_info.mat... done.
Loading matlab import file results to R:\data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.spikes.mat... done.
Failure loading .position.npy. Must recompute.

Computing linear positions for all active epochs for session... Saving updated position results results to R:\data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.position.npy... 2006-6-07_11-26-53.position.npy saved
done.
	 Failure loading .interpolated_spike_positions.npy. Must recompute.

	 Saving updated interpolated spike position results results to R:\data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.interpolated_spike_positions.npy.

'object size: 144.8962745666504 MB'

In [10]:
def build_any_maze_epochs_filters(sess):
    ## All Spikes:
    sess.epochs.t_start = 22.26 # exclude the first short period where the animal isn't on the maze yet
    active_session_filter_configurations = {'maze1': lambda x: (x.filtered_by_epoch(x.epochs.get_named_timerange('maze1')), x.epochs.get_named_timerange('maze1')),
                                        'maze2': lambda x: (x.filtered_by_epoch(x.epochs.get_named_timerange('maze2')), x.epochs.get_named_timerange('maze2')),
                                        'maze': lambda x: (x.filtered_by_epoch(NamedTimerange(name='maze', start_end_times=[x.epochs['maze1'][0], x.epochs['maze2'][1]])), NamedTimerange(name='maze', start_end_times=[x.epochs['maze1'][0], x.epochs['maze2'][1]]))
                                       }
    return active_session_filter_configurations

active_session_filter_configurations = build_any_maze_epochs_filters(curr_kdiba_pipeline.sess)
curr_kdiba_pipeline.filter_sessions(active_session_filter_configurations)
active_grid_bin = compute_position_grid_bin_size(curr_kdiba_pipeline.sess.position.x, curr_kdiba_pipeline.sess.position.y, num_bins=(64, 64))
active_session_computation_config = PlacefieldComputationParameters(speed_thresh=2.0, grid_bin=active_grid_bin, smooth=(1.5, 1.5), frate_thresh=0.1, time_bin_size=0.5)
curr_kdiba_pipeline.perform_computations(active_session_computation_config)
curr_kdiba_pipeline.prepare_for_display() # TODO: pass a display config

Applying session filter named "maze1"...
Constraining to epoch with times (start: 22.26, end: 1739.1533641185379)
Applying session filter named "maze2"...
Constraining to epoch with times (start: 1739.1533641185379, end: 1932.4200048116618)
Applying session filter named "maze"...
Constraining to epoch with times (start: 22.26, end: 1932.4200048116618)
Performing single_computation on filtered_session with filter named "maze1"...
Recomputing active_epoch_placefields... 	 done.
Recomputing active_epoch_placefields2D... 	 done.
Performing perform_registered_computations(...) with 1 registered_computation_functions...


  C_tau_n = 1.0 / np.sum(un_normalized_result) # normalize the result
  result = C_tau_n * un_normalized_result


Performing single_computation on filtered_session with filter named "maze2"...
Recomputing active_epoch_placefields... 	 done.
Recomputing active_epoch_placefields2D... 	 done.
Performing perform_registered_computations(...) with 1 registered_computation_functions...
Performing single_computation on filtered_session with filter named "maze"...
Recomputing active_epoch_placefields... 	 done.
Recomputing active_epoch_placefields2D... 	 done.
Performing perform_registered_computations(...) with 1 registered_computation_functions...


  C_tau_n = 1.0 / np.sum(un_normalized_result) # normalize the result


The specified cmap supports less colors than n_neurons (supports 7, n_neurons: 64). An extended colormap will be built.
The specified cmap supports less colors than n_neurons (supports 7, n_neurons: 64). An extended colormap will be built.
The specified cmap supports less colors than n_neurons (supports 7, n_neurons: 64). An extended colormap will be built.


In [None]:
def build_lap_epochs_filters(sess):
    lap_specific_epochs = sess.laps.as_epoch_obj()
    any_lap_specific_epochs = lap_specific_epochs.label_slice(lap_specific_epochs.labels[np.arange(len(sess.laps.lap_id))])
    even_lap_specific_epochs = lap_specific_epochs.label_slice(lap_specific_epochs.labels[np.arange(0, len(sess.laps.lap_id), 2)])
    odd_lap_specific_epochs = lap_specific_epochs.label_slice(lap_specific_epochs.labels[np.arange(1, len(sess.laps.lap_id), 2)])
    ## All Spikes:
    sess.epochs.t_start = 22.26 # exclude the first short period where the animal isn't on the maze yet
    active_session_filter_configurations = {'maze1': lambda x: (x.filtered_by_neuron_type('pyramidal').filtered_by_epoch(x.epochs.get_named_timerange('maze1')), x.epochs.get_named_timerange('maze1')),
                                        'maze2': lambda x: (x.filtered_by_neuron_type('pyramidal').filtered_by_epoch(x.epochs.get_named_timerange('maze2')), x.epochs.get_named_timerange('maze2')),
                                        'maze': lambda x: (x.filtered_by_neuron_type('pyramidal').filtered_by_epoch(NamedTimerange(name='maze', start_end_times=[x.epochs['maze1'][0], x.epochs['maze2'][1]])), NamedTimerange(name='maze', start_end_times=[x.epochs['maze1'][0], x.epochs['maze2'][1]]))
                                       }
    return active_session_filter_configurations

active_session_filter_configurations = build_lap_epochs_filters(curr_kdiba_pipeline.sess)
curr_kdiba_pipeline.filter_sessions(active_session_filter_configurations)
active_grid_bin = compute_position_grid_bin_size(curr_kdiba_pipeline.sess.position.x, curr_kdiba_pipeline.sess.position.y, num_bins=(64, 64))
active_session_computation_config = PlacefieldComputationParameters(speed_thresh=0.0, grid_bin=active_grid_bin, smooth=(0.5, 0.5), frate_thresh=0.1, time_bin_size=0.25)
curr_kdiba_pipeline.perform_computations(active_session_computation_config)
curr_kdiba_pipeline.prepare_for_display() # TODO: pass a display config


In [11]:
from neuropy.plotting.ratemaps import enumTuningMap2DPlotVariables
curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_2d_placefield_result_plot_ratemaps_2D, 'maze1', enable_spike_overlay=False, plot_variable=enumTuningMap2DPlotVariables.FIRING_MAPS, fignum=1) # works!


data_aspect_ratio: (7.988962073389786, Width_Height_Tuple(width=241.7178791533281, height=30.256480996256016))
Pagination is disabled because one of the subplots values is None. Output will be in a single figure/page.
page_grid_sizes: [RowColTuple(num_rows=22, num_columns=3)]
resolution_multiplier: 1.0, required_figure_size: (24.0, 22.0)
page_idx: 0


In [10]:
curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_2d_placefield_result_plot_ratemaps_2D, 'maze1', enable_spike_overlay=False, plot_variable=enumTuningMap2DPlotVariables.TUNING_MAPS, fignum=2) # works!

data_aspect_ratio: (7.988962073389786, Width_Height_Tuple(width=241.7178791533281, height=30.256480996256016))
Pagination is disabled because one of the subplots values is None. Output will be in a single figure/page.
page_grid_sizes: [RowColTuple(num_rows=22, num_columns=3)]
resolution_multiplier: 1.0, required_figure_size: (24.0, 22.0)
page_idx: 0


In [17]:
from scipy import signal
window = signal.parzen(51)
plt.plot(window)
plt.title("Parzen window")
plt.ylabel("Amplitude")
plt.xlabel("Sample")

Text(0.5, 0, 'Sample')

In [None]:
# curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_3d_interactive_custom_data_explorer, 'maze1') # works!
curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_3d_interactive_tuning_curves_plotter, 'maze1') # works!

In [None]:
curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_1d_placefield_validations, 'maze1') # works!

## Position Decoding:

In [None]:
## Stock Decoder:
from neuropy.analyses.decoders import Decode1d

curr_result_label = 'maze1'
sess = curr_kdiba_pipeline.filtered_sessions[curr_result_label]
pf = curr_kdiba_pipeline.computation_results[curr_result_label].computed_data['pf1D']

def stock_1d_decoder(sess, pf, curr_result_label):
    maze1 = sess.paradigm[curr_result_label]
    # rpls = sess.ripple.time_slice(maze1[0], maze1[1])
    rpls = None
    pf_neurons = sess.neurons.get_by_id(pf.ratemap.neuron_ids) 
    decode = Decode1d(neurons=pf_neurons, ratemap = pf.ratemap, epochs=rpls, bin_size=0.02)
    return decode
    
def validate_stock_1d_decoder(sess, decode):
    # Plot to validate decoder:
    np.shape(decode.decoded_position) # (85845,)
    plt.plot(decode.decoded_position)
    ax = plt.gca()
    # ax.xlim() # (-4292.2, 90136.2)
    ax.set_xlim(10000, 12000)

np.shape(decode.posterior) # (48, 85845)

In [108]:
## Zhang Velocity/Position For 2-step Bayesian Decoder:
from neuropy.analyses.placefields import PfND
from pyphoplacecellanalysis.General.Decoder.decoder_result import build_position_df_discretized_binned_positions

##TODO: compute the average speed at each position x:
active_sess = curr_kdiba_pipeline.filtered_sessions['maze1']
active_position_df = active_sess.position.to_dataframe()
active_computation_config = curr_kdiba_pipeline.computation_results['maze1'].computed_data['pf2D'].config


def _compute_avg_speed_at_each_position_bin(active_position_df, active_computation_config, show_plots=False, debug_print=False):
    """ compute the average speed at each position x: """
    ## Non-working attempt to use edges instead of bins:
    # xbin_edges = xbin + [(xbin[-1] + (xbin[1] - xbin[0]))] # add an additional (right) edge to the end of the xbin array for use with pd.cut
    # ybin_edges = ybin + [(ybin[-1] + (ybin[1] - ybin[0]))] # add an additional (right) edge to the end of the ybin array for use with pd.cut
    # print(f'xbin_edges: {np.shape(xbin_edges)}\n ybin_edges: {np.shape(ybin_edges)}, np.shape(np.arange(len(xbin_edges))): {np.shape(np.arange(len(xbin_edges)))}')
    # active_position_df['binned_x'] = pd.cut(active_position_df['x'].to_numpy(), bins=xbin_edges, include_lowest=True, labels=np.arange(start=1, stop=len(xbin_edges))) # same shape as the input data 
    # active_position_df['binned_y'] = pd.cut(active_position_df['y'].to_numpy(), bins=ybin_edges, include_lowest=True, labels=np.arange(start=1, stop=len(ybin_edges))) # same shape as the input data 

    def _compute_group_stats_for_var(active_position_df, variable_name:str = 'speed'):
        # For each unique binned_x and binned_y value, what is the average velocity_x at that point?
        position_bin_dependent_specific_average_velocities = active_position_df.groupby(['binned_x','binned_y'])[variable_name].agg([np.nansum, np.nanmean, np.nanmin, np.nanmax]).reset_index() #.apply(lambda g: g.mean(skipna=True)) #.agg((lambda x: x.mean(skipna=False)))
        # position_bin_dependent_specific_average_velocities # 1856 rows
        output = np.zeros((len(xbin), len(ybin))) # (65, 30)
        # np.shape(output)
        output[position_bin_dependent_specific_average_velocities['binned_x'].to_numpy()-1, position_bin_dependent_specific_average_velocities['binned_y'].to_numpy()-1] = position_bin_dependent_specific_average_velocities['nanmean'].to_numpy() # ValueError: shape mismatch: value array of shape (1856,) could not be broadcast to indexing result of shape (1856,2,30)
        return output

    outputs = dict()
    outputs['speed'] = _compute_group_stats_for_var(active_position_df, 'speed')
    if show_plots:
        plt.figure(num=8)
        plt.imshow(outputs['speed'])
        plt.title('speed')

    outputs['velocity_x'] = _compute_group_stats_for_var(active_position_df, 'velocity_x')
    if show_plots:
        plt.figure(num=9)
        plt.imshow(outputs['velocity_x'])
        plt.title('velocity_x')

    outputs['acceleration_x'] = _compute_group_stats_for_var(active_position_df, 'acceleration_x')
    if show_plots:
        plt.figure(num=10)
        plt.imshow(outputs['acceleration_x'])
        plt.title('acceleration_x')

    return outputs['speed']


active_sess.position.df = build_position_df_discretized_binned_positions(active_sess.position.df, active_computation_config, debug_print=False) # update the session's position dataframe with the new columns.
avg_speed_per_pos = _compute_avg_speed_at_each_position_bin(active_sess.position.to_dataframe(), active_computation_config)
avg_speed_per_pos

# non_empty_position_bin_dependent_specific_average_velocities = position_bin_dependent_specific_average_velocities[position_bin_dependent_specific_average_velocities['nansum'].to_numpy() > 0.001]
# non_empty_position_bin_dependent_specific_average_velocities # 790 rows

array([[        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [        nan,         nan, 15.08023389, ...,         nan,
                nan,  0.        ],
       [        nan,         nan, 11.43465014, ...,         nan,
                nan,  0.        ],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])

In [86]:
avg_speed_per_pos # called U_x in the paper.

from scipy.stats import multivariate_normal

class Zhang_Two_Step:
    @classfunc
    def sigma_t(cls, v_t, K, V, d:float=1.0):
        """ The standard deviation of the Gaussian prior for position. Once computed and normalized, can be used such that it only requires the current position (x_t) to return the correct std_dev at a given timestamp.
        K, V are constants
        d is 0.5 for random walks and 1.0 for linear movements. 
        """
        return K * np.power((v_t / V), d)
    
    @classfunc
    def compute_conditional_probability_x_prev_given_x_t(cls, C, x_prev, t):
        """ Should return a value for all possible current locations x_t. x_prev should be a concrete position, not a matrix of them. """
        # multivariate_normal.pdf()
        
        return C * np.exp(-np.square(np.norm(x - x_prev))/(2.0*np.square(sigma_t)))
        
        # output = multivariate_normal.pdf()
        
    @classfunc
    def compute_bayesian_two_step_prob_single_timestep(cls, one_step_p_x_given_n, t, x_prev, C, k):
        return k * one_step_p_x_given_n * cls.compute_conditional_probability_x_prev_given_x_t(C, x_prev, t)
    

def v_t(window_t):
    # get position during that window step
    
    # use that position x_t to get the velocity during this window.
    

# .apply(lambda g: g.mean(skipna=False))
# # --- occupancy map calculation -----------
# if (position.ndim > 1):
#     occupancy, xedges, yedges = Pf2D._compute_occupancy(self.x, self.y, xbin, ybin, self.position_srate, smooth)
# else:
#     occupancy, xedges = Pf1D._compute_occupancy(self.x, xbin, self.position_srate, smooth[0])

array([[        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [        nan,         nan, 15.08023389, ...,         nan,
                nan,  0.        ],
       [        nan,         nan, 11.43465014, ...,         nan,
                nan,  0.        ],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [        nan,         nan,         nan, ...,         nan,
                nan,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])

<!-- % $$\int_{a}^b f(x)dx$$ -->
<!-- Euler's identity: $ e^{i \pi} + 1 = 0 $ -->

## One-step Bayesian Decoder:
$$P(\overrightarrow{x}_{t}|\overrightarrow{n}_{t})$$

$$P(\overrightarrow{n}|\overrightarrow{x})$$ : probability for the numbers of spikes $\overrightarrow{n}$ to occur given we know the animal is at location $\overrightarrow{x}$

## Two-step Bayesian Decoder:
$$P(\overrightarrow{x}_{t}|\overrightarrow{n}_{t}, \overrightarrow{x}_{t-1}) = k P(\overrightarrow{x}_{t}|\overrightarrow{n}_{t}) P(\overrightarrow{x}_{t-1}|\overrightarrow{x}_{t})$$

In [5]:

curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_plot_most_likely_position_comparisons, 'maze1') # works!

# fig, axs = plot_most_likely_position_comparsions(pho_custom_decoder, sess.position.to_dataframe())


# axs[0].set_xlim(500, 550)
# plt.ioff()

## Position Dataframe Binning in Time:

In [118]:
from neuropy.utils.align_util import align_data
from pyphoplacecellanalysis.General.Decoder.decoder_result import build_position_df_resampled_to_time_windows, build_position_df_time_window_idx

active_pos_df = build_position_df_resampled_to_time_windows(curr_kdiba_pipeline.filtered_sessions['maze1'].position.to_dataframe(), time_bin_size=active_session_computation_config.time_bin_size)

def add_timedelta_column_to_pos_df(active_pos_df):
    position_time_delta = pd.to_timedelta(active_pos_df[active_pos_df.position.time_variable_name], unit="sec")
    active_pos_df['time_delta_sec'] = position_time_delta
    return active_pos_df

# Update the position dataframe with a time_delta column
curr_kdiba_pipeline.filtered_sessions['maze1'].position.df = add_timedelta_column_to_pos_df(curr_kdiba_pipeline.filtered_sessions['maze1'].position.df)

active_pos_df = curr_kdiba_pipeline.filtered_sessions['maze1'].position.df.copy() # don't change the index of the actual position df, make a copy
active_pos_df = active_pos_df.set_index('time_delta_sec')

active_pos_df

Unnamed: 0_level_0,t,x,y,lin_pos,speed,dt,velocity_x,acceleration_x,velocity_y,acceleration_y,x_smooth,y_smooth,velocity_x_smooth,acceleration_x_smooth,velocity_y_smooth,acceleration_y_smooth,binned_x,binned_y
time_delta_sec,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
0 days 00:00:22.267855,22.267855,228.503636,145.468764,-46.089873,4.527751,0.033407,-1.892589,-0.363896,4.107252,0.789719,229.101522,144.171245,-1.887591,-0.071637,4.096405,0.155464,55,21
0 days 00:00:22.301320,22.301320,228.440530,145.605714,-46.027268,4.519148,0.033465,-1.885719,0.205287,4.092343,-0.445509,229.038607,144.307781,-1.888068,-0.071449,4.097439,0.155057,55,21
0 days 00:00:22.334963,22.334963,228.377562,145.742365,-45.964681,4.509270,0.033643,-1.871642,0.418426,4.061793,-0.908058,228.975675,144.444355,-1.888121,-0.058911,4.097555,0.127846,55,21
0 days 00:00:22.367647,22.367647,228.314608,145.878987,-45.902119,4.508311,0.032684,-1.926149,-1.667703,4.180084,3.619208,228.912721,144.580977,-1.885664,0.023121,4.092224,-0.050176,55,22
0 days 00:00:22.401755,22.401755,228.251478,146.015989,-45.839591,4.520871,0.034108,-1.850876,2.206925,4.016726,-4.789417,228.849758,144.717616,-1.887591,-0.101013,4.096404,0.219215,55,22
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0 days 00:28:58.997945,1738.997945,211.383281,138.712202,-28.953781,15.351718,0.031953,-15.635207,-24.437137,3.540465,5.533591,215.388793,137.166853,-9.927697,-17.447514,3.928886,6.108043,50,15
0 days 00:28:59.031413,1739.031413,210.881820,138.825754,-28.475727,15.409162,0.033468,-14.983303,19.478434,3.392847,-4.410733,215.039589,137.304090,-10.476674,-16.326759,4.122275,5.904925,50,15
0 days 00:28:59.064597,1739.064597,210.375746,138.940351,-28.043462,15.550918,0.033184,-15.250553,-8.053578,3.453364,1.823667,214.671762,137.447850,-11.046550,-16.946740,4.317797,5.970330,50,15
0 days 00:28:59.099261,1739.099261,209.868561,139.055199,-27.704419,15.585046,0.034664,-14.631461,17.859780,3.313175,-4.044202,214.285218,137.598141,-11.571283,-15.611819,4.507992,5.820545,50,15


In [None]:
aligned_active_pos_df = active_pos_df.copy()

time_window_edges_time_delta = pd.to_timedelta(pho_custom_decoder.time_window_edges, unit="sec") # convert the windows to timedeltas as well to allow efficient comparison
time_window_edges_time_delta
# aligned_active_pos_df.between_time

# aligned_active_pos_df.at_time(pho_custom_decoder.time_window_edges)

# pho_custom_decoder.time_window_edges
# aligned_active_pos_df.align(time_window_edges_time_delta, fill_value=np.nan)

# build_position_df_time_window_idx(active_pos_df, pho_custom_decoder.active_time_window_centers)

# pho_custom_decoder.active_time_window_centers
# pho_custom_decoder.time_window_edges

# s21, s22 = ts_2.align(ts_1, fill_value=0)

# active_sess.position.df
# aligned_active_pos_df

In [123]:
pho_custom_decoder = curr_kdiba_pipeline.computation_results['maze1'].computed_data['pf2D_Decoder']
enable_plots = True

print(f'most_likely_positions: {np.shape(pho_custom_decoder.most_likely_positions)}') # most_likely_positions: (3434, 2)


def spike_count_and_firing_rate_normalizations(pho_custom_decoder, enable_plots=True):
    """ Computes several different normalizations of binned firing rate and spike counts, optionally plotting them. 
    
    Usage:
        pho_custom_decoder = curr_kdiba_pipeline.computation_results['maze1'].computed_data['pf2D_Decoder']
        enable_plots = True
        unit_specific_time_binned_outputs = spike_count_and_firing_rate_normalizations(pho_custom_decoder, enable_plots=enable_plots)
        spike_proportion_global_fr_normalized, firing_rate, firing_rate_global_fr_normalized = unit_specific_time_binned_outputs # unwrap the output tuple:
    """
    # produces a fraction which indicates which proportion of the window's firing belonged to each unit (accounts for global changes in firing rate (each window is scaled by the toial spikes of all cells in that window)
    unit_specific_time_binned_spike_proportion_global_fr_normalized = pho_custom_decoder.unit_specific_time_binned_spike_counts / pho_custom_decoder.total_spike_counts_per_window
    if enable_plots:
        plt.figure(num=5)
        plt.imshow(unit_specific_time_binned_spike_proportion_global_fr_normalized, cmap='turbo', aspect='auto')
        plt.title('Unit Specific Proportion of Window Spikes')
        plt.xlabel('Binned Time Window')
        plt.ylabel('Neuron Proportion Activity')

    # print(pho_custom_decoder.time_window_edges_binning_info.step)
    # print(f'pho_custom_decoder: {pho_custom_decoder}')
    # np.shape(pho_custom_decoder.F) # (1856, 64)

    unit_specific_time_binned_firing_rate = pho_custom_decoder.unit_specific_time_binned_spike_counts / pho_custom_decoder.time_window_edges_binning_info.step
    # print(unit_specific_time_binned_firing_rate)
    if enable_plots:
        plt.figure(num=6)
        plt.imshow(unit_specific_time_binned_firing_rate, cmap='turbo', aspect='auto')
        plt.title('Unit Specific Binned Firing Rates')
        plt.xlabel('Binned Time Window')
        plt.ylabel('Neuron Firing Rate')


    # produces a unit firing rate for each window that accounts for global changes in firing rate (each window is scaled by the firing rate of all cells in that window
    unit_specific_time_binned_firing_rate_global_fr_normalized = unit_specific_time_binned_spike_proportion_global_fr_normalized / pho_custom_decoder.time_window_edges_binning_info.step
    if enable_plots:
        plt.figure(num=7)
        plt.imshow(unit_specific_time_binned_firing_rate_global_fr_normalized, cmap='turbo', aspect='auto')
        plt.title('Unit Specific Binned Firing Rates (Global Normalized)')
        plt.xlabel('Binned Time Window')
        plt.ylabel('Neuron Proportion Firing Rate')

    # Return the computed values, leaving the original data unchanged.
    return unit_specific_time_binned_spike_proportion_global_fr_normalized, unit_specific_time_binned_firing_rate, unit_specific_time_binned_firing_rate_global_fr_normalized


    unit_specific_time_binned_outputs = spike_count_and_firing_rate_normalizations(pho_custom_decoder, enable_plots=enable_plots)
    spike_proportion_global_fr_normalized, firing_rate, firing_rate_global_fr_normalized = unit_specific_time_binned_outputs # unwrap the output tuple:

    # pho_custom_decoder.unit_specific_time_binned_spike_counts
    # pho_custom_decoder.time_window_edges
    # pho_custom_decoder.time_window_edges_binning_info
    # pho_custom_decoder.total_spike_counts_per_window
    # curr_kdiba_pipeline.pf.xbin


    # active_pos_df = curr_kdiba_pipeline.filtered_sessions['maze1'].position.to_dataframe()
    # time_window_edges, time_window_edges_binning_info = compute_spanning_bins(active_pos_df['x'].to_numpy(), bin_size=max_time_bin_size) # np.shape(out_digitized_variable_bins)[0] == np.shape(spikes_df)[0]
    # assert np.shape(time_window_edges)[0] < np.shape(spikes_df)[0], f'spikes_df[time_variable_name]: {np.shape(spikes_df[time_variable_name])} should be less than time_window_edges: {np.shape(time_window_edges)}!'

    active_sess.position.df

    # active_aligned_pos_df = align_data(pho_custom_decoder.active_time_window_centers, active_pos_df['t'].to_numpy(), active_pos_df[['x','y']].to_numpy())
    # active_aligned_pos_df = align_data(pho_custom_decoder.active_time_window_centers, active_pos_df.index, active_pos_df['x'])
    # active_aligned_pos_df

most_likely_positions: (3434, 2)


In [30]:
# from pyphoplacecellanalysis.General.Decoder.decoder_result import DecoderResultDisplayingPlot2D
# def _display_decoder_result():
#     renderer = DecoderResultDisplayingPlot2D(pho_custom_decoder, active_pos_df)
#     def animate(i):
#         # print(f'animate({i})')
#         return renderer.display(i)
#     interact(animate, i=(0, pho_custom_decoder.num_time_windows, 10))

curr_kdiba_pipeline.display(DefaultDisplayFunctions._display_decoder_result, 'maze1') # works!

In [None]:
# @pn.interact(i=(0,pho_custom_decoder.num_time_windows,1,0))
# @interact(i=pn.widgets.IntSlider(start=0,end=pho_custom_decoder.num_time_windows,step=1,value=0))

# ani = FuncAnimation(renderer.fig, animate, interval=300)
# interact(animate, i=(0, pho_custom_decoder.num_time_windows, 10))

# pn.Column('**A custom interact layout**', pn.Row(layout[0], layout[1]))

In [None]:
def debug_plot_filtered_subsession_neuron_differences(sess, filtered_sess):
    print_subsession_neuron_differences(sess.neurons, active_epoch_session.neurons)

[debug_plot_filtered_subsession_neuron_differences(curr_kdiba_pipeline.sess, a_filtered_sess) for a_filtered_sess in curr_kdiba_pipeline.filtered_sessions]

In [None]:
curr_kdiba_pipeline.computation_results['maze2']

In [None]:
_display_result(curr_kdiba_pipeline.computation_results['maze1'])
_display_result(curr_kdiba_pipeline.computation_results['maze2'])

In [None]:
import matplotlib.pyplot as plt
from neuropy.utils.misc import is_iterable
from neuropy.plotting.figure import pretty_plot
from scipy.ndimage import gaussian_filter, gaussian_filter1d, interpolation

from pyphoplacecellanalysis.Analysis.reliability import compute_lap_to_lap_reliability

from PhoPositionalData.plotting.laps import plot_laps_2d


def _test_plotRaw_v_time(active_pf, cellind, speed_thresh=False, alpha=0.5, ax=None):
        """ Builds one subplot for each dimension of the position data
        
        Updated to work with both 1D and 2D Placefields """   
        if ax is None:
            fig, ax = plt.subplots(active_pf.ndim, 1, sharex=True)
            fig.set_size_inches([23, 9.7])
        
        if not is_iterable(ax):
            ax = [ax]
            
        # plot trajectories
        if active_pf.ndim < 2:
            variable_array = [active_pf.x]
            label_array = ["X position (cm)"]
        else:
            variable_array = [active_pf.x, active_pf.y]
            label_array = ["X position (cm)", "Y position (cm)"]
            
        for a, pos, ylabel in zip(ax, variable_array, label_array):
            a.plot(active_pf.t, pos)
            a.set_xlabel("Time (seconds)")
            a.set_ylabel(ylabel)
            pretty_plot(a)

        # Grab correct spike times/positions
        if speed_thresh:
            spk_pos_, spk_t_ = active_pf.run_spk_pos, active_pf.run_spk_t
        else:
            spk_pos_, spk_t_ = active_pf.spk_pos, active_pf.spk_t

        # plot spikes on trajectory
        for a, pos in zip(ax, spk_pos_[cellind]):
            a.plot(spk_t_[cellind], pos, ".", color=[0, 0, 0.8, alpha])

        # Put info on title
        ax[0].set_title(
            "Cell "
            + str(active_pf.cell_ids[cellind])
            + ":, speed_thresh="
            + str(active_pf.speed_thresh)
        )
        return ax


def compute_reliability_metrics(out_indicies, out_digitized_position_bins, out_within_lap_spikes_overlap, debug_print=False, plot_results=False):
    """ Takes input from compute_lap_to_lap_reliability(...) to build the actual reliability metrics """
    # Actual Computations of Reliability:
    out_pairwise_pair_results = np.zeros_like(out_within_lap_spikes_overlap)
    
    # do simple diff:
    laps_spikes_overlap_diff = np.diff(out_within_lap_spikes_overlap, axis=1) # the element-wise diff of the overlap. Shows changes.
    out_pairwise_pair_results[:, 1:] = laps_spikes_overlap_diff
    # out_pairwise_pair_results[:, -1] = np.zeros_like(out_within_lap_spikes_overlap[:,0])
    
    # do custom pairwise operation:
#     for first_item_lap_idx, next_item_lap_idx in list(out_pairwise_flat_lap_indicies):
#         first_item = out_within_lap_spikes_overlap[:, first_item_lap_idx]
#         next_item = out_within_lap_spikes_overlap[:, next_item_lap_idx]
#         out_pairwise_pair_results[:, next_item_lap_idx] = (first_item * next_item) # the result should be stored in the index of the second item, if we're doing the typical backwards style differences.
#         # print(f'np.max(out_pairwise_pair_results[:, next_item_lap_idx]): {np.max(out_pairwise_pair_results[:, next_item_lap_idx])}')

    if debug_print: 
        print(f'max out: {np.max(out_pairwise_pair_results)}')
        
    lap_ids 
    flat_lap_idxs = np.arange(len(lap_ids))
    
    
    # add to the extant plot as a new color:
    if plot_results:
        for lap_idx, lap_ID in zip(flat_lap_idxs, lap_ids):
            # curr_lap_alt_ax = axs[lap_idx]
            if plot_horizontal:
                curr_lap_alt_ax = axs[lap_idx].twiny()
                curr_lap_alt_ax.plot(out_pairwise_pair_results[:, lap_idx], out_digitized_position_bins, '--r')
            else:
                # vertical
                curr_lap_alt_ax = axs[lap_idx].twinx()
                curr_lap_alt_ax.plot(out_digitized_position_bins, out_pairwise_pair_results[:, lap_idx], '--r')
            
    cum_laps_reliability = np.cumprod(out_within_lap_spikes_overlap, axis=1)
    all_laps_reliability = np.prod(out_within_lap_spikes_overlap, axis=1, keepdims=True)
    
    if plot_results:
        fig_result, axs_result = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(24, 40))
        axs_result[0].plot(out_digitized_position_bins, all_laps_reliability, 'r')
        axs_result[1].plot(out_digitized_position_bins, cum_laps_reliability, 'r')

curr_result_label = 'maze1'
sess = curr_kdiba_pipeline.filtered_sessions[curr_result_label]
sess = curr_kdiba_pipeline.sess

In [None]:
curr_laps_df = sess.laps.to_dataframe()
curr_laps_df

In [None]:
pos_df = sess.compute_position_laps() # ensures the laps are computed if they need to be:
position_obj = sess.position
position_obj.compute_higher_order_derivatives()
pos_df = position_obj.compute_smoothed_position_info(N=20) ## Smooth the velocity curve to apply meaningful logic to it
pos_df = position_obj.to_dataframe()
pos_df

In [None]:
# fig, out_axes_list = plot_laps_2d(sess, legacy_plotting_mode=True)
fig, out_axes_list = plot_laps_2d(sess, legacy_plotting_mode=False)
out_axes_list[0].set_title('New Pho Position Thresholding Estimated Laps')

curr_cell_idx = 2 
# curr_cell_idx = 3 # good for end platform analysis
curr_cell_ID = sess.spikes_df.spikes.neuron_ids[curr_cell_idx]
print(f'curr_cell_idx: {curr_cell_idx}, curr_cell_ID: {curr_cell_ID}')

# pre-filter by spikes that occur in one of the included laps for the filtered_spikes_df
filtered_spikes_df = sess.spikes_df.copy()
time_variable_name = filtered_spikes_df.spikes.time_variable_name # 't_rel_seconds'

lap_ids = sess.laps.lap_id
# lap_flat_idxs = sess.laps.get_lap_flat_indicies(lap_ids)

out_indicies, out_digitized_position_bins, out_within_lap_spikes_overlap = compute_lap_to_lap_reliability(curr_kdiba_pipeline.computation_results[curr_result_label].computed_data['pf2D'], filtered_spikes_df, lap_ids, curr_cell_idx, debug_print=False, plot_results=True);

# compute_reliability_metrics(out_indicies, out_digitized_position_bins, out_within_lap_spikes_overlap, debug_print=False, plot_results=False)

# # curr_kdiba_pipeline.computation_results['maze1'].computed_data['pf2D'].plotRaw_v_time(curr_cell_idx)
# _test_plotRaw_v_time(curr_kdiba_pipeline.computation_results[curr_result_label].computed_data['pf2D'], curr_cell_idx)

# 3D Lap Plotting Experimentation

In [None]:
from itertools import islice # for Pagination class
import pyvista as pv
import pyvistaqt as pvqt
from PhoGui.InteractivePlotter.LapsVisualizationMixin import LapsVisualizationMixin
from PhoGui.PhoCustomVtkWidgets import PhoWidgetHelper
from PhoPositionalData.plotting.spikeAndPositions import perform_plot_flat_arena, _build_flat_arena_data

""" Test Drawing Spike Lines """
# from pyphoplacecellanalysis.Pho3D.spikes import draw_line_spike, lines_from_points
from pyphoplacecellanalysis.Pho3D.points import interlieve_points

from PhoPositionalData.plotting.spikeAndPositions import build_active_spikes_plot_pointdata_df

def _plot_all_lap_spikes(p, sess, included_cell_IDXs, included_lap_IDXs, debug_print=True, lap_start_z=0.0, lap_id_dependent_z_offset=10.0):
    should_reinterpolate_spike_positions = False
        
    def _plot_single_spikes(p, cell_specific_spikes_dfs, placefield_cell_index):
        curr_cell_spike_df = cell_specific_spikes_dfs[placefield_cell_index]
        # curr_cell_spike_df['z_fixed'] = np.full_like(active_flat_df['x'].values, 1.1)
        pdata = build_active_spikes_plot_pointdata_df(curr_cell_spike_df)

        # curr_cell_spike_times = curr_cell_spike_df[curr_cell_spike_df.spikes.time_variable_name].to_numpy()  # (271,)
        # curr_cell_spike_positions = curr_cell_spike_df['x','y'].to_numpy()  # (271,)

        # lines_from_points(
        # p[0,0].add_points(pdata, name='plot_single_spikes_points', render_points_as_spheres=True, point_size=5.0)

        # Build offset points and spike data:
        spike_height = max((lap_id_dependent_z_offset * 0.6), 0.5) # half the line height

        # start_points = pdata.points.copy()
        # end_points = start_points.copy()
        # # end_points[:,2] # get z values
        # end_points[:,2] = end_points[:,2] + spike_height
        # all_points = interlieve_points(start_points, end_points)
        # lines_poly_data = pv.PolyData()
        # lines_poly_data.points = all_points
        # # cells = np.hstack(([2, 0, 1],[2, 1, 2]))
        # num_lines = np.shape(start_points)[0]
        # cells = [[2, 2*i, 2*i+1] for i in np.arange(num_lines)]
        # lines_poly_data.lines = cells
        
        p[0,0].add_mesh(pdata, name=f'plot_single_spikes_points[{placefield_cell_index}]', render_points_as_spheres=True, point_size=5.0, color='white')        
        
        # p[0,0].add_points(pdata, name=f'plot_single_spikes_points[{placefield_cell_index}]', render_points_as_spheres=True, point_size=5.0)
        # p[0,0].add_mesh(lines_poly_data, name=f'plot_single_spikes_lines[{placefield_cell_index}]', render_points_as_spheres=False, point_size=5.0)
        # return {'pdata':pdata, 'lines_poly_data':lines_poly_data}
        
        return {'pdata':pdata}
    
    time_variable_name = sess.spikes_df.spikes.time_variable_name # 't_rel_seconds'
    # sets the 'z_fixed' value for all spikes in sess.spikes_df, which will be used to plot them as points
    sess.spikes_df['z'] = lap_start_z + (lap_id_dependent_z_offset * sess.spikes_df.lap.to_numpy())

    if debug_print:
        print(f'sess.laps.lap_id: {sess.laps.lap_id}')
        
    included_cell_IDXs = np.array(included_cell_IDXs)
    included_lap_IDXs = np.array(included_lap_IDXs)
    
    # ensure that only lap_ids included in this session are used:
    included_lap_ids = sess.laps.lap_id[included_lap_IDXs]
    possible_included_lap_ids = np.unique(sess.spikes_df.lap.values)
    if debug_print:
        print(f'np.unique(sess.spikes_df.lap.values): {np.unique(sess.spikes_df.lap.values)}')
    included_lap_ids = included_lap_ids[np.isin(included_lap_ids, possible_included_lap_ids)]
    if debug_print:
        print(f'included_lap_ids: {included_lap_ids}')
    
    # get the included cell IDs
    included_cell_IDs = np.array(sess.spikes_df.spikes.neuron_ids)[included_cell_IDXs]
        
    # print(np.isin(['R','G','B','render_opacity'], sess.spikes_df.columns).all())

    # POSITIONS:
    curr_position_df, lap_specific_position_dfs, lap_specific_time_ranges, lap_specific_position_traces = LapsVisualizationMixin._compute_laps_position_data(sess)
    
    # SPIKES:
    # # grouped by lap
    # lap_grouped_spikes_df = sess.spikes_df.groupby('lap')
    # lap_specific_spikes_dfs = [lap_grouped_spikes_df.get_group(i)[[time_variable_name,'aclu','lap','flat_spike_idx','cell_type','x','y','lin_pos','z']] for i in included_lap_ids] # dataframes split for each ID:

    # grouped by cell:
    # pre-filter by spikes that occur in one of the included laps for the filtered_spikes_df
    filtered_spikes_df = sess.spikes_df.copy()
    filtered_spikes_df = filtered_spikes_df[np.isin(filtered_spikes_df['lap'], included_lap_ids)] # get only the spikes that occur in one of the included laps for the filtered_spikes_df
    
    
    # Interpolate the spikes positions again:
    if should_reinterpolate_spike_positions:
        print('Re-interpolating spike positions...')
        filtered_spikes_df = filtered_spikes_df.spikes.interpolate_spike_positions(curr_position_df['t'].to_numpy(), curr_position_df['x'].to_numpy(), curr_position_df['y'].to_numpy())
        # filtered_spikes_df = FlattenedSpiketrains.interpolate_spike_positions(filtered_spikes_df, session.position.time, session.position.x, session.position.y, spike_timestamp_column_name=time_variable_name)
    
    cell_grouped_spikes_df = filtered_spikes_df.groupby('aclu')
    cell_specific_spikes_dfs = [cell_grouped_spikes_df.get_group(i)[[time_variable_name,'aclu','lap','flat_spike_idx','cell_type','x','y','lin_pos','z']] for i in included_cell_IDs] # dataframes split for each ID:

    # lap_specific_position_dfs = _compute_laps_position_data(sess)

    # # Positions:
    # curr_position_df = sess.compute_position_laps()
    # included_pos_lap_ids = np.unique(curr_position_df.lap.values)
    # print(f'np.unique(curr_position_df.lap.values): {np.unique(curr_position_df.lap.values)}')
    # included_pos_lap_ids = included_pos_lap_ids[np.isin(included_pos_lap_ids, sess.laps.lap_id)]
    # included_pos_lap_ids

    # print(f'included_pos_lap_ids: {included_pos_lap_ids}')

    # lap_grouped_position_df = curr_position_df.groupby('lap')
    # lap_specific_position_dfs = [lap_grouped_position_df.get_group(i)[['t','aclu','x','y','lin_pos']] for i in included_lap_ids] # dataframes split for each ID:

    for i, curr_cell_ID in enumerate(included_cell_IDs):
        plot_data_dict = _plot_single_spikes(p, cell_specific_spikes_dfs, i)


# sess = curr_kdiba_pipeline.filtered_sessions['maze']
# included_cell_IDXs = [6]
# included_lap_IDXs = [2, 3]
# _plot_all_lap_spikes(pActiveInteractiveLapsPlotter, sess, included_cell_IDXs, included_lap_IDXs)

from PhoGui.InteractivePlotter.InteractiveCustomDataExplorer import InteractiveCustomDataExplorer

lap_start_z = 0.0
# lap_id_dependent_z_offset = 0.0
lap_id_dependent_z_offset = 3.0
# curr_kdiba_pipeline.active_configs['maze1'].lap_id_dependent_z_offset = 3.0

def _display_testing(sess, computation_result, active_config, extant_plotter=None):
    """ Testing of plot_lap_trajectories_2d """
    print(f'active_config.plotting_config: {active_config.plotting_config}')
    single_combined_plot=True
    if single_combined_plot:
        default_plotting = True
    else:
        default_plotting = False
    active_config.plotting_config.plotter_type = 'MultiPlotter'
    print(f'active_config.plotting_config: {active_config.plotting_config}')
    iplapsDataExplorer = InteractiveCustomDataExplorer(active_config, sess, extant_plotter=extant_plotter)
    pActiveInteractiveLapsPlotter = iplapsDataExplorer.plot(pActivePlotter=extant_plotter, default_plotting=default_plotting)
    # included_cell_idxs = None
    included_cell_idxs = [0, 1]
    # included_lap_idxs = [2, 5, 9, 12]
    included_lap_idxs = [2]
    # All
    # included_cell_idxs = np.arange(len(sess.spikes_df.spikes.neuron_ids))
    # included_lap_idxs = np.arange(len(sess.laps.lap_id))
    from PhoPositionalData.plotting.laps import plot_lap_trajectories_3d
    pActiveInteractiveLapsPlotter, laps_pages = plot_lap_trajectories_3d(sess, curr_num_subplots=5, active_page_index=0, included_lap_idxs=included_lap_idxs, single_combined_plot=single_combined_plot, 
                                                                         lap_start_z = lap_start_z, lap_id_dependent_z_offset = lap_id_dependent_z_offset,
                                                                         existing_plotter=pActiveInteractiveLapsPlotter)

    # add the spikes for the curves:
    _plot_all_lap_spikes(pActiveInteractiveLapsPlotter, sess, included_cell_idxs, included_lap_idxs, lap_start_z=lap_start_z, lap_id_dependent_z_offset=lap_id_dependent_z_offset)
    return iplapsDataExplorer, pActiveInteractiveLapsPlotter


# curr_kdiba_pipeline.computation_results['maze1'].computation_config
# curr_kdiba_pipeline.computation_results['maze1'].sess.config
# curr_kdiba_pipeline.active_configs['maze1']

curr_result_label = 'maze1'
sess = curr_kdiba_pipeline.filtered_sessions[curr_result_label]
sess = curr_kdiba_pipeline.sess

pActiveInteractiveLapsPlotter = None
try: pActiveInteractiveLapsPlotter
except NameError: pActiveInteractiveLapsPlotter = None # Checks variable p's existance, and sets its value to None if it doesn't exist so it can be checked in the next step
iplapsDataExplorer, pActiveInteractiveLapsPlotter = _display_testing(sess, curr_kdiba_pipeline.computation_results[curr_result_label], curr_kdiba_pipeline.active_configs[curr_result_label],
                                                                     extant_plotter=pActiveInteractiveLapsPlotter)
pActiveInteractiveLapsPlotter.show()

In [None]:
# lines_from_points(pdata.points)


np.shape(all_points)

In [None]:
p, laps_pages = _plot_lap_trajectories_combined_plot_3d(curr_kdiba_pipeline.sess, curr_num_subplots=5, single_combined_plot=False)
p.show()