In [None]:
%config IPCompleter.use_jedi = False
%pdb off
# %load_ext viztracer
# from viztracer import VizTracer
%load_ext autoreload
%autoreload 3
import sys
from pathlib import Path
from benedict import benedict

# required to enable non-blocking interaction:
%gui qt5

from copy import deepcopy
from numba import jit
import numpy as np
import pandas as pd
# from benedict import benedict # https://github.com/fabiocaccamo/python-benedict#usage

# Pho's Formatting Preferences
# from pyphocorehelpers.preferences_helpers import set_pho_preferences, set_pho_preferences_concise, set_pho_preferences_verbose
# set_pho_preferences_concise()

## Pho's Custom Libraries:
from pyphocorehelpers.general_helpers import CodeConversion
from pyphocorehelpers.function_helpers import function_attributes
from pyphocorehelpers.print_helpers import print_keys_if_possible, print_value_overview_only, document_active_variables, objsize, print_object_memory_usage, debug_dump_object_member_shapes, TypePrintMode
from pyphocorehelpers.print_helpers import get_now_day_str, get_now_time_str, get_now_time_precise_str
from pyphocorehelpers.Filesystem.path_helpers import find_first_extant_path

# NeuroPy (Diba Lab Python Repo) Loading
# from neuropy import core
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.core.epoch import NamedTimerange
from neuropy.core.ratemap import Ratemap
from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.core.session.Formats.Specific.BapunDataSessionFormat import BapunDataSessionFormatRegisteredClass
from neuropy.core.session.Formats.Specific.KDibaOldDataSessionFormat import KDibaOldDataSessionFormatRegisteredClass
from neuropy.core.session.Formats.Specific.RachelDataSessionFormat import RachelDataSessionFormat
from neuropy.core.session.Formats.Specific.HiroDataSessionFormat import HiroDataSessionFormatRegisteredClass

## For computation parameters:
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.utils.dynamic_container import DynamicContainer
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.session.Formats.BaseDataSessionFormats import find_local_session_paths

# pyPhoPlaceCellAnalysis:
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import NeuropyPipeline # get_neuron_identities
from pyphoplacecellanalysis.General.Mixins.ExportHelpers import export_pyqtgraph_plot
from pyphoplacecellanalysis.General.Batch.NonInteractiveWrapper import batch_load_session, batch_extended_computations, SessionBatchProgress, batch_programmatic_figures, batch_extended_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import PipelineSavingScheme
from pyphoplacecellanalysis.Pho2D.matplotlib.visualize_heatmap import visualize_heatmap

# from PendingNotebookCode import _perform_batch_plot, _build_batch_plot_kwargs

session_batch_status = {}
session_batch_errors = {}
enable_saving_to_disk = False

global_data_root_parent_path = find_first_extant_path([Path(r'W:\Data'), Path(r'/media/MAX/Data'), Path(r'/Volumes/MoverNew/data')])
assert global_data_root_parent_path.exists(), f"global_data_root_parent_path: {global_data_root_parent_path} does not exist! Is the right computer's config commented out above?"

# Load Pipeline

In [None]:
# ==================================================================================================================== #
# Load Data                                                                                                            #
# ==================================================================================================================== #

active_data_mode_name = 'kdiba'

## Data must be pre-processed using the MATLAB script located here: 
#     neuropy/data_session_pre_processing_scripts/KDIBA/IIDataMat_Export_ToPython_2022_08_01.m
# From pre-computed .mat files:

local_session_root_parent_context = IdentifyingContext(format_name=active_data_mode_name) # , animal_name='', configuration_name='one', session_name=self.session_name
local_session_root_parent_path = global_data_root_parent_path.joinpath('KDIBA')

# Animal `gor01`:
# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='gor01', exper_name='one') # IdentifyingContext<('kdiba', 'gor01', 'one')>
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name) # 'gor01', 'one'
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=['PhoHelpers', 'Spike3D-Minimal-Test', 'Unused'])

local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='gor01', exper_name='two')
local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

# ## Animal `vvp01`:
# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='vvp01', exper_name='one')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='vvp01', exper_name='two')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

# ### Animal `pin01`:
# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='pin01', exper_name='one')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name) # no exper_name ('one' or 'two') folders for this animal.
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=['redundant','showclus','sleep','tmaze'])

## Build session contexts list:
local_session_contexts_list = [local_session_parent_context.adding_context(collision_prefix='sess', session_name=a_name) for a_name in local_session_names_list] # [IdentifyingContext<('kdiba', 'gor01', 'one', '2006-6-07_11-26-53')>, ..., IdentifyingContext<('kdiba', 'gor01', 'one', '2006-6-13_14-42-6')>]

## Initialize `session_batch_status` with the NOT_STARTED status if it doesn't already have a different status
for curr_session_basedir in local_session_paths_list:
    curr_session_status = session_batch_status.get(curr_session_basedir, None)
    if curr_session_status is None:
        session_batch_status[curr_session_basedir] = SessionBatchProgress.NOT_STARTED # set to not started if not present
        # session_batch_status[curr_session_basedir] = SessionBatchProgress.COMPLETED # set to not started if not present

session_batch_status

In [None]:
%pdb off
# %%viztracer
basedir = local_session_paths_list[0] # NOT 3
print(f'basedir: {str(basedir)}')

## Read if possible:
saving_mode = PipelineSavingScheme.SKIP_SAVING
force_reload = False

# # Force write:
# saving_mode = PipelineSavingScheme.OVERWRITE_IN_PLACE
# force_reload = True

# ==================================================================================================================== #
# Load Pipeline                                                                                                        #
# ==================================================================================================================== #
# epoch_name_whitelist = ['maze']
epoch_name_whitelist = None
active_computation_functions_name_whitelist=['_perform_baseline_placefield_computation', '_perform_time_dependent_placefield_computation', '_perform_extended_statistics_computation',
                                        '_perform_position_decoding_computation', 
                                        '_perform_firing_rate_trends_computation',
                                        # '_perform_pf_find_ratemap_peaks_computation',
                                        '_perform_time_dependent_pf_sequential_surprise_computation'
                                        '_perform_two_step_position_decoding_computation',
                                        # '_perform_recursive_latent_placefield_decoding'
                                    ]
curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, epoch_name_whitelist=epoch_name_whitelist,
                                          computation_functions_name_whitelist=active_computation_functions_name_whitelist,
                                          saving_mode=saving_mode, force_reload=force_reload,
                                          skip_extended_batch_computations=True, debug_print=False, fail_on_exception=True)

In [None]:
from neuropy.analyses.placefields import PfND
from neuropy.analyses.laps import estimate_session_laps

## 2023-04-07 - Builds the laps using estimation_session_laps(...) if needed for each epoch, and then sets the decoder's .epochs property to the laps object so the occupancy is correct.
def constrain_to_laps(curr_active_pipeline):
	""" 2023-04-07 - Constrains the placefields to just the laps, computing the laps if needed.
	Other laps-related things?
		# ??? pos_df = sess.compute_position_laps() # ensures the laps are computed if they need to be:
		# DataSession.compute_position_laps(self)
		# DataSession.compute_laps_position_df(position_df, laps_df)
	"""
	long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
	long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
	long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]

	for a_name, a_sess, a_result in zip((long_epoch_name, short_epoch_name, global_epoch_name), (long_session, short_session, global_session), (long_results, short_results, global_results)):
		a_sess = estimate_session_laps(a_sess, should_plot_laps_2d=False)
		curr_laps_obj = a_sess.laps.as_epoch_obj() # set this to the laps object
		curr_laps_obj = curr_laps_obj.get_non_overlapping()
		curr_laps_obj = curr_laps_obj.filtered_by_duration(1.0, 10.0) # the lap must be at least 1 second long and at most 10 seconds long
		# curr_laps_obj = a_sess.estimate_laps().as_epoch_obj()
		curr_active_pipeline.active_configs[a_name].computation_config.pf_params.computation_epochs = curr_laps_obj
		curr_pf1D, curr_pf2D = a_result.pf1D, a_result.pf2D

		lap_filtered_curr_pf1D = deepcopy(curr_pf1D)
		lap_filtered_curr_pf1D = PfND(spikes_df=lap_filtered_curr_pf1D.spikes_df, position=lap_filtered_curr_pf1D.position, epochs=deepcopy(curr_laps_obj), config=lap_filtered_curr_pf1D.config, compute_on_init=True)
		lap_filtered_curr_pf2D = deepcopy(curr_pf2D)
		lap_filtered_curr_pf2D = PfND(spikes_df=lap_filtered_curr_pf2D.spikes_df, position=lap_filtered_curr_pf2D.position, epochs=deepcopy(curr_laps_obj), config=lap_filtered_curr_pf2D.config, compute_on_init=True)
		a_result.pf1D = lap_filtered_curr_pf1D
		a_result.pf2D = lap_filtered_curr_pf2D
		return curr_active_pipeline

curr_active_pipeline = constrain_to_laps(curr_active_pipeline)

In [None]:
long_epoch_name, short_epoch_name, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
long_session, short_session, global_session = [curr_active_pipeline.filtered_sessions[an_epoch_name] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_results, short_results, global_results = [curr_active_pipeline.computation_results[an_epoch_name]['computed_data'] for an_epoch_name in [long_epoch_name, short_epoch_name, global_epoch_name]]
long_pf1D, short_pf1D, global_pf1D = long_results.pf1D, short_results.pf1D, global_results.pf1D
long_one_step_decoder_1D, short_one_step_decoder_1D  = [results_data.get('pf1D_Decoder', None) for results_data in (long_results, short_results)]
recalculate_anyway = True

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_position_decoding_computation'], computation_kwargs_list=[dict(ndim=1)], enabled_filter_names=[long_epoch_name, short_epoch_name], fail_on_exception=True, debug_print=True)
long_pf1D, short_pf1D, global_pf1D = long_results.pf1D, short_results.pf1D, global_results.pf1D
# long_one_step_decoder_1D, short_one_step_decoder_1D  = [results_data.get('pf1D_Decoder', None) for results_data in (long_results, short_results)]
long_one_step_decoder_1D, short_one_step_decoder_1D  = [deepcopy(results_data.get('pf1D_Decoder', None)) for results_data in (long_results, short_results)]
# ds and Decoders conform between the long and the short epochs:
short_one_step_decoder_1D, did_recompute = short_one_step_decoder_1D.conform_to_position_bins(long_one_step_decoder_1D, force_recompute=True)

# ## Build or get the two-step decoders for both the long and short:
# long_two_step_decoder_1D, short_two_step_decoder_1D  = [results_data.get('pf1D_TwoStepDecoder', None) for results_data in (long_results, short_results)]
# if recalculate_anyway or did_recompute or (long_two_step_decoder_1D is None) or (short_two_step_decoder_1D is None):
#     curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_two_step_position_decoding_computation'], computation_kwargs_list=[dict(ndim=1)], enabled_filter_names=[long_epoch_name, short_epoch_name], fail_on_exception=True, debug_print=True)
#     long_two_step_decoder_1D, short_two_step_decoder_1D  = [results_data.get('pf1D_TwoStepDecoder', None) for results_data in (long_results, short_results)]
#     assert (long_two_step_decoder_1D is not None and short_two_step_decoder_1D is not None)

decoding_time_bin_size = long_one_step_decoder_1D.time_bin_size # 1.0/30.0 # 0.03333333333333333
# decoding_time_bin_size = 0.03 # 0.03333333333333333
print(f'decoding_time_bin_size: {decoding_time_bin_size}')

#### Get 2D Decoders for validation and comparisons:

In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_position_decoding_computation'], computation_kwargs_list=[dict(ndim=2)], enabled_filter_names=[long_epoch_name, short_epoch_name], fail_on_exception=True, debug_print=True)
# Make the 2D Placefields and Decoders conform between the long and the short epochs:
long_pf2D, short_pf2D, global_pf2D = long_results.pf2D, short_results.pf2D, global_results.pf2D
long_one_step_decoder_2D, short_one_step_decoder_2D  = [results_data.get('pf2D_Decoder', None) for results_data in (long_results, short_results)]
short_one_step_decoder_2D, did_recompute = short_one_step_decoder_2D.conform_to_position_bins(long_one_step_decoder_2D)

# ## Build or get the two-step decoders for both the long and short:
# long_two_step_decoder_2D, short_two_step_decoder_2D  = [results_data.get('pf2D_TwoStepDecoder', None) for results_data in (long_results, short_results)]
# if recalculate_anyway or did_recompute or (long_two_step_decoder_2D is None) or (short_two_step_decoder_2D is None):
#     curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_two_step_position_decoding_computation'], computation_kwargs_list=[dict(ndim=2)], enabled_filter_names=[long_epoch_name, short_epoch_name], fail_on_exception=True, debug_print=True)
#     long_two_step_decoder_2D, short_two_step_decoder_2D  = [results_data.get('pf2D_TwoStepDecoder', None) for results_data in (long_results, short_results)]
#     assert (long_two_step_decoder_2D is not None and short_two_step_decoder_2D is not None)
# Sums are similar:
print(f'{np.sum(long_one_step_decoder_2D.marginal.x.p_x_given_n) =},\t {np.sum(long_one_step_decoder_1D.p_x_given_n) = }') # 31181.999999999996 vs 31181.99999999999

## Validate:
assert long_one_step_decoder_2D.marginal.x.p_x_given_n.shape == long_one_step_decoder_1D.p_x_given_n.shape, f"Must equal but: {long_one_step_decoder_2D.marginal.x.p_x_given_n.shape =} and {long_one_step_decoder_1D.p_x_given_n.shape =}"
assert long_one_step_decoder_2D.marginal.x.most_likely_positions_1D.shape == long_one_step_decoder_1D.most_likely_positions.shape, f"Must equal but: {long_one_step_decoder_2D.marginal.x.most_likely_positions_1D.shape =} and {long_one_step_decoder_1D.most_likely_positions.shape =}"

## validate values:
# assert np.allclose(long_one_step_decoder_2D.marginal.x.p_x_given_n, long_one_step_decoder_1D.p_x_given_n), f"1D Decoder should have an x-posterior equal to its own posterior"
# assert np.allclose(curr_epoch_result['marginal_x']['most_likely_positions_1D'], curr_epoch_result['most_likely_positions']), f"1D Decoder should have an x-posterior with most_likely_positions_1D equal to its own most_likely_positions"

In [None]:
from neuropy.analyses.time_dependent_placefields import PfND_TimeDependent

curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_time_dependent_placefield_computation'], enabled_filter_names=[long_epoch_name, short_epoch_name, global_epoch_name], fail_on_exception=True, debug_print=True)

# long_results.pf1D_dt.reset()
# long_results.pf1D_dt._included_thresh_neurons_indx


In [None]:
curr_active_pipeline.save_pipeline(saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE)

# TypeError: cannot pickle 'MplMultiTab' object


#### 2023-04-13 - Clear Display Outputs from pipeline so it can be saved
##### TODO: Ignore `curr_active_pipeline.display_output_history_list`

In [None]:
len(curr_active_pipeline.display_output)

In [None]:
curr_active_pipeline.display_output_history_list

In [None]:
## Clear any hanging display outputs:
# do I need to close them before I just remove them?
for a_display_output_key in curr_active_pipeline.display_output_history_list:
	# a_display_output.close()
	del curr_active_pipeline.display_output[a_display_output_key]

curr_active_pipeline.display_output_history_list.clear()


In [None]:
## Get all matplotlib figures?
from pyphocorehelpers.plotting.figure_management import PhoActiveFigureManager2D, capture_new_figures_decorator
fig_man = PhoActiveFigureManager2D(name=f'fig_man') # Initialize a new figure manager
fig_man.close_all()

In [None]:
from pyphocorehelpers.print_helpers import print_object_memory_usage
print_object_memory_usage(curr_active_pipeline)

In [None]:
%help %whos

In [None]:
%lsmagic

In [None]:
%who_ls

# 2023-04-11 - Review of Year's Progress


In [None]:
%matplotlib qt
import matplotlib.pyplot as plt

In [None]:
# test for interactive
curr_active_pipeline.registered_display_function_docs_dict

In [None]:
curr_active_pipeline.reload_default_display_functions()

In [None]:
plotter = curr_active_pipeline.plot


In [None]:
%matplotlib inline


In [None]:
plotter._display_normal(active_session_configuration_context=long_epoch_name)

In [None]:
curr_active_pipeline.plot._display_short_long_firing_rate_index_comparison()

In [None]:
curr_active_pipeline.plot.display_short_long_firing_rate_index_comparison

In [None]:
# extended_computations_include_whitelist=None # do all
extended_computations_include_whitelist=['jonathan_firing_rate_analysis', 'short_long_pf_overlap_analyses', 'long_short_fr_indicies_analyses'] # do only specifiedl
newly_computed_values = batch_extended_computations(curr_active_pipeline, include_whitelist=extended_computations_include_whitelist, include_global_functions=True, fail_on_exception=True, progress_print=True, debug_print=True)


In [None]:
curr_active_pipeline.perform_specific_computation(computation_functions_name_whitelist=['_perform_short_long_firing_rate_analyses'], fail_on_exception=True, debug_print=False) # fail_on_exception MUST be True or error handling is all messed up 
print(f'\t done.')
long_short_fr_indicies_analysis_results = curr_active_pipeline.global_computation_results.computed_data['long_short_fr_indicies_analysis']
x_frs_index, y_frs_index = long_short_fr_indicies_analysis_results['x_frs_index'], long_short_fr_indicies_analysis_results['y_frs_index'] # use the all_results_dict as the computed data value
active_context = long_short_fr_indicies_analysis_results['active_context']

In [None]:
curr_jonathan_firing_rate_analysis = curr_active_pipeline.global_computation_results.computed_data['jonathan_firing_rate_analysis']
neuron_replay_stats_df, rdf, aclu_to_idx, irdf = curr_jonathan_firing_rate_analysis['neuron_replay_stats_df'], curr_jonathan_firing_rate_analysis['rdf']['rdf'], curr_jonathan_firing_rate_analysis['rdf']['aclu_to_idx'], curr_jonathan_firing_rate_analysis['irdf']['irdf']

In [None]:
short_long_pf_overlap_analyses = curr_active_pipeline.global_computation_results.computed_data.short_long_pf_overlap_analyses
conv_overlap_dict = short_long_pf_overlap_analyses['conv_overlap_dict']
conv_overlap_scalars_df = short_long_pf_overlap_analyses['conv_overlap_scalars_df']
prod_overlap_dict = short_long_pf_overlap_analyses['product_overlap_dict']
relative_entropy_overlap_dict = short_long_pf_overlap_analyses['relative_entropy_overlap_dict']
relative_entropy_overlap_scalars_df = short_long_pf_overlap_analyses['relative_entropy_overlap_scalars_df']

In [None]:
long_short_fr_indicies_analysis_results = curr_active_pipeline.global_computation_results.computed_data['long_short_fr_indicies_analysis']
x_frs_index, y_frs_index = long_short_fr_indicies_analysis_results['x_frs_index'], long_short_fr_indicies_analysis_results['y_frs_index'] # use the all_results_dict as the computed data value
active_context = long_short_fr_indicies_analysis_results['active_context']
# long_short_fr_indicies_analysis_results

In [None]:
%matplotlib qt
active_identifying_session_ctx = curr_active_pipeline.sess.get_context() # 'bapun_RatN_Day4_2019-10-15_11-30-06'

graphics_output_dict = curr_active_pipeline.display('_display_batch_pho_jonathan_replay_firing_rate_comparison', active_identifying_session_ctx)
fig, axs, plot_data = graphics_output_dict['fig'], graphics_output_dict['axs'], graphics_output_dict['plot_data']
neuron_df, rdf, aclu_to_idx, irdf = plot_data['df'], plot_data['rdf'], plot_data['aclu_to_idx'], plot_data['irdf']
# Grab the output axes:
curr_axs_dict = axs[0]
curr_firing_rate_ax, curr_lap_spikes_ax, curr_placefield_ax = curr_axs_dict['firing_rate'], curr_axs_dict['lap_spikes'], curr_axs_dict['placefield'] # Extract variables from the `curr_axs_dict` dictionary to the local workspace
## TODO 2023-04-11 - Set Window Title for '_display_batch_pho_jonathan_replay_firing_rate_comparison'

In [None]:
# Plot long|short firing rate index:
fig_save_parent_path = Path(r'E:\Dropbox (Personal)\Active\Kamran Diba Lab\Pho-Kamran-Meetings\Results from 2023-04-11')
curr_active_pipeline.display('_display_short_long_firing_rate_index_comparison', curr_active_pipeline.sess.get_context(), fig_save_parent_path=fig_save_parent_path)
# plt.close() # closes the current figure

In [None]:
# 2023-04-11 - Debug why `_display_short_long_firing_rate_index_comparison` is empty
# Plot long|short firing rate index:

long_short_fr_indicies_analysis_results = curr_active_pipeline.global_computation_results.computed_data['long_short_fr_indicies_analysis']
x_frs_index, y_frs_index = long_short_fr_indicies_analysis_results['x_frs_index'], long_short_fr_indicies_analysis_results['y_frs_index'] # use the all_results_dict as the computed data value
active_context = long_short_fr_indicies_analysis_results['active_context']
long_short_fr_indicies_analysis_results


In [None]:
curr_active_pipeline.display('_display_short_long_firing_rate_index_comparison', curr_active_pipeline.sess.get_context(), fig_save_parent_path=fig_save_parent_path)

In [None]:
active_identifying_session_ctx, active_session_figures_out_path, active_out_figures_list = batch_programmatic_figures(curr_active_pipeline)


In [None]:
batch_extended_programmatic_figures(curr_active_pipeline)

## Explore display function documentation

In [None]:
curr_active_pipeline.registered_display_function_docs_dict

In [None]:
# func.short_name, func.tags, func.creation_date, func.input_requires, func.output_provides, func.uses, func.used_by
{a_fn_name:getattr(a_fn, 'tags', None) for a_fn_name, a_fn in curr_active_pipeline.registered_display_function_dict.items()}

In [None]:
curr_active_pipeline.registered_display_function_dict['_display_2d_placefield_result_plot_ratemaps_2D'].tags

# 2023-03-16 - Explore passing in long/short decoders specifically to `perform_full_session_leave_one_out_decoding_analysis`

In [None]:
# Get existing long/short decoders from the cell under "# 2023-02-24 Decoders"
long_decoder, short_decoder = deepcopy(long_one_step_decoder_1D), deepcopy(short_one_step_decoder_1D)
assert np.all(long_decoder.xbin == short_decoder.xbin)

In [None]:
## backup existing replay objects
# long_session.replay_backup, short_session.replay_backup, global_session.replay_backup = [deepcopy(a_session.replay) for a_session in [long_session, short_session, global_session]]
# null-out the replay objects
# long_session.replay, short_session.replay, global_session.replay = [None, None, None]

# Compute/estimate replays if missing from session:
if not global_session.has_replays:
    print(f'Replays missing from sessions. Computing replays...')
    long_session.replay, short_session.replay, global_session.replay = [a_session.estimate_replay_epochs(min_epoch_included_duration=0.06, max_epoch_included_duration=None, maximum_speed_thresh=None, min_inclusion_fr_active_thresh=0.01, min_num_unique_aclu_inclusions=3).to_dataframe() for a_session in [long_session, short_session, global_session]]


### Need to prune to only the cells active in both epochs ahead of time:

In [None]:
from pyphoplacecellanalysis.Analysis.Decoder.reconstruction import BasePositionDecoder, BayesianPlacemapPositionDecoder
from pyphoplacecellanalysis.Analysis.Decoder.decoder_result import perform_full_session_leave_one_out_decoding_analysis
from pyphoplacecellanalysis.Analysis.Decoder.decoder_result import SurpriseAnalysisResult
from pyphoplacecellanalysis.General.Mixins.CrossComputationComparisonHelpers import build_neurons_color_map # for plot_short_v_long_pf1D_comparison

# Prune to the shared aclus in both epochs (short/long):
long_shared_aclus_only_decoder, short_shared_aclus_only_decoder = [BasePositionDecoder.init_from_stateful_decoder(a_decoder) for a_decoder in (long_decoder, short_decoder)]
shared_aclus, (long_shared_aclus_only_decoder, short_shared_aclus_only_decoder), long_short_pf_neurons_diff = BasePositionDecoder.prune_to_shared_aclus_only(long_shared_aclus_only_decoder, short_shared_aclus_only_decoder)
n_neurons = len(shared_aclus)
# for plotting purposes, build colors only for the common (present in both, the intersection) neurons:
neurons_colors_array = build_neurons_color_map(n_neurons, sortby=None, cmap=None)
print(f'{n_neurons = }, {neurons_colors_array.shape =}')

In [None]:
# with VizTracer(output_file=f"viztracer_{get_now_time_str()}-full_session_LOO_decoding_analysis.json", min_duration=200, tracer_entries=3000000, ignore_frozen=True) as tracer:
long_results_obj = perform_full_session_leave_one_out_decoding_analysis(global_session, original_1D_decoder=long_shared_aclus_only_decoder, decoding_time_bin_size = 0.025, cache_suffix = '_long', perform_cache_load=True) # , perform_cache_load=False
short_results_obj = perform_full_session_leave_one_out_decoding_analysis(global_session, original_1D_decoder=short_shared_aclus_only_decoder, decoding_time_bin_size = 0.025, cache_suffix = '_short', perform_cache_load=True) # , perform_cache_load=False

## 2023-04-13 - Shuffled Surprise
""" 
Relevant Functions:
`perform_full_session_leave_one_out_decoding_analysis`:
	`perform_leave_one_aclu_out_decoding_analysis`:	from pyphoplacecellanalysis.Analysis.Decoder.decoder_result import perform_leave_one_aclu_out_decoding_analysis
	`_analyze_leave_one_out_decoding_results`: from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.DefaultComputationFunctions import _analyze_leave_one_out_decoding_results
"""


In [None]:
# ## for each cell:
# for i, left_out_aclu in enumerate(original_1D_decoder.neuron_IDs):
# 	# aclu = original_1D_decoder.neuron_IDs[i]
# 	left_out_neuron_IDX = original_1D_decoder.neuron_IDXs[i] # should just be i, but just to be safe
# 	## TODO: only look at bins where the cell fires (is_cell_firing_time_bin[i])
# 	curr_cell_pf_curve = original_1D_decoder.pf.ratemap.tuning_curves[left_out_neuron_IDX]
# 	# curr_cell_spike_curve = original_1D_decoder.pf.ratemap.spikes_maps[unit_IDX] ## not occupancy weighted... is this the right one to use for computing the expected spike rate? NO... doesn't seem like it

# 	shuffled_cell_pf_curve = original_1D_decoder.pf.ratemap.tuning_curves[shuffle_IDXs[i]]

# 	left_out_decoder_result = one_left_out_filter_epochs_decoder_result_dict[left_out_aclu]

# active_epoch = results_obj.active_filter_epochs[epoch_idx]
# # Get a conversion between the epoch indicies and the flat indicies
# flat_bin_indicies = results_obj.split_by_epoch_reverse_flattened_time_bin_indicies[epoch_idx]


# for decoded_epoch_idx in np.arange(n_epochs):
# 	num_curr_epoch_time_bins = [len(self.result.one_left_out_posterior_to_pf_surprises[aclu][decoded_epoch_idx]) for aclu in self.original_1D_decoder.neuron_IDs] # get the number of time bins in this epoch to build the reverse indexing array

# for decoded_epoch_idx in np.arange(left_out_decoder_result.num_filter_epochs):
# 	long_results_obj.active_filter_epochs


for decoded_epoch_idx in np.arange(left_out_decoder_result.num_filter_epochs):
	long_results_obj.active_filter_epochs



In [None]:
results_obj = deepcopy(long_results_obj)
results_obj.all_included_filter_epochs_decoder_result

In [None]:
# Extract important things from the decoded data, like the time bins which are the same for all:
n_epochs = results_obj.all_included_filter_epochs_decoder_result.num_filter_epochs
n_timebins = np.sum(results_obj.all_included_filter_epochs_decoder_result.nbins)
shared_timebin_containers = results_obj.all_included_filter_epochs_decoder_result.time_bin_containers
# shared_timebin_
n_timebins

In [None]:
from attrs import define
## 0. Precompute the active neurons in each timebin, and the epoch-timebin-flattened decoded posteriors makes it easier to compute for a given time bin:
@define(slots=False, repr=False)
class TimebinnedNeuronActivity:
	""" keeps track of which neurons are active and inactive in each decoded timebin """
	n_timebins: int
	active_IDXs: np.ndarray
	active_aclus: np.ndarray
	inactive_IDXs: np.ndarray
	inactive_aclus: np.ndarray

	@classmethod
	def init_from_results_obj(cls, results_obj: SurpriseAnalysisResult):
		n_timebins = np.sum(results_obj.all_included_filter_epochs_decoder_result.nbins)
		# a list of lists where each list contains the aclus that are active during that timebin:
		timebins_active_neuron_IDXs = [np.array(results_obj.original_1D_decoder.neuron_IDXs)[a_timebin_is_cell_firing] for a_timebin_is_cell_firing in np.logical_not(results_obj.is_non_firing_time_bin).T]
		timebins_active_aclus = [np.array(results_obj.original_1D_decoder.neuron_IDs)[an_IDX] for an_IDX in timebins_active_neuron_IDXs]

		timebins_inactive_neuron_IDXs = [np.array(results_obj.original_1D_decoder.neuron_IDXs)[a_timebin_is_cell_firing] for a_timebin_is_cell_firing in results_obj.is_non_firing_time_bin.T]
		timebins_inactive_aclus = [np.array(results_obj.original_1D_decoder.neuron_IDs)[an_IDX] for an_IDX in timebins_inactive_neuron_IDXs]
		# timebins_p_x_given_n = np.hstack(results_obj.all_included_filter_epochs_decoder_result.p_x_given_n_list) # # .shape: (239, 5) - (n_x_bins, n_epoch_time_bins)  --TO-->  .shape: (63, 4146) - (n_x_bins, n_flattened_all_epoch_time_bins)
		return cls(n_timebins=n_timebins, active_IDXs=timebins_active_neuron_IDXs, active_aclus=timebins_active_aclus, inactive_IDXs=timebins_inactive_neuron_IDXs, inactive_aclus=timebins_inactive_aclus)

long_results_obj.timebinned_neuron_info = TimebinnedNeuronActivity.init_from_results_obj(long_results_obj)
short_results_obj.timebinned_neuron_info = TimebinnedNeuronActivity.init_from_results_obj(short_results_obj)
assert long_results_obj.timebinned_neuron_info.n_timebins == short_results_obj.timebinned_neuron_info.n_timebins

In [None]:
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.DefaultComputationFunctions import LeaveOneOutDecodingResult
from scipy.spatial import distance # for Jensen-Shannon distance in `_subfn_compute_leave_one_out_analysis`
import random # for random.choice(mylist)
from PendingNotebookCode import _scramble_curve

# @define(slots=False, repr=False)
# class PlacefieldPosteriorComputationHelper:

# 	def compute(self, curr_cell_pf_curve, curr_timebin_p_x_given_n):
# 		result.one_left_out_posterior_to_pf_surprises[timebin_IDX].append(distance.jensenshannon(curr_cell_pf_curve, curr_timebin_p_x_given_n))
# 		result.one_left_out_posterior_to_pf_correlations[timebin_IDX].append(distance.correlation(curr_cell_pf_curve, curr_timebin_p_x_given_n))


# active_surprise_metric_fn = lambda pf, p_x_given_n: distance.jensenshannon(pf, p_x_given_n)
# active_surprise_metric_fn = lambda pf, p_x_given_n: distance.correlation(pf, p_x_given_n)
active_surprise_metric_fn = lambda pf, p_x_given_n: distance.sqeuclidean(pf, p_x_given_n)


#p_x_given_n_list
#DecodedFilterEpochsResult
timebinned_neuron_info = long_results_obj.timebinned_neuron_info
result = LeaveOneOutDecodingResult(shuffle_IDXs=None)

pf_shape = (len(long_results_obj.original_1D_decoder.pf.ratemap.xbin_centers),) # (59, )
# result.random_noise_curves = []
result.random_noise_curves = np.random.uniform(low=0, high=1, size=(timebinned_neuron_info.n_timebins, *pf_shape))

# result.random_noise_curves = (result.random_noise_curves.T / np.sum(result.random_noise_curves, axis=1)).T # normalize
result.random_noise_curves = (result.random_noise_curves.T / np.max(result.random_noise_curves, axis=1)).T # unit max normalization

for index in np.arange(timebinned_neuron_info.n_timebins):
	# iterate through timebins
	if index not in result.one_left_out_posterior_to_pf_surprises:
		result.one_left_out_posterior_to_pf_surprises[index] = []
	if index not in result.one_left_out_posterior_to_scrambled_pf_surprises:
		result.one_left_out_posterior_to_scrambled_pf_surprises[index] = []

	## Pre loop: add empty array for accumulation

	# curr_random_not_firing_cell_pf_curve = np.random.uniform(low=0, high=1, size=curr_cell_pf_curve.shape) # generate one at a time
	# curr_random_not_firing_cell_pf_curve = curr_random_not_firing_cell_pf_curve / np.sum(curr_random_not_firing_cell_pf_curve) # normalize
	# result.random_noise_curves.append(curr_random_not_firing_cell_pf_curve)

	curr_random_not_firing_cell_pf_curve = result.random_noise_curves[index]

	for neuron_IDX, aclu in zip(timebinned_neuron_info.active_IDXs[index], timebinned_neuron_info.active_aclus[index]):
		# iterate through only the active cells
		# 1. Get set of cells active in a given time bin, for each compute the surprise of its placefield with the leave-one-out decoded posterior.
		left_out_decoder_result = long_results_obj.one_left_out_filter_epochs_decoder_result_dict[aclu]
		curr_cell_pf_curve = long_results_obj.original_1D_decoder.pf.ratemap.tuning_curves[neuron_IDX]
		# curr_cell_spike_curve = original_1D_decoder.pf.ratemap.spikes_maps[unit_IDX] ## not occupancy weighted... is this the right one to use for computing the expected spike rate? NO... doesn't seem like it

		curr_cell_pf_curve = long_results_obj.original_1D_decoder.pf.ratemap.unit_max_tuning_curves[neuron_IDX] # normalized_tuning_curves

		_, _, curr_timebins_p_x_given_n = left_out_decoder_result.flatten()
		curr_timebin_p_x_given_n = curr_timebins_p_x_given_n[:, index] # .shape: (239, 5) - (n_x_bins, n_epoch_time_bins)
		assert curr_timebin_p_x_given_n.shape[0] == curr_cell_pf_curve.shape[0], f"{curr_timebin_p_x_given_n.shape = } == {curr_cell_pf_curve.shape = }"
		
		# if aclu not in result.one_left_out_posterior_to_pf_surprises:
		# 	result.one_left_out_posterior_to_pf_surprises[aclu] = []
		# result.one_left_out_posterior_to_pf_surprises[aclu].append(distance.jensenshannon(curr_cell_pf_curve, curr_timebin_p_x_given_n))

		result.one_left_out_posterior_to_pf_surprises[index].append(active_surprise_metric_fn(curr_cell_pf_curve, curr_timebin_p_x_given_n))
		# result.one_left_out_posterior_to_pf_correlations[timebin_IDX].append(distance.correlation(curr_cell_pf_curve, curr_timebin_p_x_given_n))

		

		# 2. From the remainder of cells (those not active), randomly choose one to grab the placefield of and compute the surprise with that and the same posterior.
		# shuffled_cell_pf_curve = long_results_obj.original_1D_decoder.pf.ratemap.tuning_curves[shuffle_IDXs[i]]

		# # a) Use a random non-firing cell's placefield:
		# random_not_firing_neuron_IDX = random.choice(timebinned_neuron_info.inactive_IDXs[timebin_IDX])
		# # random_not_firing_aclu = random.choice(timebinned_neuron_info.inactive_aclus[i])
		# curr_random_not_firing_cell_pf_curve = long_results_obj.original_1D_decoder.pf.ratemap.tuning_curves[random_not_firing_neuron_IDX]

		# b) Use a scrambled version of the real curve:
		# curr_random_not_firing_cell_pf_curve = _scramble_curve(curr_cell_pf_curve)
		
		
		# if aclu not in result.one_left_out_posterior_to_scrambled_pf_surprises:
		# 	result.one_left_out_posterior_to_scrambled_pf_surprises[aclu] = []
		# # The shuffled cell's placefield and the posterior from leaving a cell out:
		# result.one_left_out_posterior_to_scrambled_pf_surprises[aclu].append(distance.jensenshannon(curr_random_not_firing_cell_pf_curve, curr_timebin_p_x_given_n))
		
		# The shuffled cell's placefield and the posterior from leaving a cell out:
		result.one_left_out_posterior_to_scrambled_pf_surprises[index].append(active_surprise_metric_fn(curr_random_not_firing_cell_pf_curve, curr_timebin_p_x_given_n))
		# result.one_left_out_posterior_to_scrambled_pf_correlations[timebin_IDX].append(distance.correlation(curr_random_not_firing_cell_pf_curve, curr_timebin_p_x_given_n))

	# END Neuron Loop
	## Post neuron loops: convert lists to np.arrays
	result.one_left_out_posterior_to_pf_surprises[index] = np.array(result.one_left_out_posterior_to_pf_surprises[index])
	result.one_left_out_posterior_to_scrambled_pf_surprises[index] = np.array(result.one_left_out_posterior_to_scrambled_pf_surprises[index])

# End Timebin Loop
## Post timebin loops compute mean variables:
result.one_left_out_posterior_to_pf_surprises_mean = {k:np.mean(v) for k, v in result.one_left_out_posterior_to_pf_surprises.items() if np.size(v) > 0}
result.one_left_out_posterior_to_scrambled_pf_surprises_mean = {k:np.mean(v) for k, v in result.one_left_out_posterior_to_scrambled_pf_surprises.items() if np.size(v) > 0}
assert len(result.one_left_out_posterior_to_scrambled_pf_surprises_mean) == len(result.one_left_out_posterior_to_pf_surprises_mean)
assert list(result.one_left_out_posterior_to_scrambled_pf_surprises_mean.keys()) == list(result.one_left_out_posterior_to_pf_surprises_mean.keys())

valid_time_bin_indicies = np.array(list(result.one_left_out_posterior_to_pf_surprises_mean.keys()))
one_left_out_posterior_to_pf_surprises_mean = np.array(list(result.one_left_out_posterior_to_pf_surprises_mean.values()))
one_left_out_posterior_to_scrambled_pf_surprises_mean = np.array(list(result.one_left_out_posterior_to_scrambled_pf_surprises_mean.values()))

long_results_obj.all_epochs_reverse_flat_epoch_indicies_array

# one_left_out_posterior_to_scrambled_pf_surprises_mean

result_df = pd.DataFrame({'time_bin_indices': valid_time_bin_indicies, 'epoch_IDX': long_results_obj.all_epochs_reverse_flat_epoch_indicies_array[valid_time_bin_indicies], 'posterior_to_pf_mean_surprise': one_left_out_posterior_to_pf_surprises_mean, 'posterior_to_scrambled_pf_mean_surprise': one_left_out_posterior_to_scrambled_pf_surprises_mean})
result_df['surprise_diff'] = result_df['posterior_to_scrambled_pf_mean_surprise'] - result_df['posterior_to_pf_mean_surprise']
result_df

# 24.9 seconds to compute

## Compute Aggregate Dataframe for Epoch means:
# Group by 'epoch_IDX' and compute means of all columns
result_df_grouped = result_df.groupby('epoch_IDX').mean()
result_df_grouped

In [None]:
import pyphoplacecellanalysis.External.pyqtgraph as pg
## Create a diagnostic plot that plots a stack of the three curves used for computations in the given epoch:

def _add_plot(win, data, name:str):
	plot = win.addPlot() # PlotItem has to be built first?
	curve = plot.plot(data, name=name, label=name)
	plot.setLabel('top', name)
	return plot, curve

win = pg.GraphicsLayoutWidget(show=True, title='diagnostic_plot')
plot_data = {'curr_cell_pf_curve': curr_cell_pf_curve, 'curr_random_not_firing_cell_pf_curve': curr_random_not_firing_cell_pf_curve, 'curr_timebin_p_x_given_n': curr_timebin_p_x_given_n}
plot_dict = {}

def _initialize_plots(plot_data):
	for i, (name, data) in enumerate(plot_data.items()):
		plot_item, curve = _add_plot(win, data=data, name=name)
		plot_dict[name] = {'plot_item':plot_item,'curve':curve}
		if i == 0:
			first_curve_name = name
		else:
			plot_dict[name]['plot_item'].setYLink(first_curve_name)  ## test linking by name
	return plot_dict


def _update_plots(plot_dict, updated_plot_data):
	""" updates the plots created with `_initialize_plots`"""
	for i, (name, data) in enumerate(updated_plot_data.items()):
		curr_plot = plot_dict[name]['plot_item']
		curr_curve = plot_dict[name]['curve']
		curr_curve.setData(data)

def update_function(index):
	""" Define an update function that will be called with the current slider index 
	Captures plot_dict, and all data variables
	"""
    # print(f'Slider index: {index}')
	hardcoded_sub_epoch_item_idx = 0

	curr_random_not_firing_cell_pf_curve = result.random_noise_curves[index]
	neuron_IDX, aclu = timebinned_neuron_info.active_IDXs[index], timebinned_neuron_info.active_aclus[index]
	print(f'{neuron_IDX = }, {aclu = }')
	if len(neuron_IDX) > 0:
		# Get first index
		neuron_IDX = neuron_IDX[hardcoded_sub_epoch_item_idx]
		aclu = aclu[hardcoded_sub_epoch_item_idx]
		curr_cell_pf_curve = long_results_obj.original_1D_decoder.pf.ratemap.tuning_curves[neuron_IDX]
		curr_timebin_p_x_given_n = curr_timebins_p_x_given_n[:, index]
		print(f'{curr_random_not_firing_cell_pf_curve.shape = }, {curr_cell_pf_curve.shape = }, {curr_timebin_p_x_given_n.shape = }')
		# result.one_left_out_posterior_to_pf_surprises[index]
		# result.one_left_out_posterior_to_scrambled_pf_surprises[index]
		normal_surprise, random_surprise = result.one_left_out_posterior_to_pf_surprises[index][hardcoded_sub_epoch_item_idx], result.one_left_out_posterior_to_scrambled_pf_surprises[index][hardcoded_sub_epoch_item_idx]

		updated_plot_data = {'curr_cell_pf_curve': curr_cell_pf_curve, 'curr_random_not_firing_cell_pf_curve': curr_random_not_firing_cell_pf_curve, 'curr_timebin_p_x_given_n': curr_timebin_p_x_given_n}
		_update_plots(plot_dict, updated_plot_data)

		plot_dict['curr_cell_pf_curve']['plot_item'].setLabel('bottom', f"{normal_surprise}")
		plot_dict['curr_random_not_firing_cell_pf_curve']['plot_item'].setLabel('bottom', f"{random_surprise}")
	else:
		# Invalid period
		plot_dict['curr_cell_pf_curve']['plot_item'].setLabel('bottom', f"NO ACTIVITY")
		plot_dict['curr_random_not_firing_cell_pf_curve']['plot_item'].setLabel('bottom', f"NO ACTIVITY")

plot_dict = _initialize_plots(plot_data=plot_data)

# win.graphicsItem().setLabel(axis='left', text='Short v. Long - Expected vs. Observed # Spikes')
# win.graphicsItem().setLabel(axis='bottom', text='time')

# def update(


import ipywidgets as widgets
from IPython.display import display

def integer_slider(update_func):
    """ 2023-04-13 - WORKS!!!
    Displays an integer slider that the user can adjust.

    Args:
        update_func (function): A user-provided update function that will be called with the current slider index.
    """
    slider = widgets.IntSlider(description='Slider:', min=0, max=100, value=0)

    def on_slider_change(change):
        """Callback function for slider value change."""
        if change['type'] == 'change' and change['name'] == 'value':
            # Call the user-provided update function with the current slider index
            update_func(change['new'])

    slider.observe(on_slider_change)
    display(slider)

# Define an update function that will be called with the current slider index
# def update_function(index):
#     print(f'Slider index: {index}')

# Call the integer_slider function with the update function
integer_slider(update_function)


## 2023-04-13 - Desired Plotting Interface Idea:
```python
# Rows of plots can be constructed trivially through lists:
row_of_plots = [pg.plot(curr_cell_pf_curve, label='curr_cell_pf_curve'), pg.plot(curr_random_not_firing_cell_pf_curve, label='curr_random_not_firing_cell_pf_curve'), ...] 
	# I'd guess behind the scenes they would be converted into a helper.row([...]) object

# If you want a column instead, use helper.column
column_of_plots = helper.column([pg.plot(curr_cell_pf_curve, label='curr_cell_pf_curve'), pg.plot(curr_random_not_firing_cell_pf_curve, label='curr_random_not_firing_cell_pf_curve'), ...])

# The returned objects are composable:
row_of_layouts = [row_of_plots, column_of_plots] # stacks the layout objects just like they were plot objects

# Showing the result is easy, as is combining separate results in a new place:
whole_figure_window = [row_of_layouts] 
whole_figure_window.show()
```

""" 
Relevant Functions:
`perform_full_session_leave_one_out_decoding_analysis`:
	`perform_leave_one_aclu_out_decoding_analysis`:	from pyphoplacecellanalysis.Analysis.Decoder.decoder_result import perform_leave_one_aclu_out_decoding_analysis
	`_analyze_leave_one_out_decoding_results`: from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.DefaultComputationFunctions import _analyze_leave_one_out_decoding_results
"""



In [None]:

# 1. Get set of cells active in a given time bin, for each compute the surprise of its placefield with the leave-one-out decoded posterior.

# 2. From the remainder of cells (those not active), randomly choose one to grab the placefield of and compute the surprise with that and the same posterior.


# Expectation: The cells that are included in the time bin are expected to have a lower surprise (be less correlated with) the posterior.

In [None]:
import pyphoplacecellanalysis.External.pyqtgraph as pg
# 'time_bin_indices': valid_time_bin_indicies, 'posterior_to_pf_mean_surprise': one_left_out_posterior_to_pf_surprises_mean, 'posterior_to_scrambled_pf_mean_surprise': one_left_out_posterior_to_scrambled_pf_surprises_mean}

# make a separate symbol_brush color for each cell:
# cell_color_symbol_brush = [pg.intColor(i,hues=9, values=3, alpha=180) for i, aclu in enumerate(long_results_obj.original_1D_decoder.neuron_IDs)] # maxValue=128
# All properties in common:
win = pg.plot()
win.setWindowTitle('Long Sanity Check - Leave-one-out Custom Surprise Plot')
# legend_size = (80,60) # fixed size legend
legend_size = None # auto-sizing legend to contents
legend = pg.LegendItem(legend_size, offset=(-1,0)) # do this instead of # .addLegend
legend.setParentItem(win.graphicsItem())

plots = {}
label_prefix_list = ['normal', 'scrambled']
long_short_symbol_list = ['t', 'o'] # note: 's' is a square. 'o', 't1': triangle pointing upwards0


# Use mean time_bin and surprise for each epoch
# plots['normal'] = win.plot(x=valid_time_bin_indicies, y=one_left_out_posterior_to_pf_surprises_mean, pen=None, symbol='t', symbolBrush=pg.intColor(1,6,maxValue=128), name=f'normal', alpha=0.5) #  symbolBrush=pg.intColor(i,6,maxValue=128) , symbol=curr_symbol, symbolBrush=cell_color_symbol_brush[unit_IDX]
# plots['scrambled'] = win.plot(x=valid_time_bin_indicies, y=one_left_out_posterior_to_scrambled_pf_surprises_mean, pen=None, symbol='t', symbolBrush=pg.intColor(2,6,maxValue=128), name=f'scrambled', alpha=0.5) #  symbolBrush=pg.intColor(i,6,maxValue=128) , symbol=curr_symbol, symbolBrush=cell_color_symbol_brush[unit_IDX]

curr_surprise_difference = one_left_out_posterior_to_scrambled_pf_surprises_mean - one_left_out_posterior_to_pf_surprises_mean


# x=valid_time_bin_indicies
# y=curr_surprise_difference
x=result_df_grouped.time_bin_indices.to_numpy()
y=result_df_grouped['surprise_diff'].to_numpy()

plots['difference'] = win.plot(x=x, y=y, pen=None, symbol='t', symbolBrush=pg.intColor(2,6,maxValue=128), name=f'difference', alpha=0.5) #  symbolBrush=pg.intColor(i,6,maxValue=128) , symbol=curr_symbol, symbolBrush=cell_color_symbol_brush[unit_IDX]



for k, v in plots.items():
	legend.addItem(v, f'{k}')

win.graphicsItem().setLabel(axis='left', text='Normal v. Random - Surprise (Custom)')
win.graphicsItem().setLabel(axis='bottom', text='time')
# return win, plots_tuple, legend

In [None]:
valid_time_bin_indicies

In [None]:
one_left_out_posterior_to_pf_surprises_mean

## Pre 2023-04-13

In [None]:
## Compute Fresh (don't load from cache)
long_results_obj = perform_full_session_leave_one_out_decoding_analysis(global_session, original_1D_decoder=long_shared_aclus_only_decoder, decoding_time_bin_size = 0.025, cache_suffix = '_long', perform_cache_load=False, skip_cache_save=False)
short_results_obj = perform_full_session_leave_one_out_decoding_analysis(global_session, original_1D_decoder=short_shared_aclus_only_decoder, decoding_time_bin_size = 0.025, cache_suffix = '_short', perform_cache_load=False, skip_cache_save=False)
# # (time_bins, neurons), (epochs, neurons), (epochs)
# all_epochs_computed_one_left_out_posterior_to_pf_surprises, all_epochs_computed_cell_one_left_out_posterior_to_pf_surprises_mean, all_epochs_all_cells_one_left_out_posterior_to_pf_surprises_mean

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib qt

from pyphoplacecellanalysis.Analysis.Decoder.decoder_result import plot_kourosh_activity_style_figure

from neuropy.core.neurons import NeuronType
# # Include only pyramidal aclus:
# print(f'all shared_aclus: {len(shared_aclus)}\nshared_aclus: {shared_aclus}')
# shared_aclu_neuron_type = long_session.neurons.neuron_type[np.isin(long_session.neurons.neuron_ids, shared_aclus)]
# assert len(shared_aclu_neuron_type) == len(shared_aclus)
# # Find only the aclus that are pyramidal:
# is_shared_aclu_pyramidal = (shared_aclu_neuron_type == NeuronType.PYRAMIDAL)
# pyramidal_only_shared_aclus = shared_aclus[is_shared_aclu_pyramidal]
# print(f'num pyramidal_only_shared_aclus: {len(pyramidal_only_shared_aclus)}\npyramidal_only_shared_aclus: {pyramidal_only_shared_aclus}')


## Drop Pyramidal but don't use only shared aclus:
all_aclus = deepcopy(long_session.neurons.neuron_ids)
neuron_type = long_session.neurons.neuron_type
assert len(neuron_type) == len(all_aclus)
# Find only the aclus that are pyramidal:
is_aclu_pyramidal = (neuron_type == NeuronType.PYRAMIDAL)
pyramidal_only_all_aclus = all_aclus[is_aclu_pyramidal]
print(f'num pyramidal_only_all_aclus: {len(pyramidal_only_all_aclus)}\npyramidal_only_all_aclus: {pyramidal_only_all_aclus}')

# app, win, plots, plots_data = plot_kourosh_activity_style_figure(long_results_obj, long_session, shared_aclus, epoch_idx=5, callout_epoch_IDXs=[0,1,2,3], skip_rendering_callouts=True)
# app, win, plots, plots_data = plot_kourosh_activity_style_figure(long_results_obj, long_session, pyramidal_only_shared_aclus, epoch_idx=2, callout_epoch_IDXs=[0,4], skip_rendering_callouts=False)
app, win, plots, plots_data = plot_kourosh_activity_style_figure(long_results_obj, long_session, pyramidal_only_all_aclus, epoch_idx=3, callout_epoch_IDXs=[2,4,6], skip_rendering_callouts=False)

In [None]:
app, win, plots, plots_data = plot_kourosh_activity_style_figure(long_results_obj, long_session, pyramidal_only_all_aclus, epoch_idx=11, callout_epoch_IDXs=[0,1,2, 3, 4, 5], skip_rendering_callouts=False)

# 2023-04-13 Show Surprise 

In [None]:
import pyphoplacecellanalysis.External.pyqtgraph as pg
from PendingNotebookCode import plot_long_short, plot_long_short_any_values, plot_long_short_expected_vs_observed_firing_rates

# plot_long_short(long_results_obj, short_results_obj)

In [None]:
plot_long_short_any_values(long_results_obj=long_results_obj, short_results_obj=short_results_obj)

In [None]:
plot_long_short_expected_vs_observed_firing_rates(long_results_obj=long_results_obj, short_results_obj=short_results_obj)

In [None]:
x_fn = lambda a_results_obj: a_results_obj.all_epochs_decoded_epoch_time_bins_mean[:,0]
# y_fn = lambda a_results_obj: a_results_obj.all_epochs_all_cells_one_left_out_posterior_to_scrambled_pf_surprises_mean
y_fn = lambda a_results_obj: a_results_obj.all_epochs_all_cells_one_left_out_posterior_to_pf_surprises_mean
# y_fn = lambda a_results_obj: a_results_obj.all_epochs_computed_one_left_out_posterior_to_pf_surprises

# (time_bins, neurons), (epochs, neurons), (epochs)
# all_epochs_computed_one_left_out_posterior_to_pf_surprises, all_epochs_computed_cell_one_left_out_posterior_to_pf_surprises_mean, all_epochs_all_cells_one_left_out_posterior_to_pf_surprises_mean
win, plots_tuple, legend = plot_long_short_any_values(long_results_obj, short_results_obj, x=x_fn, y=y_fn, limit_aclus=[20])

# 2023-04-13 - Find Good looking epochs:

In [None]:
%matplotlib qt
from pyphoplacecellanalysis.General.Pipeline.Stages.DisplayFunctions.DecoderPredictionError import plot_decoded_epoch_slices

laps_plot_tuple = plot_decoded_epoch_slices(long_results_obj.active_filter_epochs, long_results_obj.all_included_filter_epochs_decoder_result, global_pos_df=global_session.position.df, variable_name='lin_pos', xbin=long_results_obj.original_1D_decoder.xbin,
                                                        name='stacked_epoch_slices_long_results_obj', debug_print=True, debug_test_max_num_slices=32)

In [None]:
[23, 27, 29, ]

In [None]:
## Pagination support for `plot_decoded_epoch_slices`
from pyphocorehelpers.indexing_helpers import compute_paginated_grid_config

n_epochs = long_results_obj.active_filter_epochs.n_epochs
subplot_no_pagination_configuration, included_combined_indicies_pages, page_grid_sizes = compute_paginated_grid_config(n_epochs, max_num_columns=1, max_subplots_per_page=32, data_indicies=result_df_grouped.index.to_numpy(), last_figure_subplots_same_layout=True)
num_pages = len(included_combined_indicies_pages)
num_pages
included_epoch_indicies_pages = [[curr_included_epoch_index for (a_linear_index, curr_row, curr_col, curr_included_epoch_index) in v] for page_idx, v in enumerate(included_combined_indicies_pages)] # a list of length `num_pages` containing up to 10 items