# Import packages

In [1]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, ml_methods_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils
from neural_data_analysis.design_kits.design_by_segment import create_design_df, predictor_utils, other_feats
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event, pn_glm_utils
from neural_data_analysis.neural_analysis_tools.glm_tools.tpg import glm_bases, glm_plotting, glm_plotting2, glm_fit

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, get_stops_utils, psth_postprocessing, psth_stats, compare_events, dpca_utils
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_fit import stop_glm_fit, cv_stop_glm, glm_fit_utils, variance_explained
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_plotting import plot_spikes, plot_glm_fit, plot_tuning_func
from neural_data_analysis.design_kits.design_around_event import event_binning, stop_design, cluster_design, design_checks
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_hyperparams import compare_glm_configs, glm_hyperparams_class
from neural_data_analysis.neural_analysis_tools.glm_tools.glm_decoding_tools import glm_decoding_llr, glm_decoding
from planning_analysis.show_planning.cur_vs_nxt_ff import cvn_from_ref_class
from planning_analysis.plan_factors import build_factor_comp


import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
from numpy import pi

from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

Set up logging configuration.




done
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Retrieve data

In [41]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0222"

In [3]:
cvn = cvn_from_ref_class.CurVsNxtFfFromRefClass(raw_data_folder_path=raw_data_folder_path)
# Quick method - tries to retrieve first, creates if needed
cvn.make_heading_info_df_without_long_process(
    test_or_control='test',  # or 'control'
    ref_point_mode='distance',  # or 'time after cur ff visible'
    ref_point_value=-100,  # or 0.0 for time mode
    heading_info_df_exists_ok=True,  # Set to False to force recreation
    stops_near_ff_df_exists_ok=True,
    save_data=True
)

# Access the result
heading_info_df = cvn.heading_info_df
heading_df = heading_info_df[['cur_ff_index', 'diff_in_abs_angle_to_nxt_ff']].copy()
heading_df = heading_df.sort_values(by='diff_in_abs_angle_to_nxt_ff', ascending=False).reset_index(drop=True)


Retrieved monkey_information
The number of points that were removed due to delta_position exceeding the ceiling is 0
Note: ff_caught_T_sorted is replaced with ff_caught_T_new
Removed 0 rows out of 766 rows where cur_ff was not visible bbas or nxt_ff was not visible both bbas and bsans
shared_stops_near_ff_df has 766 rows
Retrieving shared_stops_near_ff_df succeeded
Successfully retrieved diff_in_curv_df from all_monkey_data/planning/monkey_Bruno/data_0330/diff_in_curv_df/opt_arc_stop_closest/test/dist_-100_window_-25cm_0cm
Successfully retrieved heading_info_df from all_monkey_data/planning/monkey_Bruno/data_0330/heading_info_df/opt_arc_stop_closest/test/Bruno_dist_-100


# based on same side

In [None]:
from data_wrangling import combine_info_utils, specific_utils

# Get all sessions for a specific monkey
monkey_name = "monkey_Bruno"  # or "monkey_Schro"
sessions_df = combine_info_utils.make_sessions_df_for_one_monkey(
    raw_data_dir_name='all_monkey_data/raw_monkey_data',
    monkey_name=monkey_name
)

# Iterate through each session
for index, row in sessions_df.iterrows():
    if row['finished']:
        continue  # Skip already processed sessions
    
    # Construct the raw_data_folder_path
    raw_data_folder_path = f"all_monkey_data/raw_monkey_data/{row['monkey_name']}/{row['data_name']}"
    
    print(f"Processing: {raw_data_folder_path}")
    

    pn = glm_decoding.init_decoding_data(raw_data_folder_path)

    heading_info_df, heading_df = pn_glm_utils.get_test_heading_df(raw_data_folder_path)

    build_factor_comp.add_dir_from_cur_ff_same_side(heading_info_df)
    heading_info_df['dir_from_cur_ff_same_side'].mean()


    for same_side in [True, False]:
        print('-'*100)
        print('-'*100)
        if same_side:
            str = "=========Same Side========="
        else:
            str = "=========Opposite Side========="
            
        rebinned_x_var, rebinned_y_var = pn_glm_utils.select_ff_subset_by_dir_from_cur_ff_same_side(heading_info_df, pn.rebinned_x_var, pn.rebinned_y_var,
                                                                                                    same_side=same_side)
        
        rebinned_x_var = pn_glm_utils.drop_constant_columns(rebinned_x_var)
        data = rebinned_y_var.copy()


        df_X, df_Y = glm_decoding.get_data_for_decoding_vis(rebinned_x_var, rebinned_y_var, pn.bin_width)

        exposure = np.ones(len(df_Y)) * pn.bin_width
        offset_log = np.log(exposure)

        report = stop_glm_fit.glm_mini_report(
            df_X=df_X, df_Y=df_Y, offset_log=offset_log,
            cov_type='HC1', 
            fast_mle=True,
            do_inference=False, 
            make_plots=False,
            show_plots=True,
        )
        
        #cols_to_decode = ['nxt_vis', 'random_0_or_1', 'cur_vis']
        cols_to_decode = ['nxt_vis']
        groups = np.array(data['new_segment'])

        # # Decoding from fit
        # print(f"{str}")
        # glm_decoding.glm_decoding_from_fit(cols_to_decode, df_X, df_Y, offset_log, report)

        # CV
        print(f"{str}")
        glm_decoding.glm_decoding_cv(cols_to_decode, df_X, df_Y, groups, offset_log)

        # # permutations
        # print(f"{str}")
        # glm_decoding.glm_decoding_permutation_test(cols_to_decode, df_X, df_Y,
        #                         groups, offset_log, report, print_progress=False)


# top vs bottom (diff_in_abs_angle_to_nxt_ff)

In [None]:
from data_wrangling import combine_info_utils, specific_utils

# Get all sessions for a specific monkey
monkey_name = "monkey_Bruno"  # or "monkey_Schro"
sessions_df = combine_info_utils.make_sessions_df_for_one_monkey(
    raw_data_dir_name='all_monkey_data/raw_monkey_data',
    monkey_name=monkey_name
)

# Iterate through each session
for index, row in sessions_df.iterrows():
    if row['finished']:
        continue  # Skip already processed sessions
    
    # Construct the raw_data_folder_path
    raw_data_folder_path = f"all_monkey_data/raw_monkey_data/{row['monkey_name']}/{row['data_name']}"
    
    print(f"Processing: {raw_data_folder_path}")
    

    pn = glm_decoding.init_decoding_data(raw_data_folder_path)

    heading_info_df, heading_df = pn_glm_utils.get_test_heading_df(raw_data_folder_path)

    heading_info_df['dir_from_cur_ff_same_side'].mean()


    for top in [True, False]:
        print('-'*100)
        print('-'*100)
        if top:
            str = "=========TOP TOP TOP TOP TOP========="
        else:
            str = "=========BOTTOM BOTTOM BOTTOM BOTTOM BOTTOM========="
        rebinned_x_var, rebinned_y_var = pn_glm_utils.select_ff_subset(heading_df, pn.rebinned_x_var, pn.rebinned_y_var, 
                                                                    top=False, pct=0.5)

        rebinned_x_var = pn_glm_utils.drop_constant_columns(rebinned_x_var)
        data = rebinned_y_var.copy()


        df_X, df_Y = glm_decoding.get_data_for_decoding_vis(rebinned_x_var, rebinned_y_var, pn.bin_width)

        exposure = np.ones(len(df_Y)) * pn.bin_width
        offset_log = np.log(exposure)

        report = stop_glm_fit.glm_mini_report(
            df_X=df_X, df_Y=df_Y, offset_log=offset_log,
            cov_type='HC1', 
            fast_mle=True,
            do_inference=False, 
            make_plots=False,
            show_plots=True,
        )
        
        #cols_to_decode = ['nxt_vis', 'random_0_or_1', 'cur_vis']
        cols_to_decode = ['nxt_vis']
        groups = np.array(data['new_segment'])

        # Decoding from fit
        print(f"{str}")
        glm_decoding.glm_decoding_from_fit(cols_to_decode, df_X, df_Y, offset_log, report)

        # CV
        print(f"{str}")
        glm_decoding.glm_decoding_cv(cols_to_decode, df_X, df_Y, groups, offset_log)

        # permutations
        print(f"{str}")
        glm_decoding.glm_decoding_permutation_test(cols_to_decode, df_X, df_Y,
                                groups, offset_log, report, print_progress=False)



# debug

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0222"
pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=True)

Loaded binned_spikes_df from all_monkey_data/processed_neural_data/monkey_Bruno/data_0222/binned_spikes_df_0p05.csv
Retrieved monkey_information
The number of points that were removed due to delta_position exceeding the ceiling is 0
Note: ff_caught_T_sorted is replaced with ff_caught_T_new
Removed 0 rows out of 702 rows where cur_ff was not visible bbas or nxt_ff was not visible both bbas and bsans
shared_stops_near_ff_df has 702 rows
Retrieving shared_stops_near_ff_df succeeded
Successfully retrieved diff_in_curv_df from all_monkey_data/planning/monkey_Bruno/data_0222/diff_in_curv_df/opt_arc_stop_closest/test/cur_vis_0_1_window_-25cm_0cm
Successfully retrieved heading_info_df from all_monkey_data/planning/monkey_Bruno/data_0222/heading_info_df/opt_arc_stop_closest/test/Bruno_cur_vis_0_1
Need to make a new heading_info_df so that no data are dropped because ff_y is negative. No curature info is needed for this temporary heading_info_df.
Percentage of rows outside of [-45, 45]: 3.17%
Re

In [None]:
pn.rebin_data_in_new_segments(cur_or_nxt='cur', first_or_last='first', time_limit_to_count_sighting=2,
                                pre_event_window=0, post_event_window=1.5, rebinned_max_x_lag_number=2)


new_seg_duration is now 1.5, and post_event_window is now 1.5
Loaded new_seg_info from all_monkey_data/planning_and_neural/monkey_Bruno/data_0222/new_seg_info/tlim2_cur_first_pre0_post1p5.csv
Dropped 30 columns due to containing NA in rebinned_y_var via calling drop_na_cols function: ['nxt_cntr_arc_curv', 'nxt_opt_arc_curv', 'cur_cntr_arc_curv', 'cur_opt_arc_curv', 'cur_opt_arc_end_heading', 'angle_opt_cur_end_to_nxt_ff', 'angle_from_stop_to_nxt_ff', 'diff_in_angle_to_nxt_ff', 'diff_in_abs_angle_to_nxt_ff', 'traj_curv_to_stop', 'curv_from_stop_to_nxt_ff', 'opt_curv_to_cur_ff', 'curv_from_cur_end_to_nxt_ff', 'd_curv_null_arc', 'd_curv_monkey', 'abs_d_curv_null_arc', 'abs_d_curv_monkey', 'diff_in_d_curv', 'diff_in_abs_d_curv', 'abs_angle_opt_cur_end_to_nxt_ff', 'abs_angle_from_stop_to_nxt_ff', 'abs_diff_in_angle_to_nxt_ff', 'abs_diff_in_abs_angle_to_nxt_ff', 'stop_id_end_time', 'stop_id_duration', 'stop_cluster_id', 'stop_cluster_start_point', 'stop_cluster_end_point', 'stop_cluster_size

  'lag_segment_id', group_keys=False).apply(lag_group)


In [None]:
pn.rebinned_y_var = pn_utils.rebin_segment_data(
    pn.planning_data_by_point, pn.new_seg_info, bin_width=pn.bin_width)


In [None]:
pn.rebinned_y_var[['time_since_target_last_seen']].info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8914 entries, 0 to 8913
Data columns (total 1 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   time_since_target_last_seen  8914 non-null   float64
dtypes: float64(1)
memory usage: 69.8 KB


In [None]:
pn.rebinned_y_var['time_since_target_last_seen'].describe()

count   5127.00000
mean       0.34051
std        0.35699
min        0.00000
25%        0.00000
50%        0.24074
75%        0.59734
max        1.47709
Name: time_since_target_last_seen, dtype: float64

In [None]:
pn.dec.behav_data_by_point

Unnamed: 0,monkey_x,monkey_y,time,point_index,monkey_angle,speed,accel,ang_speed,ang_accel,_contam,...,target_cluster_last_seen_angle,target_cluster_last_seen_angle_to_boundary,monkey_x_target_cluster_last_seen,monkey_y_target_cluster_last_seen,monkey_angle_target_cluster_last_seen,cum_distance_target_cluster_last_seen,target_cluster_has_disappeared_for_last_time_dummy,target_cluster_visible_dummy,curv_of_traj,target_opt_arc_dheading
0,0.00000,30.80000,0.08336,0,1.57080,0.00000,-0.00000,-127.95750,3242.61415,False,...,,,,,,,0,1,0.00000,-0.00000
1,0.00000,30.80000,0.09999,1,-1.56612,0.00000,-0.00000,-92.58555,3891.29539,False,...,,,,,,,0,1,0.00000,0.09967
2,0.00000,30.80000,0.11654,2,-1.56612,0.00000,0.00000,0.00000,2822.78260,False,...,,,,,,,0,1,0.00000,0.09967
3,0.00000,30.80000,0.13314,3,-1.56612,0.00000,-0.00000,-0.00000,0.00000,False,...,,,,,,,0,1,0.00000,0.09967
4,0.00000,30.80000,0.14975,4,-1.56612,0.00000,-0.00000,0.00000,-0.00000,False,...,,,,,,,0,1,0.00000,0.09967
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
218746,44.83097,292.44476,3597.06860,218746,-1.23116,24.51224,-570.18740,0.00000,-0.11538,False,...,0.06286,0.00000,2.86870,457.53769,-1.43900,785446.14219,1,0,0.00228,-0.00000
218747,44.93691,292.15079,3597.08525,218747,-1.23116,14.93512,-448.57057,0.00000,-0.00000,False,...,0.06286,0.00000,2.86870,457.53769,-1.43900,785446.14219,1,0,0.00226,-0.00000
218748,45.00046,291.97443,3597.10177,218748,-1.23116,9.48333,-336.53191,0.00000,-0.00000,False,...,0.06286,0.00000,2.86870,457.53769,-1.43900,785446.14219,1,0,0.00225,-0.00000
218749,45.04283,291.85681,3597.11841,218749,-1.23116,3.76618,-212.95439,-0.00000,-0.00000,False,...,0.06286,0.00000,2.86870,457.53769,-1.43900,785446.14219,1,0,0.00225,-0.00000


In [None]:
pn.dec.get_basic_data()
pn.dec._make_or_retrieve_target_df(exists_ok=False)
pn.dec.make_or_retrieve_target_cluster_df()
pn.dec.target_df

Retrieved monkey_information
The number of points that were removed due to delta_position exceeding the ceiling is 0
Note: ff_caught_T_sorted is replaced with ff_caught_T_new
Retrieved ff_dataframe from all_monkey_data/processed_data/monkey_Bruno/data_0222/ff_dataframe.h5
Retrieved target_df
Retrieved target_cluster_df
Made new target_df


In [None]:
pn.rebinned_y_var.loc[pn.rebinned_y_var['time_since_target_last_seen'].isna(), ['cur_ff_index', 'cur_in_memory', 'time']]


Unnamed: 0,cur_in_memory,time
5127,1.00000,1927.03294
5128,1.00000,1927.09107
5129,1.00000,1927.14085
5130,1.00000,1927.19072
5131,1.00000,1927.24049
...,...,...
8909,1.00000,3582.19590
8910,1.00000,3582.24566
8911,1.00000,3582.29549
8912,1.00000,3582.34525
