"2024-01-08 - Batch Run Progress Tracker.ipynb"

Serves to keep track of the status, results, and effects of various runs of the pipeline. For example when a certain run on a certain machine throws and error and produces a stacktrace, at a minimum that stacktrace should be accessible later in a manner more or equally conveninet.



In [None]:
%config IPCompleter.use_jedi = False
%pdb off
%load_ext autoreload
%autoreload 3
import sys
from copy import deepcopy
from typing import List, Dict, Optional, Union, Callable
from pathlib import Path
import pathlib
import numpy as np
import pandas as pd
import tables as tb
from datetime import datetime, timedelta
from attrs import define, field, Factory

# required to enable non-blocking interaction:
%gui qt5

In [None]:
import pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions
# from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions import ComputationFunctionRegistryHolder # should include ComputationFunctionRegistryHolder and all specifics
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.ComputationFunctionRegistryHolder import ComputationFunctionRegistryHolder
from pyphoplacecellanalysis.General.Pipeline.Stages.ComputationFunctions.MultiContextComputationFunctions.MultiContextComputationFunctions import _wrap_multi_context_computation_function

from pyphocorehelpers.exception_helpers import CapturedException # used in _execute_computation_functions for error handling
from pyphocorehelpers.programming_helpers import metadata_attributes
from pyphocorehelpers.function_helpers import function_attributes

from pyphoplacecellanalysis.General.Pipeline.Stages.Computation import FunctionsSearchMode
from pyphoplacecellanalysis.General.Model.SpecificComputationValidation import SpecificComputationValidator, SpecificComputationResultsSpecification

In [None]:
@define(slots=False, repr=False)
class ComputationFunctionManager:
    """Built from `ComputedPipelineStage`

    global_comparison_results has keys of type IdentifyingContext
    """
    registered_computation_function_dict: Dict = field(default=Factory(dict), repr=True)
    registered_global_computation_function_dict: Dict = field(default=Factory(dict), repr=True)
        

    @property
    def registered_computation_functions(self):
        return list(self.registered_computation_function_dict.values())
    @property
    def registered_computation_function_names(self):
        return list(self.registered_computation_function_dict.keys()) 


    @property
    def registered_global_computation_functions(self):
        return list(self.registered_global_computation_function_dict.values())
    @property
    def registered_global_computation_function_names(self):
        return list(self.registered_global_computation_function_dict.keys()) 


    # 'merged' refers to the fact that both global and non-global computation functions are included _____________________ #
    @property
    def registered_merged_computation_function_dict(self):
        """build a merged function dictionary containing both global and non-global functions:"""
        return (self.registered_global_computation_function_dict | self.registered_computation_function_dict)
    @property
    def registered_merged_computation_functions(self):
        return list(self.registered_merged_computation_function_dict.values())
    @property
    def registered_merged_computation_function_names(self):
        return list(self.registered_merged_computation_function_dict.keys()) 

    def get_merged_computation_function_validators(self) -> Dict[str, SpecificComputationValidator]:
        ## From the registered computation functions, gather any validators and build the SpecificComputationValidator for them, then append them to `_comp_specifiers`:
        return {k:SpecificComputationValidator.init_from_decorated_fn(v) for k,v in self.registered_merged_computation_function_dict.items() if hasattr(v, 'validate_computation_test') and (v.validate_computation_test is not None)}



    def reload_default_computation_functions(self):
        """ reloads/re-registers the default display functions after adding a new one
            Note: execution ORDER MATTERS for the computation functions, unlike the display functions, so they need to be enumerated in the correct order and not sorted alphabetically        
        # Sort by precidence:
            _computationPrecidence
        """
        # Non-Global Items:
        for (a_computation_class_name, a_computation_class) in reversed(ComputationFunctionRegistryHolder.get_non_global_registry_items().items()):
            for (a_computation_fn_name, a_computation_fn) in reversed(a_computation_class.get_all_functions(use_definition_order=True)):
                self.register_computation(a_computation_fn_name, a_computation_fn, is_global=False)
        # Global Items:
        for (a_computation_class_name, a_computation_class) in reversed(ComputationFunctionRegistryHolder.get_global_registry_items().items()):
            for (a_computation_fn_name, a_computation_fn) in reversed(a_computation_class.get_all_functions(use_definition_order=True)):
                self.register_computation(a_computation_fn_name, a_computation_fn, is_global=True)

    def register_computation(self, registered_name, computation_function, is_global:bool):
        # Set the .is_global attribute on the function object itself, since functions are 1st-class objects in Python:
        computation_function.is_global = is_global

        if is_global:
            try:
                self.registered_global_computation_function_dict[registered_name] = computation_function
            except AttributeError as e:
                # Create a new global dictionary if needed and then try re-register:
                self.registered_global_computation_function_dict = dict()
                self.registered_global_computation_function_dict[registered_name] = computation_function            
        else:
            # non-global:
            try:
                self.registered_computation_function_dict[registered_name] = computation_function
            except AttributeError as e:
                # Create a new non-global dictionary if needed and then try re-register:
                self.registered_computation_function_dict = dict()
                self.registered_computation_function_dict[registered_name] = computation_function
        

    def unregister_all_computation_functions(self):
        ## Drops all registered computationf functions (global and non-global) so they can be reloaded fresh:
        self.registered_global_computation_function_dict = dict()
        self.registered_computation_function_dict = dict()


    def find_registered_computation_functions(self, registered_names_list, search_mode:FunctionsSearchMode=FunctionsSearchMode.ANY, names_list_is_excludelist:bool=False):
        ''' Finds the list of actual function objects associated with the registered_names_list by using the appropriate dictionary of registered functions depending on whether are_global is True or not.

        registered_names_list: list<str> - a list of function names to be used to fetch the appropriate functions
        are_global: bool - If True, the registered_global_computation_function_dict is used instead of the registered_computation_function_dict
        names_list_is_excludelist: bool - if True, registered_names_list is treated as a excludelist, and all functions are returned EXCEPT those that are in registered_names_list

        Usage:
            active_computation_functions = self.find_registered_computation_functions(computation_functions_name_includelist, are_global=are_global)
        '''
        # We want to reload the new/modified versions of the functions:
        self.reload_default_computation_functions()

        if search_mode.name == FunctionsSearchMode.GLOBAL_ONLY.name:
            active_registered_computation_function_dict = self.registered_global_computation_function_dict
        elif search_mode.name == FunctionsSearchMode.NON_GLOBAL_ONLY.name:
            active_registered_computation_function_dict = self.registered_computation_function_dict
        elif search_mode.name == FunctionsSearchMode.ANY.name:
            # build a merged function dictionary containing both global and non-global functions:
            active_registered_computation_function_dict = self.registered_merged_computation_function_dict


        else:
            raise NotImplementedError

        if names_list_is_excludelist:
            # excludelist-style operation: treat the registered_names_list as a excludelist and return all registered functions EXCEPT those that are in registered_names_list
            active_computation_function_dict = {a_computation_fn_name:a_computation_fn for (a_computation_fn_name, a_computation_fn) in active_registered_computation_function_dict.items() if ((a_computation_fn_name not in registered_names_list) and (getattr(a_computation_fn, 'short_name', a_computation_fn.__name__) not in registered_names_list))}
        else:
            # default includelist-style operation:
            active_computation_function_dict = {a_computation_fn_name:a_computation_fn for (a_computation_fn_name, a_computation_fn) in active_registered_computation_function_dict.items() if ((a_computation_fn_name in registered_names_list) or (getattr(a_computation_fn, 'short_name', a_computation_fn.__name__) in registered_names_list))}

        return list(active_computation_function_dict.values())


a_man = ComputationFunctionManager()
a_man.reload_default_computation_functions()

## Specify the computations and the requirements to validate them.

## Hardcoded comp_specifiers
_comp_specifiers = list(reversed(list(a_man.get_merged_computation_function_validators().values()))) ## Execution order is currently determined by `_comp_specifiers` order and not the order the `include_includelist` lists them (which is good) but the `curr_active_pipeline.registered_merged_computation_function_dict` has them registered in *REVERSE* order for the specific computation function called, so we need to reverse these

_comp_specifiers

In [None]:
## Get the short names of all the possible functions that can be found


always_disabled_global_comp_names = ['PBE_stats']
always_disabled_non_global_comp_names = ['_perform_specific_epochs_decoding', 'velocity_vs_pf_simplified_count_density', 'placefield_overlap', '_DEP_ratemap_peaks', 'recursive_latent_pf_decoding', 'EloyAnalysis']


check_manual_non_global_comp_names = ['pf_computation', 'pfdt_computation', 'firing_rate_trends', 'pf_dt_sequential_surprise', 'ratemap_peaks_prominence2d', 'position_decoding', 'position_decoding_two_step', 'spike_burst_detection', 'extended_stats']
check_manual_global_comp_names = ['long_short_decoding_analyses', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 'long_short_post_decoding', 'long_short_rate_remapping', 'long_short_inst_spike_rate_groups', 'pf_dt_sequential_surprise', 'long_short_endcap_analysis',
                         'split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis'] # , 'long_short_rate_remapping'


## Get the computation shortnames:
non_global_comp_names = [v.short_name for v in  _comp_specifiers if ((not v.is_global) and (v.short_name not in always_disabled_non_global_comp_names))]
global_comp_names = [v.short_name for v in  _comp_specifiers if (v.is_global and (v.short_name not in always_disabled_global_comp_names))]

non_global_comp_names
global_comp_names

missing_global_comp_names = list(set(global_comp_names) - set(check_manual_global_comp_names))
missing_non_global_comp_names = list(set(non_global_comp_names) - set(check_manual_non_global_comp_names))

print(f'missing_global_comp_names: {missing_global_comp_names}')
print(f'missing_non_global_comp_names: {missing_non_global_comp_names}')

In [None]:
progress_print = True
force_recompute = False
force_recompute_override_computations_includelist = None
include_includelist = None

dry_run = True
include_global_functions = True

curr_active_pipeline

In [None]:
active_comp_specifier_fcn = lambda _comp_specifier, *args, **kwargs: _comp_specifier.try_computation_if_needed(*args, **kwargs)
active_comp_specifier_fcn = lambda _comp_specifier, *args, **kwargs: _comp_specifier.try_check_missing_provided_keys(*args, **kwargs)

remaining_include_function_names = {k:False for k in include_includelist.copy()}

for _comp_specifier in _comp_specifiers:
	if (not _comp_specifier.is_global) or include_global_functions:
		if (_comp_specifier.short_name in include_includelist) or (_comp_specifier.computation_fn_name in include_includelist):
			if (not _comp_specifier.is_global):
				# Not Global-only, need to compute for all `included_computation_filter_names`:
				for a_computation_filter_name in included_computation_filter_names:
					if not dry_run:
						# newly_computed_values += _comp_specifier.try_computation_if_needed(curr_active_pipeline, computation_filter_name=a_computation_filter_name, on_already_computed_fn=_subfn_on_already_computed, fail_on_exception=fail_on_exception, progress_print=progress_print, debug_print=debug_print, force_recompute=force_recompute)
						newly_computed_values += _comp_specifier.try_check_missing_provided_keys(curr_active_pipeline, computation_filter_name=a_computation_filter_name, progress_print=progress_print, force_recompute=force_recompute)

					else:
						print(f'dry-run: {_comp_specifier.short_name}, computation_filter_name={a_computation_filter_name}, force_recompute={force_recompute}')

			else:
				# Global-Only:
				_curr_force_recompute = force_recompute or ((_comp_specifier.short_name in force_recompute_override_computations_includelist) or (_comp_specifier.computation_fn_name in force_recompute_override_computations_includelist)) # force_recompute for this specific result if either of its name is included in `force_recompute_override_computations_includelist`
									
				# Check for existing result:
				is_known_missing_provided_keys: bool = _comp_specifier.try_check_missing_provided_keys(curr_active_pipeline)
				if is_known_missing_provided_keys:
					print(f'{_comp_specifier.short_name} -- is_known_missing_provided_keys = True!')

			if (_comp_specifier.short_name in include_includelist):
				del remaining_include_function_names[_comp_specifier.short_name]
			elif (_comp_specifier.computation_fn_name in include_includelist):
				del remaining_include_function_names[_comp_specifier.computation_fn_name]
			else:
				raise NotImplementedError

if len(remaining_include_function_names) > 0:
	print(f'WARNING: after execution of all _comp_specifiers found the functions: {remaining_include_function_names} still remain! Are they correct and do they have proper validator decorators?')
if progress_print:
	print('done with all batch_extended_computations(...).')


In [None]:
def _subfn_on_already_computed(_comp_name, computation_filter_name):
	""" captures: `progress_print`, `force_recompute`
	raises AttributeError if force_recompute is true to trigger recomputation """
	if progress_print:
		print(f'{_comp_name}, {computation_filter_name} already computed.')
	if force_recompute:
		if progress_print:
			print(f'\tforce_recompute is true so recomputing anyway')
		raise AttributeError # just raise an AttributeError to trigger recomputation    

newly_computed_values = []
force_recompute_override_computations_includelist = force_recompute_override_computations_includelist or []

non_global_comp_names = ['pf_computation', 'pfdt_computation', 'firing_rate_trends', 'pf_dt_sequential_surprise', 'ratemap_peaks_prominence2d', 'position_decoding', 'position_decoding_two_step', 'spike_burst_detection']
global_comp_names = ['long_short_decoding_analyses', 'jonathan_firing_rate_analysis', 'long_short_fr_indicies_analyses', 'short_long_pf_overlap_analyses', 'long_short_post_decoding', 'long_short_rate_remapping', 'long_short_inst_spike_rate_groups', 'pf_dt_sequential_surprise', 'long_short_endcap_analysis',
						'split_to_directional_laps', 'merged_directional_placefields', 'rank_order_shuffle_analysis'] # , 'long_short_rate_remapping'

# 'firing_rate_trends', 'pf_dt_sequential_surprise'
# '_perform_firing_rate_trends_computation', '_perform_time_dependent_pf_sequential_surprise_computation'

if include_includelist is None:
	# include all:
	include_includelist = non_global_comp_names + global_comp_names
else:
	print(f'included includelist is specified: {include_includelist}, so only performing these extended computations.')

## Get computed relative entropy measures:
_, _, global_epoch_name = curr_active_pipeline.find_LongShortGlobal_epoch_names()
# global_epoch_name = curr_active_pipeline.active_completed_computation_result_names[-1] # 'maze'

if included_computation_filter_names is None:
	included_computation_filter_names = [global_epoch_name] # use only the global epoch: e.g. ['maze']
	if progress_print:
		print(f'Running batch_extended_computations(...) with global_epoch_name: "{global_epoch_name}"')
else:
	if progress_print:
		print(f'Running batch_extended_computations(...) with included_computation_filter_names: "{included_computation_filter_names}"')



## Specify the computations and the requirements to validate them.

## Hardcoded comp_specifiers
_comp_specifiers = list(curr_active_pipeline.get_merged_computation_function_validators().values())
## Execution order is currently determined by `_comp_specifiers` order and not the order the `include_includelist` lists them (which is good) but the `curr_active_pipeline.registered_merged_computation_function_dict` has them registered in *REVERSE* order for the specific computation function called, so we need to reverse these
_comp_specifiers = reversed(_comp_specifiers)


def try_run_compute_comp_specifiers(_comp_specifiers, curr_active_pipeline, include_global_functions: bool = True, ):
	""" Captures: force_recompute, fail_on_exception, debug_print 
	"""
	remaining_include_function_names = {k:False for k in include_includelist.copy()}

	for _comp_specifier in _comp_specifiers:
		if (not _comp_specifier.is_global) or include_global_functions:
			if (_comp_specifier.short_name in include_includelist) or (_comp_specifier.computation_fn_name in include_includelist):
				if (not _comp_specifier.is_global):
					# Not Global-only, need to compute for all `included_computation_filter_names`:
					for a_computation_filter_name in included_computation_filter_names:
						if not dry_run:
							newly_computed_values += _comp_specifier.try_computation_if_needed(curr_active_pipeline, computation_filter_name=a_computation_filter_name, on_already_computed_fn=_subfn_on_already_computed, fail_on_exception=fail_on_exception, progress_print=progress_print, debug_print=debug_print, force_recompute=force_recompute)
						else:
							print(f'dry-run: {_comp_specifier.short_name}, computation_filter_name={a_computation_filter_name}, force_recompute={force_recompute}')

				else:
					# Global-Only:
					_curr_force_recompute = force_recompute or ((_comp_specifier.short_name in force_recompute_override_computations_includelist) or (_comp_specifier.computation_fn_name in force_recompute_override_computations_includelist)) # force_recompute for this specific result if either of its name is included in `force_recompute_override_computations_includelist`
					if not dry_run:
						newly_computed_values += _comp_specifier.try_computation_if_needed(curr_active_pipeline, computation_filter_name=global_epoch_name, on_already_computed_fn=_subfn_on_already_computed, fail_on_exception=fail_on_exception, progress_print=progress_print, debug_print=debug_print, force_recompute=_curr_force_recompute)
					else:
						print(f'dry-run: {_comp_specifier.short_name}, force_recompute={force_recompute}, curr_force_recompute={_curr_force_recompute}')
						# Check for existing result:
						is_known_missing_provided_keys: bool = _comp_specifier.try_check_missing_provided_keys(curr_active_pipeline)
						if is_known_missing_provided_keys:
							print(f'{_comp_specifier.short_name} -- is_known_missing_provided_keys = True!')

				if (_comp_specifier.short_name in include_includelist):
					del remaining_include_function_names[_comp_specifier.short_name]
				elif (_comp_specifier.computation_fn_name in include_includelist):
					del remaining_include_function_names[_comp_specifier.computation_fn_name]
				else:
					raise NotImplementedError

	return remaining_include_function_names
	
remaining_include_function_names = try_run_compute_comp_specifiers(_comp_specifiers)

if len(remaining_include_function_names) > 0:
	print(f'WARNING: after execution of all _comp_specifiers found the functions: {remaining_include_function_names} still remain! Are they correct and do they have proper validator decorators?')
if progress_print:
	print('done with all batch_extended_computations(...).')