# Import packages

In [None]:

%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils, combine_info_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding import pn_decoding_utils, plot_pn_decoding, pn_decoding_model_specs

import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload
import warnings

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca
import quantities as pq

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# specs and funcs

In [None]:
key_features = [
    'new_bin', 'new_segment', 'whether_test',
    'cur_eye_hor_l', 'cur_eye_ver_l', 'cur_eye_hor_r', 'cur_eye_ver_r',
    'nxt_eye_hor_l', 'nxt_eye_ver_l', 'nxt_eye_hor_r', 'nxt_eye_ver_r',
    'LDz', 'RDz', 'LDx', 'RDx',
    'gaze_mky_view_x', 'gaze_mky_view_y', 'gaze_mky_view_angle',
    'cur_opt_arc_dheading', 'cur_ff_distance',
    'cur_ff_rel_x', 'cur_ff_rel_y', 'nxt_ff_rel_x', 'nxt_ff_rel_y', 'nxt_ff_distance',
    'num_ff_visible', 'num_ff_in_memory',
    'cur_ff_distance_at_ref', 'cur_ff_angle_boundary_at_ref', 'nxt_ff_distance_at_ref',
    'speed', 'ang_speed', 'accel', 'ang_accel',
    'monkey_speeddummy', 'curv_of_traj',
    'angle_from_cur_ff_to_nxt_ff',
    'time_since_last_capture', 'bin_mid_time_rel_to_event',
    'time', 'target_index',
    'cur_vis', 'nxt_vis', 'nxt_in_memory', 'any_ff_visible',
]


# Iterate through run_segment_split_regression_cv

In [None]:
use_raw_spike_data_instead=True
apply_pca_on_raw_spike_data=False
use_lagged_raw_spike_data=True

use_lagged_rebinned_behav_data = False

In [None]:
raw_data_dir_name = 'all_monkey_data/raw_monkey_data'

sessions_df_for_one_monkey = combine_info_utils.make_sessions_df_for_one_monkey(
    raw_data_dir_name, 'monkey_Bruno')

for index, row in sessions_df_for_one_monkey.iterrows():
    print('='*100)
    print('='*100)
    print(row['data_name'])
    raw_data_folder_path = os.path.join(
        raw_data_dir_name, row['monkey_name'], row['data_name'])
    
    reduce_y_var_lags = False
    planning_data_by_point_exists_ok = True
    y_data_exists_ok = True
    bin_width = 0.1
    
    try:
        pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path, bin_width=bin_width)
        pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
        pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
            pn.planning_data_by_point)
        pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)
        
            
        pn.prepare_seg_aligned_data()
        if not use_raw_spike_data_instead:
            pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=True)

        pn.get_concat_data_for_regression(use_raw_spike_data_instead=use_raw_spike_data_instead,
                                        use_lagged_rebinned_behav_data=use_lagged_rebinned_behav_data,
                                        apply_pca_on_raw_spike_data=apply_pca_on_raw_spike_data,
                                        use_lagged_raw_spike_data=use_lagged_raw_spike_data,) 


        pn.concat_behav_trials, added_cols = pn_decoding_utils.prep_behav(pn.concat_behav_trials)
        pn.rebinned_behav_data, _ = pn_decoding_utils.prep_behav(pn.rebinned_behav_data)
        overall_key_feats = list(set(key_features + added_cols))
        pn.concat_behav_trials = pn.concat_behav_trials[overall_key_feats].copy()
        pn.rebinned_behav_data = pn.rebinned_behav_data[overall_key_feats].copy()
        
        pn.print_data_dimensions()
        
        mask = pn.concat_behav_trials['bin_mid_time_rel_to_event'] > 0
        pn.concat_behav_trials = pn.concat_behav_trials[mask]
        pn.concat_neural_trials = pn.concat_neural_trials[mask]
        
        

        pn.separate_test_and_control_data()
        # columns_of_interest = ['whether_test', 'cur_ff_distance', 'cur_ff_angle', 'nxt_ff_distance', 'nxt_ff_rel_y', 'nxt_opt_arc_dheading', 'nxt_ff_rel_x', 'nxt_ff_angle', 'nxt_ff_angle_at_ref']
        columns_of_interest = pn.concat_behav_trials.columns
        all_results = []
        for test_or_control in ['both']: #['test', 'control', 'both']:
            x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control=test_or_control)
            
            results_summary = ml_methods_utils.run_segment_split_regression_cv(
                x_var, 
                y_var, 
                columns_of_interest, 
                num_folds=5, 
            )
            results_summary['test_or_control'] = test_or_control
            all_results.append(results_summary)

        all_results = pd.concat(all_results)
        all_results.head()

        reg_results = all_results[all_results['Model'] == 'Linear Regression']
        class_results = all_results[all_results['Model'] == 'Logistic Regression']

        # first only plot key_features2 cur_ff_distance
        key_features2 = ['cur_ff_distance', 'log1p_cur_ff_distance', 'speed', 'accel', 'time_since_last_capture']
        for metric in ['test_r2']:
            ml_methods_utils.make_barplot_to_compare_results(
                    reg_results, 
                    metric=metric, 
                    features=key_features2,
                )
            print('='*100)
            print('='*100)


        rest_of_features = [c for c in reg_results['Feature'].unique() if c not in key_features2]

        # regression results
        for metric in ['test_r2']:
            ml_methods_utils.make_barplot_to_compare_results(
                    reg_results, 
                    metric=metric, 
                    features=rest_of_features,
                )
            print('='*100)
            print('='*100)
            
        # classification results
        for metric in ['test_roc_auc']:
            ml_methods_utils.make_barplot_to_compare_results(
                class_results, 
                metric=metric, 
            )
            
    except Exception as e:
        print(f"Error processing {row['data_name']}: {e}")
        continue

    # Save the current state of the notebook
    

# TRY GPFA NEXT (for the above) !!!!!!!!!!

# Iterate through 'decode_by_num_ff_visible_or_in_memory'

## note: i can try various combos of neural data formats

In [None]:
ff_visibility_col = 'num_ff_visible'


use_raw_spike_data_instead=True
apply_pca_on_raw_spike_data=False
use_lagged_raw_spike_data=True

use_lagged_rebinned_behav_data = False

In [None]:

reload(pn_decoding_utils)


In [None]:
raw_data_dir_name = 'all_monkey_data/raw_monkey_data'

sessions_df_for_one_monkey = combine_info_utils.make_sessions_df_for_one_monkey(
    raw_data_dir_name, 'monkey_Bruno')

for index, row in sessions_df_for_one_monkey.iterrows():
    print('='*100)
    print('='*100)
    print(row['data_name'])
    raw_data_folder_path = os.path.join(
        raw_data_dir_name, row['monkey_name'], row['data_name'])
    
    reduce_y_var_lags = False
    planning_data_by_point_exists_ok = True
    y_data_exists_ok = True
    bin_width = 0.1
    
    try:
        pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path, bin_width=bin_width)
        pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
        pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
            pn.planning_data_by_point)
        pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)
        
            
        pn.prepare_seg_aligned_data()
        if not use_raw_spike_data_instead:
            pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=True)

        # for regression later
        pn.get_concat_data_for_regression(use_raw_spike_data_instead=use_raw_spike_data_instead,
                                        use_lagged_rebinned_behav_data=use_lagged_rebinned_behav_data,
                                        apply_pca_on_raw_spike_data=apply_pca_on_raw_spike_data,
                                        use_lagged_raw_spike_data=use_lagged_raw_spike_data,) 

        # get key_features2
        pn.concat_behav_trials, cols_added = pn_decoding_utils.add_interaction_terms_and_features(pn.concat_behav_trials)
        key_features2 = (
            ['cur_ff_distance', 'log1p_cur_ff_distance', 'speed',
                'accel', 'time_since_last_capture']
            + cols_added
        )


        x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control='both')
        y_var, cols_added = pn_decoding_utils.add_interaction_terms_and_features(y_var)
        y_var, added_cols = pn_decoding_utils.prep_behav(y_var)


        all_results = []

        save_path = os.path.join(pn.planning_and_neural_folder_path, 'pn_decoding', 'conditioned_on_ff_visibility')
        for model_name, spec in pn_decoding_model_specs.MODEL_SPECS.items():

            print(f'model_name: ', model_name)
                        
            config = pn_decoding_utils.DecodingRunConfig(
                model_class=spec['model_class'],
                model_kwargs=spec['model_kwargs'],
                use_early_stopping=False,
                make_plots=False,
            )
            
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')

                results_df = pn_decoding_utils.decode_by_num_ff_visible_or_in_memory(
                    x_var,
                    y_var,
                    key_features2,
                    config=config,
                    save_path=save_path,
                    ff_visibility_col=ff_visibility_col,
                
            )

            results_df['model_name'] = model_name
            all_results.append(results_df)

        all_results_df = pd.concat(all_results, ignore_index=True)

        plot_pn_decoding.plot_decoding_heatmaps_with_n(
            all_results_df,
            ff_visibility_col)

            
    except Exception as e:
        print(f"Error processing {row['data_name']}: {e}")
        continue

    break
    # Save the current state of the notebook
    

## shuffled control

In [None]:
raw_data_dir_name = 'all_monkey_data/raw_monkey_data'

sessions_df_for_one_monkey = combine_info_utils.make_sessions_df_for_one_monkey(
    raw_data_dir_name, 'monkey_Bruno')

for index, row in sessions_df_for_one_monkey.iterrows():
    print('='*100)
    print('='*100)
    print(row['data_name'])
    raw_data_folder_path = os.path.join(
        raw_data_dir_name, row['monkey_name'], row['data_name'])
    
    reduce_y_var_lags = False
    planning_data_by_point_exists_ok = True
    y_data_exists_ok = True
    bin_width = 0.1
    
    try:
        pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path, bin_width=bin_width)
        pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
        pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
            pn.planning_data_by_point)
        pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)
        
            
        pn.prepare_seg_aligned_data()
        if not use_raw_spike_data_instead:
            pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=True)

        # for regression later
        pn.get_concat_data_for_regression(use_raw_spike_data_instead=use_raw_spike_data_instead,
                                        use_lagged_rebinned_behav_data=use_lagged_rebinned_behav_data,
                                        apply_pca_on_raw_spike_data=apply_pca_on_raw_spike_data,
                                        use_lagged_raw_spike_data=use_lagged_raw_spike_data,) 

        # get key_features2
        pn.concat_behav_trials, cols_added = pn_decoding_utils.add_interaction_terms_and_features(pn.concat_behav_trials)
        key_features2 = (
            ['cur_ff_distance', 'log1p_cur_ff_distance', 'speed',
                'accel', 'time_since_last_capture']
            + cols_added
        )


        x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control='both')
        y_var, cols_added = pn_decoding_utils.add_interaction_terms_and_features(y_var)
        y_var, added_cols = pn_decoding_utils.prep_behav(y_var)


        all_results = []

        save_path = os.path.join(pn.planning_and_neural_folder_path, 'pn_decoding', 'conditioned_on_ff_visibility')
        for model_name, spec in pn_decoding_model_specs.MODEL_SPECS.items():
            
            print(f'model_name: ', model_name)
            
            config = pn_decoding_utils.DecodingRunConfig(
                model_class=spec['model_class'],
                model_kwargs=spec['model_kwargs'],
                use_early_stopping=False,
                make_plots=False,
            )
            
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')

                results_df = pn_decoding_utils.decode_by_num_ff_visible_or_in_memory(
                    x_var,
                    y_var,
                    key_features2,
                    config=config,
                    save_path=save_path,
                    ff_visibility_col=ff_visibility_col,
                    shuffle_y=True
            )

            results_df['model_name'] = model_name
            all_results.append(results_df)

        all_results_df = pd.concat(all_results, ignore_index=True)

        plot_pn_decoding.plot_decoding_heatmaps_with_n(
            all_results_df,
            ff_visibility_col)

            
    except Exception as e:
        print(f"Error processing {row['data_name']}: {e}")
        continue


    break
    # Save the current state of the notebook
    

# Conditional decoding

## first try for same session, use multiple combos of neural data

In [None]:
# key_features = [
#     'new_bin', 'new_segment', 'whether_test',
#     'cur_eye_hor_l', 'cur_eye_ver_l', 'cur_eye_hor_r', 'cur_eye_ver_r',
#     'nxt_eye_hor_l', 'nxt_eye_ver_l', 'nxt_eye_hor_r', 'nxt_eye_ver_r',
#     'LDz', 'RDz', 'LDx', 'RDx',
#     'gaze_mky_view_x', 'gaze_mky_view_y', 'gaze_mky_view_angle',
#     'cur_opt_arc_dheading', 'cur_ff_distance',
#     'cur_ff_rel_x', 'cur_ff_rel_y', 'nxt_ff_rel_x', 'nxt_ff_rel_y', 'nxt_ff_distance',
#     'num_ff_visible', 'num_ff_in_memory',
#     'cur_ff_distance_at_ref', 'cur_ff_angle_boundary_at_ref', 'nxt_ff_distance_at_ref',
#     'speed', 'ang_speed', 'accel', 'ang_accel',
#     'monkey_speeddummy', 'curv_of_traj',
#     'angle_from_cur_ff_to_nxt_ff',
#     'time_since_last_capture', 'bin_mid_time_rel_to_event',
#     'time', 'target_index',
#     'cur_vis', 'nxt_vis', 'nxt_in_memory', 'any_ff_visible',
# ]


In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True
bin_width = 0.1

In [None]:

raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0316"

In [None]:
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding.interactions.band_conditioned.specified_pairs import CONTINUOUS_INTERACTIONS, DISCRETE_INTERACTIONS


In [None]:
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding.interactions import add_interactions, discrete_decoders, interaction_decoding
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding.interactions.band_conditioned import conditional_decoding_clf, conditional_decoding_reg, conditional_decoding_plots

In [None]:
pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path, bin_width=bin_width)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
    pn.planning_data_by_point)
pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)


flags = {
    'use_raw_spike_data_instead': True,
    'apply_pca_on_raw_spike_data': False,
    'use_lagged_raw_spike_data': False,
}
pn.get_concat_data_for_regression(**flags) 
df = pn.concat_behav_trials.copy()
df, added_cols = pn_decoding_utils.prep_behav(df)
df = add_interactions.add_behavior_bands(df)


key_features2 = (['cur_ff_distance', 'log1p_cur_ff_distance', 'speed', 'accel', 'time_since_last_capture'] + added_cols)