# Import packages

In [None]:
# !pip install -r multiff_analysis/requirements.txt

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

project_folder = '/Users/dusiyi/Documents/Multifirefly-Project'
os.chdir(project_folder)
sys.path.append(os.path.join(project_folder, 'multiff_analysis', 'methods'))

from data_wrangling import general_utils, specific_utils, process_monkey_information
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from non_behavioral_analysis.neural_data_analysis.get_neural_data import neural_data_processing
from non_behavioral_analysis.neural_data_analysis.visualize_neural_data import plot_neural_data, plot_modeling_result
from non_behavioral_analysis.neural_data_analysis.model_neural_data import cca_class, cca_utils, cca_utils2, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from non_behavioral_analysis.neural_data_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from non_behavioral_analysis.neural_data_analysis.planning_neural import planning_neural_class, planning_neural_utils
from non_behavioral_analysis.neural_data_analysis.decode_targets import behav_features_to_keep, decode_target_class, plot_gpfa_utils, decode_target_utils, fit_gpfa_utils, gpfa_regression_utils

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
from numpy import var
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util




plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")


pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)


%load_ext autoreload
%autoreload 2

# Get data

In [3]:
#raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"
dec = decode_target_class.DecodeTargetClass(raw_data_folder_path=raw_data_folder_path,
                                                               bin_width=0.02, window_width=0.05)

In [None]:
dec.get_x_and_y_var()
dec.reduce_y_var_lags()

# CCA

https://medium.com/@pozdrawiamzuzanna/canonical-correlation-analysis-simple-explanation-and-python-example-a5b8e97648d2

## No lagging

In [5]:
cca_no_lag = cca_class.CCAclass(X1=dec.x_var, X2=dec.y_var_reduced, lagging_included=False)

In [None]:
cca_no_lag.conduct_cca()

## with lags

In [None]:
cca_lags = cca_class.CCAclass(X1=dec.x_var_lags.drop(columns='bin'), X2=dec.y_var_lags_reduced, lagging_included=True)
print(f'dec.x_var_lags.shape: {dec.x_var_lags.shape}')
print(f'dec.y_var_lags_reduced.shape: {dec.y_var_lags_reduced.shape}')

In [None]:
cca_lags.conduct_cca()

## compare lag vs no lag

In [9]:
canon_df = pd.DataFrame(cca_no_lag.canon_corr, columns = ['no_lag'])
canon_df[f'with_lags'] = cca_lags.canon_corr
canon_df['component'] = [f'CC {i+1}' for i in range(cca_lags.n_components)]
# convert canon_df to long format
canon_df_long = pd.melt(canon_df, id_vars=['component'], var_name='lag', value_name='canon_coeff')

In [None]:
# make a sns bar plot on canon_df_long
plt.figure(figsize=(8, 6))
sns.barplot(x='component', y='canon_coeff', data=canon_df_long, hue='lag')
plt.show()

## cca_inst (choose one between lags and no lag)

In [11]:
# choose no lag
cca_inst = cca_no_lag

In [12]:
# choose lags
cca_inst = cca_lags

## loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1', squared=False)

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2', squared=False)

## squared loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1')

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2')

## abs weights ranked

### neurons

In [None]:
cca_inst.plot_ranked_weights()

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2')

## plot real weights

### neurons

In [None]:
cca_inst.plot_ranked_weights(abs_value=False)

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2', abs_value=False)

In [None]:
stop here!

## distribution of each feature

In [None]:
cca_inst.X2_sc.shape

In [None]:
X2_sc_df = pd.DataFrame(cca_inst.X2_sc, columns = cca_inst.X2.columns)
X2_sc_df.describe()

In [None]:
for column in X2_sc_df.columns:
    plt.figure(figsize=(8, 2))
    sns.boxplot(X2_sc_df[column], orient='h')
    plt.show()
    

## heatmap of weights
raw canonical coefficients are interpreted in a manner analogous to interpreting regression coefficients. For example: a one unit increase in reading leads to a .0446 decrease in the first canonical variate of set 2 when all of the other variables are held constant (in some other data)

In [None]:
weight_df = cca_inst.X2_weight_df.copy()
weight_df = weight_df.set_index('feature').drop(columns='feature_category')

In [None]:
plt.subplots(figsize=(15, 25))
sns.heatmap(weight_df.iloc[:20, :10], cmap='coolwarm', annot=True, linewidths=1)
plt.show()

## test for p values

In [None]:
stats_cca = CanCorr(cca_inst.X1_sc, cca_inst.X2_sc)
print(stats_cca.corr_test().summary())
neural_data_modeling.print_weights('X', stats_cca.x_cancoef)
neural_data_modeling.print_weights('Z', stats_cca.y_cancoef)

## compute explained variance

In [None]:
cca2.compute_ev([test1, test2])

In [None]:
stop pls

# train test split

## no lag

In [None]:
train1, test1, train2, test2 = train_test_split(cca_no_lag.X1_sc, cca_no_lag.X2_sc, test_size=0.3, random_state=42)
# use training and testing set
nComponents = 10
cca2 = rcca.CCA(kernelcca = False, reg = 0., numCC = nComponents)
cca2.train([train1, train2])

cca_no_lag.traincorrs = cca2.validate([train1, train2])
cca_no_lag.testcorrs = cca2.validate([test1, test2])

cca_utils.plot_cca_prediction_accuracy_train_test_bars(cca_no_lag.traincorrs, cca_no_lag.testcorrs)
cca_utils.plot_cca_prediction_accuracy_train_test_stacked_bars(cca_no_lag.traincorrs, cca_no_lag.testcorrs)
cca_utils.plot_cca_prediction_accuracy_test_w_bars(cca_no_lag.traincorrs)
cca_utils.plot_cca_prediction_accuracy_w_scatter(cca_no_lag.testcorrs)

## w lags

In [None]:
train1, test1, train2, test2 = train_test_split(cca_lags.X1_sc, cca_lags.X2_sc, test_size=0.3, random_state=42)
# use training and testing set
nComponents = 10
cca2 = rcca.CCA(kernelcca = False, reg = 0., numCC = nComponents)
cca2.train([train1, train2])

cca_lags.traincorrs = cca2.validate([train1, train2])
cca_lags.testcorrs = cca2.validate([test1, test2])

cca_utils.plot_cca_prediction_accuracy_train_test_bars(cca_lags.traincorrs, cca_lags.testcorrs)
cca_utils.plot_cca_prediction_accuracy_train_test_stacked_bars(cca_lags.traincorrs, cca_lags.testcorrs)
cca_utils.plot_cca_prediction_accuracy_test_w_bars(cca_lags.traincorrs)
cca_utils.plot_cca_prediction_accuracy_w_scatter(cca_lags.testcorrs)

## compare lags vs no lag

In [23]:
def plot_cca_prediction_accuracy_train_test_bars_for_lags_and_no_lags(lags_testcorrs, lags_traincorrs, no_lags_testcorrs, no_lags_traincorrs):
    for i in range(2):
        plt.figure(figsize=(10, 6))
        plt.bar(range(len(lags_testcorrs[i])), lags_testcorrs[i], alpha=0.3, label='Test with lags')
        plt.bar(range(len(no_lags_testcorrs[i])), no_lags_testcorrs[i], alpha=0.3, label='Test without lags')
        plt.xlabel('Canonical component index')
        plt.ylabel('Prediction correlation')
        plt.title(f'Test prediction accuracy for set {i+1}')
        plt.legend()
        plt.show()


        plt.figure(figsize=(10, 6))
        plt.bar(range(len(lags_traincorrs[i])), lags_traincorrs[i], alpha=0.3, label='Train with lags')
        plt.bar(range(len(no_lags_traincorrs[i])), no_lags_traincorrs[i], alpha=0.3, label='Train without lags')
        plt.xlabel('Canonical component index')
        plt.ylabel('Prediction correlation')
        plt.title(f'Test prediction accuracy for set {i+1}')
        plt.legend()
        plt.show()

In [None]:
cca_utils.plot_cca_prediction_accuracy_train_test_bars_for_lags_and_no_lags(cca_lags.traincorrs, cca_lags.testcorrs, cca_no_lag.traincorrs, cca_no_lag.testcorrs)

# refactored more

In [193]:
combined_X1_df, combined_X2_df = cca_utils2.combine_data_to_compare_train_and_test(cca_no_lag, cca_lags)

In [None]:
cca_utils2.plot_lag_offset_train_test_overlap(combined_X2_df, 'DatasetName', mode='lag_offset')


In [None]:
cca_utils2.plot_lag_offset_train_test_overlap(combined_X2_df, 'DatasetName', mode='train_offset')


# sparse-CCA

In [None]:
!pip install cca-zoo


In [None]:
import numpy as np
from cca_zoo.models import SparseCCA

# Generate synthetic data
np.random.seed(42)
n_samples = 100
n_features_x = 20
n_features_y = 15

# X and Y with some shared structure plus noise
X = np.random.randn(n_samples, n_features_x)
Y = np.random.randn(n_samples, n_features_y)

# Inject correlation in first 3 variables
for i in range(3):
    Y[:, i] = X[:, i] + 0.1 * np.random.randn(n_samples)

# Initialize Sparse CCA model
model = SparseCCA(latent_dims=1, c=[0.1, 0.1])  # c controls sparsity penalty (smaller = sparser)

# Fit model
model.fit((X, Y))

# Get canonical weights
w_x = model.weights[0]
w_y = model.weights[1]

print("Sparse CCA weights for X:")
print(w_x)

print("\nSparse CCA weights for Y:")
print(w_y)


# Now: could you use neural data to decode target position?

what about 2nd target's position?
(can either use 1st target's decoder, or train and separate decoder for 2nd target)

also...try GPFA at some point 