# Import packages

In [None]:
current_path

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import os
if Path.cwd().parts[-1] != 'Multifirefly-Project':
    if Path.cwd().parts[-1] != 'notebooks':
        os.chdir('..')
    from add_path import find_path
    current_path = find_path()
    os.chdir(current_path)

from data_wrangling import specific_utils, process_monkey_information
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.neural_analysis_by_topic.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.planning_and_neural import planning_and_neural_class, planning_and_neural_utils
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting

import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Retrieve data

In [2]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"

In [5]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"

In [None]:
ref_point_mode='time after cur ff visible'
ref_point_value=0.1
normalize = False
eliminate_outliers = False
use_curvature_to_ff_center = False
curv_of_traj_mode = 'distance'
window_for_curv_of_traj=[-25, 25]
truncate_curv_of_traj_by_time_of_capture = True

pn = planning_and_neural_class.PlanningAndNeural(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning()
pn.get_x_and_y_data_for_modeling(exists_ok=True)

### check final VIF

In [None]:
pn.y_var_reduced

In [None]:
vif_df = drop_high_vif_vars.get_vif_df(pn.y_var_reduced)
vif_df

In [None]:
vif_df = drop_high_vif_vars.get_vif_df(pn.y_var_lags_reduced)
vif_df

# LR

In [None]:
# With x var lags
pn.y_var_lr_result_df = neural_data_modeling.get_y_var_lr_result_df(
                pn.x_var_lags_reduced, pn.y_var_reduced)
pn.y_var_lr_result_df

In [None]:
# no lags on x var
pn.y_var_lr_result_df = neural_data_modeling.get_y_var_lr_result_df(
                pn.x_var_reduced, pn.y_var_reduced)
pn.y_var_lr_result_df

In [None]:
# Plot linear regression on X and y
max_plot_number = 3
count = 0
bins_to_plot = range(len(pn.y_var))

for i, column in enumerate(pn.y_var_lr_result_df.feature.values):
    if i >= max_plot_number:
        break
    plot_neural_data.plot_regression(pn.y_var, column, pn.x_var, bins_to_plot=None, min_r_squared_to_plot=0.3)
    # if i == 3:
    #     break

# CCA

https://medium.com/@pozdrawiamzuzanna/canonical-correlation-analysis-simple-explanation-and-python-example-a5b8e97648d2

## No lagging

In [41]:
cca_no_lag = cca_class.CCAclass(X1=pn.x_var_reduced, X2=pn.y_var_reduced, lagging_included=False)

In [None]:
cca_no_lag.conduct_cca()
cca_inst = cca_no_lag

## W lags

In [43]:
cca_lags = cca_class.CCAclass(X1=pn.x_var_lags_reduced, X2=pn.y_var_lags_reduced, lagging_included=False)

In [None]:
cca_lags.conduct_cca()
cca_inst = cca_lags

## loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1', squared=False)

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2', squared=False)

## squared loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1')

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2')

## abs weights ranked

### neurons

In [None]:
cca_inst.plot_ranked_weights()

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2')

## plot real weights

### neurons

In [None]:
cca_inst.plot_ranked_weights(abs_value=False)

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2', abs_value=False)

In [None]:
stop here!

## distribution of each feature

In [None]:
cca_inst.X2_sc.shape

In [None]:
X2_sc_df = pd.DataFrame(cca_inst.X2_sc, columns = cca_inst.X2.columns)
X2_sc_df.describe()

In [None]:
for column in X2_sc_df.columns:
    plt.figure(figsize=(8, 2))
    sns.boxplot(X2_sc_df[column], orient='h')
    plt.show()
    

## heatmap of weights
raw canonical coefficients are interpreted in a manner analogous to interpreting regression coefficients. For example: a one unit increase in reading leads to a .0446 decrease in the first canonical variate of set 2 when all of the other variables are held constant (in some other data)

In [40]:
weight_df = cca_inst.X2_weight_df.copy()
weight_df = weight_df.set_index('feature').drop(columns='feature_category')

In [None]:
plt.subplots(figsize=(15, 25))
sns.heatmap(weight_df.iloc[:20, :10], cmap='coolwarm', annot=True, linewidths=1)
plt.show()

## train test split

In [None]:
train1, test1, train2, test2 = train_test_split(cca_inst.X1_sc, cca_inst.X2_sc, test_size=0.3, random_state=42)
# use training and testing set
nComponents = 10
cca2 = rcca.CCA(kernelcca = False, reg = 0., numCC = nComponents)
cca2.train([train1, train2])
testcorrs = cca2.validate([test1, test2])
testcorrs

## compute explained variance

In [None]:
cca2.compute_ev([test1, test2])

## test for p values

In [None]:
stats_cca = CanCorr(cca_inst.X1_sc, cca_inst.X2_sc)
print(stats_cca.corr_test().summary())
neural_data_modeling.print_weights('X', stats_cca.x_cancoef)
neural_data_modeling.print_weights('Z', stats_cca.y_cancoef)