# Import packages

In [None]:
cost_results.sort_values(by='num_caught_ff', ascending=False)

In [None]:
from pathlib import Path
import os, sys
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

%load_ext autoreload
%autoreload 2

from data_wrangling import specific_utils, combine_info_utils, general_utils, further_processing_class
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features, category_class
from decision_making_analysis.ff_data_acquisition import cluster_replacement_utils
from decision_making_analysis.data_compilation import miss_events_class
from decision_making_analysis.ff_data_acquisition import ff_data_utils
from decision_making_analysis.data_compilation import miss_events_across_sessions
from decision_making_analysis.data_enrichment import miss_events_enricher

from decision_making_analysis.data_compilation import miss_events_class
from decision_making_analysis.data_compilation import miss_events_across_sessions
from visualization.matplotlib_tools import plot_trials, plot_behaviors_utils
from visualization.animation import animation_class
from null_behaviors import show_null_trajectory, find_best_arc, curvature_utils, curv_of_traj_utils
from machine_learning.ml_methods import regression_utils, classification_utils, prep_ml_data_utils, hyperparam_tuning_class
from visualization.plotly_polar_tools import plotly_for_ff_polar, plotly_for_trajectory_polar
from machine_learning.ml_methods import ml_methods_class
from machine_learning.ml_methods.advanced_ml_methods import advanced_regression_utils, advanced_classification_utils, reg_feat_importance

import os, sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists
import math
import copy
import matplotlib.pyplot as plt
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import gc
from scipy import stats
from IPython.display import HTML
from matplotlib import rc
from sklearn.svm import SVC
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import warnings
import os, sys, sys
from importlib import reload
from sklearn.exceptions import ConvergenceWarning
import seaborn as sns

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
pd.options.display.max_rows = 50

# Run overnight

In [None]:
# Predict rcap vs rsw
for monkey_name in ['monkey_Bruno', 'monkey_Schro']:
    cgts = miss_events_across_sessions.MissEventsAcrossSessions()
    combd_rsw_x_df, combd_rcap_x_df = cgts.streamline_getting_combd_rsw_or_rcap_x_df(monkey_name=monkey_name)

# data across sessions

In [None]:
monkey_name = 'monkey_Schro'
cgts = miss_events_across_sessions.MissEventsAcrossSessions()
cgts.streamline_getting_combd_decision_making_basic_ff_info(monkey_name=monkey_name, exists_ok=True)
decision_making_basic_ff_info_cleaned = general_utils.drop_rows_with_any_na(cgts.combd_decision_making_basic_ff_info)

# data for one session

In [None]:
exists_ok = True
cgt = miss_events_class.MissEventsClass(ref_point_mode='time', 
                                            raw_data_folder_path='all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330',
                                            ref_point_value=-1.5,)
cgt.make_decision_making_basic_ff_info()
decision_making_basic_ff_info_cleaned = cgt.decision_making_basic_ff_info_cleaned.copy()

## also cur & alt ff info

In [None]:
gc_kwargs = miss_events_enricher.gc_kwargs.copy()

raw_data_folder_path = 'all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330'
gcc = miss_events_class.MissEventsClass(raw_data_folder_path=raw_data_folder_path, 
                                                        gc_kwargs=gc_kwargs, new_point_index_start=0)
_ = gcc.streamline_process_to_collect_info_from_one_session(miss_events_info_exists_ok=True)

# shared part with 'data from all sessions'
gcc.process_current_and_alternative_ff_info()
more_ff_attributes = ['ff_distance', 'ff_angle', 'curv_diff']
ff_last_seen_attributes = ['last_seen_' + attribute for attribute in more_ff_attributes] + ['distance_from_monkey_now_to_monkey_when_ff_last_seen', 'angle_from_monkey_now_to_monkey_when_ff_last_seen']


gcc.prepare_data_to_predict_rsw_vs_rcap(add_arc_info=True, add_current_curv_of_traj=True, 
                                 use_alt_ff_only=True,       
                          ff_attributes=['ff_distance', 'ff_angle', 'time_since_last_vis', 'duration_of_last_vis_period']
                          + ff_last_seen_attributes)

gcc.prepare_data_for_machine_learning(furnish_with_trajectory_data=False)




In [None]:
resume = False
tune = False

cols_to_use = [col for col in gcc.X_all_df.columns if ('last' in col) or ('mask' in col)]
ml_inst = ml_methods_class.MlMethods(x_var_df= gcc.X_all_df[cols_to_use],
                                     y_var_df=gcc.y_var_df)

ml_inst.use_train_test_split(ml_inst.x_var_df, ml_inst.y_var_df)
model, y_pred, model_comparison_df = advanced_classification_utils.use_advanced_model_for_classification(
    ml_inst.X_train, ml_inst.y_train, ml_inst.X_test, ml_inst.y_test,
    kfold_cv=5,
    tune=tune,                # turn tuning on/off
    n_iter=30,                # ~3k samples sweet spot
    tune_scoring="balanced_accuracy",
    checkpoint_dir=f"all_monkey_data/decision_making/{gcc.monkey_name}/pred_num_stops/cls_runs",   # folder to save progress
    resume=resume,              # skip finished models on rerun
    n_jobs=-1,
    verbose=True
)

In [None]:
gcc.X_all_df.columns

In [None]:
sns.histplot(gcc.miss_event_cur_ff['ff_distance'])

In [None]:
sns.histplot(gcc.X_all_df['ff_distance_0'])

# exp

In [None]:
gcc.miss_event_cur_ff.columns

In [None]:
gcc.miss_event_cur_ff

In [None]:
gcc.miss_event_cur_ff.groupby('point_index').size()

# want: next capture

but also just info of available ff up to the moment of stop....? especially since next ff might not be visible yet...

i guess we can take out the trials where next ff was available at that point, and trials where it wasn't, to compare

In [None]:
# # to get 'nxt_captured_ff' info

# cgt._get_rcap_df()
# cgt._get_rsw_df()

# cgt.new_rcap_df = cgt.rcap_df[cgt.rcap_df['ff_index'] < len(cgt.ff_caught_T_new) - 1].reset_index(drop=True)
# cgt.new_rcap_df['nxt_captured_ff'] = cgt.new_rcap_df['ff_index'] + 1

# cgt.new_rsw_df = cgt.rsw_df[cgt.rsw_df['ff_index'] < len(cgt.ff_caught_T_new) - 1].reset_index(drop=True)
# cgt.new_rsw_df['nxt_captured_ff'] = np.searchsorted(cgt.ff_caught_T_new, cgt.new_rsw_df['stop_time'].values)
# cgt.new_rsw_df['nxt_capture_time'] = cgt.ff_caught_T_new[cgt.new_rsw_df['nxt_captured_ff'].values]
# cgt.new_rsw_df['prev_capture_time'] = cgt.ff_caught_T_new[cgt.new_rsw_df['nxt_captured_ff'].values - 1]

# get x_df

In [None]:
cgt.streamline_getting_rsw_or_rcap_x_df(rsw_or_rcap='rcap', exists_ok=exists_ok)
cgt.streamline_getting_rsw_or_rcap_x_df(rsw_or_rcap='rsw', exists_ok=exists_ok)

In [None]:
pd.set_option('display.max_rows', 200)
pd.set_option('display.max_columns', 200)

In [None]:
# list(cgt.rsw_x_df.columns)

In [None]:
cgt.rcap_x_df.head(3)

In [None]:
cgt.rcap_events_df.head(3)

In [None]:
cgt.rcap_df.head(3)

# SELECT features

In [None]:
decision_making_basic_ff_info_cleaned.columns

In [None]:
attributes = ['ff_distance_ff_last_seen',
            'ff_angle_ff_last_seen',
            'ff_angle_boundary_ff_last_seen',
            'time_since_ff_last_seen']

In [None]:
#attributes = ['ff_distance', 'ff_angle','ff_angle_boundary', 'time_since_last_vis']
'''
不能用ff_distance, 因为毕竟reference point是point on the trajectory closest to the stop. 
用distance to ff 的话, 很大程度上暴露了到底是不是rcap (虽说不是百分百暴露）
'''

## regress on dwell time

In [None]:
x_var_df = decision_making_basic_ff_info_cleaned[attributes + ['whether_switched']]
y_var_df = decision_making_basic_ff_info_cleaned[['stop_id_duration']].copy()


ml_inst = ml_methods_class.MlMethods(x_var_df=x_var_df,
                                     y_var_df=y_var_df)

In [None]:
## Or
data_sub = decision_making_basic_ff_info_cleaned[decision_making_basic_ff_info_cleaned['whether_switched'] == 1].copy()
x_var_df = data_sub[attributes]
y_var_df = data_sub[['stop_id_duration']].copy()

ml_inst = ml_methods_class.MlMethods(x_var_df=x_var_df,
                                     y_var_df=y_var_df)

In [None]:
ml_inst.use_ml_model_for_regression(ml_inst.x_var_df, ml_inst.y_var_df, model_names=['linreg', 'svr', 'dt', 'bagging', 'boosting', 'grad_boosting', 'rf'])

In [None]:
ml_inst.model_comparison_df

In [None]:
stop

In [None]:
decision_making_basic_ff_info_cleaned

## classify type

In [None]:
x_var_df = decision_making_basic_ff_info_cleaned[attributes]
y_var_df = decision_making_basic_ff_info_cleaned[['whether_switched']].copy()

ml_inst = ml_methods_class.MlMethods(x_var_df=x_var_df,
                                     y_var_df=y_var_df)

In [None]:
ml_inst.use_ml_model_for_classification(ml_inst.x_var_df, ml_inst.y_var_df)

In [None]:
import statsmodels.api as sm

X2 = sm.add_constant(ml_inst.x_var_df)  # add intercept
logit_model = sm.Logit(ml_inst.y_var_df, X2)  # logistic regression
result = logit_model.fit()

print(result.summary())

# model's feature selection

In [None]:
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(penalty='l1', solver='liblinear')
model.fit(ml_inst.x_var_df, ml_inst.y_var_df)
important = model.coef_[0] != 0
X_new = ml_inst.x_var_df.loc[:, important]
X_new.columns

In [None]:
## Hmmm I don't think this is needed at the moment

# from sklearn.feature_selection import SequentialFeatureSelector
# from sklearn.ensemble import RandomForestClassifier

# model = RandomForestClassifier()
# sfs = SequentialFeatureSelector(model, n_features_to_select=min(10, len(ml_inst.x_var_df.columns)-1), direction="forward")
# X_new = sfs.fit_transform(ml_inst.x_var_df, ml_inst.y_var_df.values.ravel())

In [None]:

# # Boolean mask of selected features
# mask = sfs.get_support()

# # Names of selected features
# selected_features = ml_inst.x_var_df.columns[mask]

# print("Selected features:")
# print(selected_features)

## statsmodels, logreg

In [None]:
import statsmodels.api as sm

X2 = sm.add_constant(ml_inst.x_var_df)  # add intercept
logit_model = sm.Logit(ml_inst.y_var_df, X2)  # logistic regression
result = logit_model.fit()

print(result.summary())

## random forest

In [None]:
from sklearn.ensemble import RandomForestClassifier
import pandas as pd

rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(ml_inst.x_var_df, ml_inst.y_var_df)

importances = rf.feature_importances_
feature_importance = pd.DataFrame({
    "feature": ml_inst.x_var_df.columns,
    "importance": importances
}).sort_values(by="importance", ascending=False)

print(feature_importance)

## grad_boosting (so that we can see feature importance)

In [None]:
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import seaborn as sns


# Define the model
model = GradientBoostingClassifier(
    learning_rate=0.05, max_depth=7, max_features='sqrt',
    min_samples_leaf=2, min_samples_split=7,
    n_estimators=500, subsample=0.5
)


# Fit the model
model.fit(ml_inst.x_var_df, ml_inst.y_var_df)

# Get feature importances
feature_importances = model.feature_importances_

# Create a DataFrame for feature importances
feature_importances_df = pd.DataFrame({
    'Feature': ml_inst.x_var_df.columns,
    'Importance': feature_importances
})

# Sort the DataFrame by importance
feature_importances_df = feature_importances_df.sort_values(by='Importance', ascending=False)

# Plot feature importances
plt.figure(figsize=(10, 18))
sns.barplot(x='Importance', y='Feature', data=feature_importances_df)
plt.title('Feature Importances')
plt.show()

# Determine significant features (e.g., importance > 0.01)
significant_features = feature_importances_df[feature_importances_df['Importance'] > 0.01]
print("Significant features:")
print(significant_features)

# Advanced classification

In [None]:
tune = False

ml_inst.use_train_test_split(ml_inst.x_var_df, ml_inst.y_var_df)
model, y_pred, model_comparison_df = advanced_classification_utils.use_advanced_model_for_classification(
    ml_inst.X_train, ml_inst.y_train, ml_inst.X_test, ml_inst.y_test,
    kfold_cv=5,
    tune=tune,                # turn tuning on/off
    n_iter=30,                # ~3k samples sweet spot
    tune_scoring="balanced_accuracy",
    checkpoint_dir=f"all_monkey_data/decision_making/{cgts.monkey_name}/pred_num_stops/cls_runs",   # folder to save progress
    resume=False,              # skip finished models on rerun
    n_jobs=-1,
    verbose=True
)

# In the future

In [None]:
# btw, what's the reference point ??? (like at which point are we predicting rcap vs rsw?)

In [None]:
 # what might be interesting to add from rcap_x_df:
 
 'angle_from_stop_to_nxt_ff',
 'angle_from_cur_ff_to_nxt_ff',
 
 
 # monkey's own curvature info?
 # so like find curv of traj based on [-25, 0] or something,
 
 
 # rather than ff last seen, what about at ref?
 # (btw, I'm not gonna use exactly the below...will tweak it)
 'cur_ff_angle_diff_boundary_at_ref',
 'cur_ff_flash_duration_at_ref',
 'cur_ff_earliest_flash_rel_time_at_ref',
 'cur_ff_latest_flash_rel_time_at_ref',
 
 
 # and also all the eye-related features

# Compare distributions of features

## basic features

In [None]:
attributes = ['ff_distance_ff_last_seen', 
            'time_since_ff_last_seen',
            'ff_angle_ff_last_seen', 
            'ff_angle_boundary_ff_last_seen',
            ]

In [None]:
attributes = ['ff_distance', 'ff_angle', 'ff_angle_boundary', 'time_since_last_vis']

In [None]:
# for each feature in significant_features, plot the histogram of the feature for each class
for feature in attributes:
    sns.histplot(x=feature, data=decision_making_basic_ff_info_cleaned, stat='probability', kde=False, hue='whether_switched', common_norm=False)
    plt.title(f'{feature} histogram')
    plt.show()

## complex features

In [None]:
rcap = cgt.rcap_x_df.copy()
rcap['whether_rcap'] = 1
rsw = cgt.rsw_x_df.copy()
rsw['whether_rcap'] = 0
both_df = pd.concat([rcap, rsw], axis=0)

In [None]:
# for each feature in significant_features, plot the histogram of the feature for each class
for feature in ['cur_ff_distance_at_ref', 'cur_ff_angle_at_ref']:
    sns.histplot(x=feature, data=both_df, stat='probability', kde=False, hue='whether_rcap', common_norm=False)
    plt.title(f'{feature} histogram')
    plt.show()

In [None]:
# for each feature in significant_features, plot the histogram of the feature for each class
max_features_to_plot = 3
count = 0
for feature in significant_features['Feature']:
    sns.histplot(x=feature, data=both_df, stat='probability', kde=False, hue='whether_rcap', common_norm=False)
    plt.title(f'{feature} histogram')
    plt.show()
    count += 1
    if count >= max_features_to_plot:
        break

# check vif

In [None]:
# # can skip this if only wanting ML results
# pd.set_option('display.max_rows', 100)
# ml_inst.use_vif(ml_inst.x_var_df)
# features_w_big_vif = ml_inst.vif_df[ml_inst.vif_df['vif'] > 100].feature.values
# #ml_inst.x_var_df = ml_inst.x_var_df.drop(columns=features_w_big_vif)
# ml_inst.vif_df.head(20)

In [None]:
# specific_columns = ml_inst.vif_df[ml_inst.vif_df["VIF"] > 2000].feature.values
# ml_inst.show_correlation_heatmap(specific_columns=specific_columns)
# ml_inst.show_correlation_heatmap()

# cProfile

In [None]:
# # test and see what is taking so long in running a function
# import cProfile

# cProfile.run("cgt.streamline_getting_rsw_or_rcap_x_df(rsw_or_rcap='rsw', exists_ok=False)", sort='cumtime')

# #ncalls  tottime  percall  cumtime  percall