# Import packages

In [None]:

%load_ext autoreload
%autoreload 2

from pathlib import Path
import os
if Path.cwd().parts[-1] != 'Multifirefly-Project':
    if Path.cwd().parts[-1] != 'notebooks':
        os.chdir('..')
    from add_path import find_path
    current_path = find_path()
    os.chdir(current_path)

from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.neural_analysis_by_topic.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.neural_analysis_by_topic.planning_and_neural import planning_neural_class, planning_neural_utils, planning_neural_helper_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regression_utils2, ml_methods_class, classification_utils, ml_plotting_utils


import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Retrieve data

## get data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"

In [None]:
reduce_y_var_lags = True
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

ref_point_mode='time after cur ff visible'
ref_point_value=0.1
normalize = False
eliminate_outliers = False
use_curv_to_ff_center = False
curv_of_traj_mode = 'distance'
window_for_curv_of_traj=[-25, 25]
truncate_curv_of_traj_by_time_of_capture = True

pn = planning_neural_class.PlanningAndNeural(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
    pn.planning_data_by_point)
pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

## compare dist of every var

In [None]:
test_data = pn.y_var[pn.y_var['whether_test'] == 1]
ctr_data = pn.y_var[pn.y_var['whether_test'] == 0]

for col in test_data.columns:
#for col in ['target_index']:
    # compare the distribution through histplot (by percentage) of the column in test_data and ctr_data
    plt.figure(figsize=(6, 4))
    sns.histplot(test_data[col].values, label='test', alpha=0.5, stat='percent', kde=True, bins=100)
    sns.histplot(ctr_data[col].values, label='ctr', alpha=0.5, stat='percent', kde=True, bins=100)
    plt.title(f'{col}', fontsize=14)
    plt.xlabel(f'{col}', fontsize=12)
    plt.ylabel('Percentage', fontsize=12)
    plt.legend()
    plt.show()
    

## check NA

In [None]:
general_utils.check_na_in_df(pn.planning_data_by_point)

## var's corr

In [None]:
nxt_cols = [col for col in pn.y_var.columns if 'nxt' in col]
corr_df = pn.y_var.corr()[nxt_cols]

In [None]:
ml_plotting_utils.plot_correlation_heatmap(corr_df)

# LR: on all data together

### Just nxt ff vars

In [None]:
columns_of_interest = ['whether_test']

In [None]:
columns_of_interest = [col for col in pn.y_var.columns if 'nxt' in col]

In [None]:
test_or_control = 'control'

if test_or_control == 'test':
    x_var = pn.test_x_var_lags_reduced
    y_var = pn.test_y_var
elif test_or_control == 'control':
    x_var = pn.control_x_var_lags_reduced
    y_var = pn.control_y_var

y_var_sub = y_var[columns_of_interest]
# With x var lags
y_var_lr_df = neural_data_modeling.get_y_var_lr_df(
    x_var, y_var_sub, verbose=True)

## plot

In [None]:
# Plot features from y_var_lr_df
max_plot_number = 10
count = 0
bins_to_plot = range(len(y_var))

for i, column in enumerate(y_var_lr_df.feature.values): # so that features are plotted in the order of correlation
    if i >= max_plot_number:
        break
    plot_neural_data.plot_regression(y_var, column, x_var, bins_to_plot=None, min_r_squared_to_plot=0.3)

# LR: train-test split

## regularized

In [None]:
columns_of_interest = ['whether_test']

In [None]:
reload(regression_utils2)

In [None]:
test_or_control = None

if test_or_control == 'test':
    x_var = pn.test_x_var_lags_reduced
    y_var = pn.test_y_var
elif test_or_control == 'control':
    x_var = pn.control_x_var_lags_reduced
    y_var = pn.control_y_var
else:
    x_var = pn.x_var_lags_reduced
    y_var = pn.y_var

X_train, X_test, y_train, y_test = planning_neural_utils.train_test_split_based_on_targets(x_var, y_var)

# Basic usage with comprehensive metrics
for y_var_column in columns_of_interest:
    print('y_var_column:', y_var_column)
    y_train_var =  y_train[y_var_column]   
    y_test_var = y_test[y_var_column]

    results, results_df, y_pred_train, y_pred_test = regression_utils2.regularized_regression(
        X_train, y_train_var, X_test, y_test_var, method='ridge', alpha=1.0
    )
    print(results_df)

    # # Create comprehensive report
    report = regression_utils2.regression_metrics_report(
        y_test_var, y_pred_test, model_name="Ridge Regression", show_plots=True
    )
        
    results, results_df = regression_utils2.compare_regularized_models(X_train, y_train_var, X_test, y_test_var, verbose=True, show_plots=False)
    regression_utils2.print_model_comparison_summary(results)
    

## split based on targets

In [None]:
# select 'random_dummy'
pn.y_var = planning_neural_utils.randomly_assign_random_dummy_based_on_targets(pn.y_var)
columns_of_interest = ['random_dummy']

In [None]:
columns_of_interest = ['whether_test']

In [None]:
# select all nxt_ff variables
columns_of_interest = [col for col in pn.y_var.columns if 'nxt' in col]

In [None]:
# select all cur_ff variables
columns_of_interest = [col for col in pn.y_var.columns if 'cur' in col]

In [None]:
test_or_control = None

if test_or_control == 'test':
    x_var = pn.test_x_var_lags_reduced
    y_var = pn.test_y_var
elif test_or_control == 'control':
    x_var = pn.control_x_var_lags_reduced
    y_var = pn.control_y_var
else:
    x_var = pn.x_var_lags_reduced
    y_var = pn.y_var

X_train, X_test, y_train, y_test = planning_neural_utils.train_test_split_based_on_targets(x_var, y_var)
for y_var_column in columns_of_interest:
    print('y_var_column:', y_var_column)
    # if y_var_column is a dummy variable, use logistic regression
    if y_train[y_var_column].nunique() == 1:
        raise ValueError(f"y_var_column {y_var_column} has only one unique value")
    elif y_train[y_var_column].nunique() == 2:
        conf_matrix = classification_utils._use_logistic_regression(X_train , X_test, y_train[y_var_column], y_test[y_var_column])
    else:
        summary_df, y_pred, results, r2_test = regression_utils.use_linear_regression(
            X_train, X_test, y_train[y_var_column], y_test[y_var_column], show_plot=True, y_var_name=y_var_column)

## random split

In [None]:
test_or_control = 'test'

if test_or_control == 'test':
    x_var = pn.test_x_var_lags_reduced
    y_var = pn.test_y_var
elif test_or_control == 'control':
    x_var = pn.control_x_var_lags_reduced
    y_var = pn.control_y_var

columns_of_interest = [col for col in pn.y_var.columns if 'nxt' in col]

ml_inst = ml_methods_class.MlMethods()
for y_var_column in columns_of_interest:
    print('y_var_column:', y_var_column)
    ml_inst.split_and_use_linear_regression(pn.x_var_lags_reduced, pn.y_var[[y_var_column]])


# CCA

https://medium.com/@pozdrawiamzuzanna/canonical-correlation-analysis-simple-explanation-and-python-example-a5b8e97648d2

## conduct cca

In [None]:
cca_no_lag = cca_class.CCAclass(X1=pn.x_var_reduced, X2=pn.y_var_reduced, lagging_included=False)
cca_no_lag.conduct_cca()

cca_lags = cca_class.CCAclass(X1=pn.x_var_lags_reduced.drop(columns='bin', errors='ignore'), X2=pn.y_var_lags_reduced, lagging_included=True)
# for all columns that end with _0, rename them to the column name without the _0
cca_lags.X2.columns = cca_lags.X2.columns.str.replace('_0', '')
cca_lags.conduct_cca()


print(f'pn.x_var_lags.shape: {pn.x_var_lags.shape}')
print(f'pn.y_var_lags_reduced.shape: {pn.y_var_lags_reduced.shape}')

cca_inst = cca_lags

## compare lag vs no lag

In [None]:
can_load_df = pd.DataFrame(cca_no_lag.canon_corr, columns = ['no_lag'])
can_load_df[f'with_lags'] = cca_lags.canon_corr
can_load_df['component'] = [f'CC {i+1}' for i in range(cca_lags.n_components)]
# convert can_load_df to long format
can_load_df_long = pd.melt(can_load_df, id_vars=['component'], var_name='lag', value_name='canon_coeff')

In [None]:
# make a sns bar plot on can_load_df_long
plt.figure(figsize=(8, 6))
sns.barplot(x='component', y='canon_coeff', data=can_load_df_long, hue='lag')
plt.show()

## cca_inst (choose one between lags and no lag)

In [None]:
# choose no lag
cca_inst = cca_no_lag

In [None]:
# choose lags
cca_inst = cca_lags

## test for p values

In [None]:
cca_inst.test_for_p_values()

# Heatmap of loadings

## X1 loadings

In [None]:
# cca_inst.plot_X1_loadings()

## X2 loadings

In [None]:
reload(ml_plotting_utils)
reload(cca_plotting)
reload(cca_class)

In [None]:
pn.y_var_lags_reduced.columns

In [None]:
cca_inst.plot_X2_loadings()

# Canonical Variate scatterplots

In [None]:
components=range(1, 5)
cca_plotting.plot_cca_component_scatter(cca_inst.X1_c, cca_inst.X2_c, components=components, show_y_eq_x=True)


# Transform vars (e.g. use basis functions)

In [None]:
cca_no_lag.X2_tf_df = transform_vars.transform_behav_data(cca_no_lag.X2)
cca_lags.X2_tf_df = transform_vars.transform_behav_data(cca_lags.X2)

In [None]:
# If need to use the data
cca_inst = cca_lags
X1_df = cca_inst.X1_sc_df
X2_df = cca_inst.X2_tf_df


# Lags vs no lag & train vs test

## Get data

In [None]:
combined_cross_view_df, combined_can_load_df = cca_cv_utils.combine_cv_results(cca_no_lag, cca_lags, n_components=7, reg=0.1, n_splits=7)

## cross-view X1

In [None]:
# dataset_name = 'X1'
# cross_view_sub = combined_cross_view_df[combined_cross_view_df['dataset'] == dataset_name]
# cca_plot_lag_vs_no_lag.plot_cca_lag_vs_nolag_and_train_vs_test(cross_view_sub, dataset_name, mode='lag_offset')


## cross-view X2

In [None]:
dataset_name = 'X2'
cross_view_sub = combined_cross_view_df[combined_cross_view_df['dataset'] == dataset_name]
cca_plot_lag_vs_no_lag.plot_cca_lag_vs_nolag_and_train_vs_test(cross_view_sub, dataset_name, mode='lag_offset')


# Just train vs test

## cross-view X1

In [None]:
# filter_significant = True
# sort_by_significance = True
# significance_threshold = 4
# whether_lag = 'lag'

# combined_cross_view_df_sub = combined_cross_view_df[combined_cross_view_df['whether_lag'] == whether_lag]

# # X1
# cca_plot_cv.plot_cca_cv_results(combined_cross_view_df_sub, data_type='X1',
#                                     filter_significant=filter_significant, sort_by_significance=sort_by_significance, significance_threshold=significance_threshold)


## cross-view X2

In [None]:
filter_significant = True
sort_by_significance = True
significance_threshold = 1
whether_lag = 'lag'

combined_cross_view_df_sub = combined_cross_view_df[combined_cross_view_df['whether_lag'] == whether_lag]

# X2
cca_plot_cv.plot_cca_cv_results(combined_cross_view_df_sub, data_type='X2',
                                    filter_significant=filter_significant, sort_by_significance=sort_by_significance, significance_threshold=significance_threshold)


# Appendix

## reduce y_var only by vif

In [None]:
pn.reduce_y_var(save_data=True,
                     corr_threshold_for_lags_of_a_feature=0.97,
                     vif_threshold_for_initial_subset=5, vif_threshold=5, verbose=True,
                     filter_corr_by_all_columns=False,
                     filter_vif_by_subsets=False,
                     filter_vif_by_all_columns=True,
                     exists_ok=False,
                     )

## check final VIF

In [None]:
vif_df = drop_high_vif_vars.get_vif_df(pn.y_var_reduced)
vif_df

In [None]:
vif_df = drop_high_vif_vars.get_vif_df(pn.y_var_lags_reduced)
vif_df

## check correlations in y_var_lags

In [None]:
# sort pn.y_var_lags by column str a to z
pn.y_var_lags2 = pn.y_var_lags.reindex(sorted(pn.y_var_lags.columns), axis=1)

# # sort pn.y_var_lags by column str z to a
# pn.y_var_lags_reduced = pn.y_var_lags_reduced.reindex(sorted(pn.y_var_lags_reduced.columns, reverse=True), axis=1)

In [None]:
pn.y_var_lags2.iloc[:, :10]

In [None]:
pn.y_var_lags2.iloc[:, :10].corr()

## check high corr within feature's lagged columns

In [None]:
df_with_lags = pn.y_var_lags2.copy()
num_original_columns = len(df_with_lags.columns)
base_features = drop_high_corr_vars.get_base_feature_names(df_with_lags)
columns_dropped = []
top_values_by_feature = pd.DataFrame()
for i, feature in enumerate(base_features):
    df_with_lags_sub = drop_high_corr_vars._find_subset_of_df_with_lags_for_current_feature(
        df_with_lags, feature)
    # temp_columns_to_drop, top_values_of_feature = drop_high_corr_vars._drop_lags_for_feature(
    #     df_with_lags, feature, corr_threshold, vif_threshold, use_vif_instead_of_corr, drop_lag_0_last_in_vif)
    if df_with_lags_sub.corr().iloc[1,2] == 1:
        print(feature)

## Compare columns in behav_data (target_decoder) and final_behavioral_data (neural_vs_behavioral_class)

In [None]:
pn.get_behav_data(exists_ok=False)

In [None]:
data_item = neural_vs_behavioral_class.NeuralVsBehavioralClass(raw_data_folder_path=raw_data_folder_path)
data_item.streamline_preparing_neural_and_behavioral_data()

In [None]:
# Columns in pn.behav_data_all but not in data_item.final_behavioral_data
only_in_pn = set(pn.behav_data_all.columns) - set(data_item.final_behavioral_data.columns)
print("Columns only in pn.behav_data_all:")
only_in_pn = np.array(sorted(only_in_pn))
print(only_in_pn)
print('\n \n')

# Columns in data_item.final_behavioral_data but not in pn.behav_data_all
final_behavioral_data_columns = data_item.final_behavioral_data.columns
# remove all 'avg_bin_' prefix
final_behavioral_data_columns = [col.replace('avg_bin_', '') for col in final_behavioral_data_columns]
only_in_data_item = set(final_behavioral_data_columns) - set(pn.behav_data_all.columns)
print("Columns only in data_item.final_behavioral_data:")
only_in_data_item = np.array(sorted(only_in_data_item))
print(only_in_data_item)

In [None]:
pn.cur_and_nxt_ff_df.columns

In [None]:
pn.test_plan_data_inst.df.columns

## just get planning_timestep_data

In [None]:
planning_data_by_point_exists_ok = False

raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"
ref_point_mode='time after cur ff visible'
ref_point_value=0.1
normalize = False
eliminate_outliers = False
use_curv_to_ff_center = False
curv_of_traj_mode = 'distance'
window_for_curv_of_traj=[-25, 25]
truncate_curv_of_traj_by_time_of_capture = True

bin_width=0.02
window_width=0.25
one_behav_idx_per_bin=True

        # get behavioral_data
ph = planning_neural_helper_class.PlanningAndNeuralHelper(raw_data_folder_path=raw_data_folder_path,
                                                                               bin_width=bin_width,
                                                                               window_width=window_width,
                                                                               one_behav_idx_per_bin=one_behav_idx_per_bin)

ph.load_raw_data(raw_data_folder_path)
ph.prep_behav_data_to_analyze_planning(ref_point_mode=ref_point_mode,
                                                            ref_point_value=ref_point_value,
                                                            curv_of_traj_mode=curv_of_traj_mode,
                                                            window_for_curv_of_traj=window_for_curv_of_traj,
                                                            truncate_curv_of_traj_by_time_of_capture=truncate_curv_of_traj_by_time_of_capture,
                                                            use_curv_to_ff_center=use_curv_to_ff_center,
                                                            eliminate_outliers=eliminate_outliers,
                                                            planning_data_by_point_exists_ok=planning_data_by_point_exists_ok
                                                            )