# Import packages

In [None]:
#  !pip install elephant neo quantities

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys

project_folder = '/Users/dusiyi/Documents/Multifirefly-Project'
os.chdir(project_folder)
sys.path.append(os.path.join(project_folder, 'multiff_analysis', 'methods'))

from data_wrangling import specific_utils, process_monkey_information
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from non_behavioral_analysis.neural_data_analysis.get_neural_data import neural_data_processing
from non_behavioral_analysis.neural_data_analysis.visualize_neural_data import plot_neural_data, plot_modeling_result
from non_behavioral_analysis.neural_data_analysis.model_neural_data import cca_class, pgam_class, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from non_behavioral_analysis.neural_data_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from non_behavioral_analysis.neural_data_analysis.planning_neural import planning_neural_class, planning_neural_utils
from non_behavioral_analysis.neural_data_analysis.decode_targets import behav_features_to_keep, decode_target_class, plot_gpfa_utils, decode_target_utils

import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Get data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
dec = decode_target_class.DecodeTargetClass(raw_data_folder_path=raw_data_folder_path,
                                                               bin_width=0.1, window_width=0.25)
dec.streamline_making_behav_and_neural_data()
dec.get_x_and_y_var()

dec.pursuit_data.head(3)


# Reduce columns in lags

In [None]:
dec.reduce_x_var_lags()

In [None]:
dec.reduce_y_var_lags()

## check result of reducing

In [None]:
## also check correlations between x vars without lags
high_corr_pair_df, top_n_corr_df = drop_high_corr_vars.get_pairs_of_columns_w_high_corr(
            dec.x_var, corr_threshold=0.8)
top_n_corr_df

# Linear regression (didn't modify yet)

Regressing the behavioral variables individually (as y_var) against all neural activity

## put results in df

In [None]:
data_item.make_or_retrieve_y_var_lr_result_df(exists_ok=True)
data_item.y_var_lr_result_df.head(5)

In [None]:
    def make_or_retrieve_y_var_lr_result_df(self, exists_ok=True):
        df_path = os.path.join(self.lr_result_df_path,
                               'y_var_lr_result_df.csv')
        if exists_ok & exists(df_path):
            self.y_var_lr_result_df = pd.read_csv(df_path)
        else:
            self.y_var_lr_result_df = neural_data_modeling.get_y_var_lr_result_df(
                self.x_var, self.y_var)
            self.y_var_lr_result_df.to_csv(df_path, index=False)
            print('Made new y_var_lr_result_df')


##  plot all neural clusters vs one behavioral var

In [None]:
[col for col in dec.y_var_lags_reduced.columns if 'bin' in col]

In [None]:
dec.y_var_lags_reduced['bin'] = dec.y_var_lags_reduced['bin_5']

In [None]:
dec.y_var_lags_reduced['bin']

In [None]:
dec.y_var_lags_reduced['bin'] = dec.y_var_lags_reduced['bin_5'].astype(int)

In [None]:
# conduct linear regression on X and y
plt.rcParams["figure.figsize"] = (20, 10)
bins_to_plot = range(dec.y_var_lags_reduced.bin.max())
dec.y_var_lags_reduced['bin'] = dec.y_var_lags_reduced['bin_5'].astype(int)
for i, column in enumerate(dec.y_var_lags_reduced.columns):

    
    plot_neural_data.plot_regression(dec.y_var_lags_reduced, column, dec.x_var, bins_to_plot=None, min_r_squared_to_plot=0.3)
    # if i == 3:
    #     break

##  plot one neural cluster vs one behavioral var

In [None]:
# plot one neural cluster against one behavioral variable
cluster_num, behavioral_column = 6, 'monkey_speed'
bins_to_plot = range(1000, 1200)
x_values, y_values = data_item.binned_spikes_matrix[bins_to_plot, cluster_num], data_item.final_behavioral_data[behavioral_column][bins_to_plot]
reg = LinearRegression().fit(x_values.reshape(-1, 1), y_values)

plt.scatter(x_values, y_values, color='blue', s=1)
plt.plot(x_values, reg.predict(x_values.reshape(-1, 1)), color='red', linewidth=1)
plt.show()

# GPFA

## elephant example

In [19]:
spike_segments = []

for index, row in dec.single_vis_target_df.iterrows():
    mask = dec.spike_df.time.between(row['last_vis_time'], row['ff_caught_time'])
    spikes_sub = dec.spike_df[mask].copy()
    spikes_sub['segment'] = index
    spikes_sub['segment_start_time'] = row['last_vis_time']
    spikes_sub['segment_stop_time'] = row['ff_caught_time']
    spike_segments.append(spikes_sub)

spike_segs_df = pd.concat(spike_segments, ignore_index=True)


In [None]:
spike_segs_df['t_duration'] = spike_segs_df['segment_stop_time'] - spike_segs_df['segment_start_time']
max(spike_segs_df['t_duration'])

In [None]:
bin_size = 20 * pq.ms

In [None]:
%pdb off

In [70]:
# Get unique clusters and segments
clusters = spike_segs_df.cluster.unique()
segments = spike_segs_df.segment.unique()

# Create spiketrain objects (in Neo)
spiketrains = []
common_t_stop = max(spike_segs_df['t_duration'])

# Process each segment and cluster combination
for seg in segments:
    # Get data for this segment
    spike_df_trial = spike_segs_df[spike_segs_df.segment == seg]
    
    # Get segment start and stop times (should be the same for all rows in this segment)
    seg_start_time = spike_df_trial.segment_start_time.iloc[0]

    seg_spiketrain = []
    
    for cluster in clusters:
        # Get spikes for this cluster in this segment
        sub = spike_df_trial[spike_df_trial.cluster == cluster]
        
        # Calculate relative spike times
        spike_time = sub.time - seg_start_time

        # Create SpikeTrain object
        spiketrain = neo.SpikeTrain(
            times=spike_time.values * pq.s,  # Convert to quantities
            t_start=0 * pq.s,
            t_stop=common_t_stop
        )
        seg_spiketrain.append(spiketrain)

    spiketrains.append(seg_spiketrain)

In [None]:
print(len(spiketrains))
print(len(spiketrains[0]))
print(len(spiketrains[0][2]))

## fit model

In [41]:
import numpy as np
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from mpl_toolkits.mplot3d import Axes3D


In [None]:
latent_dimensionality = 3
gpfa_3dim = GPFA(bin_size=bin_size, x_dim=latent_dimensionality)
trajectories = gpfa_3dim.fit_transform(spiketrains)


In [None]:
plot_gpfa_utils.plot_gpfa_traj_3d_uniform_color(trajectories)


In [None]:
# First, enable interactive mode in your notebook
%matplotlib inline

# Import required modules
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Create the interactive plot
fig, ax = plot_gpfa_utils.plot_gpfa_traj_3d(
    trajectories=trajectories,
    figsize=(15, 5),
    linewidth_single_trial=0.75,
    alpha_single_trial=0.3,
    linewidth_trial_average=2,
    title='Latent dynamics extracted by GPFA',
    view_azim=-5,
    view_elev=60
)

plt.show()

In [None]:
fig = plot_gpfa_utils.plot_gpfa_traj_3d_plotly(trajectories)

In [None]:
from numpy import var

traj_stack = np.stack(trajectories, axis=0)  # shape: (n_trials, 3, T)
var_by_dim = var(traj_stack, axis=(0, 2))    # variance across trials and time
var_by_dim /= var_by_dim.sum()               # normalize to get explained variance ratio
print("Variance explained by each latent dimension:", var_by_dim)


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(figsize=(15, 5))

ax.set_title('Latent dynamics extracted by GPFA')
ax.set_xlabel('Time [s]')

time = np.arange(len(average_trajectory[0])) * 0.02  # assuming all trajectories have the same length

for i, x in enumerate(average_trajectory):
    ax.plot(time, x, label=f'Dim {i+1}')

ax.legend()

plt.tight_layout()
plt.show()


## regression

In [None]:
dec.behav_data_all

In [None]:
import numpy as np
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import cross_val_score, KFold
from sklearn.multioutput import MultiOutputRegressor
import matplotlib.pyplot as plt

def time_resolved_regression(latent_trajectories, behavioral_data, time_step=0.02, cv_folds=5):
    """
    Time-resolved regression predicting multiple continuous behavioral variables.

    Parameters:
    - latent_trajectories: np.array (n_trials, n_timepoints, n_latent_dims)
    - behavioral_data: np.array (n_trials, n_timepoints, n_behaviors), continuous behavioral variables
    - time_step: float, time bin size in seconds
    - cv_folds: int, cross-validation folds

    Returns:
    - scores_by_time: np.array (n_timepoints, n_behaviors), cross-validated R² scores
    - times: np.array (n_timepoints,), time points vector
    """

    n_trials, n_timepoints, n_latent_dims = latent_trajectories.shape
    _, _, n_behaviors = behavioral_data.shape

    scores_by_time = np.zeros((n_timepoints, n_behaviors))

    kf = KFold(n_splits=cv_folds, shuffle=True, random_state=42)

    for t in range(n_timepoints):
        X = latent_trajectories[:, t, :]  # (n_trials, n_latent_dims)
        Y = behavioral_data[:, t, :]       # (n_trials, n_behaviors)

        for b in range(n_behaviors):
            y = Y[:, b]

            # Ridge regression with built-in CV for alpha, wrapped in cross_val_score for out-of-sample R²
            model = RidgeCV(alphas=np.logspace(-6, 6, 13))
            cv_scores = cross_val_score(model, X, y, cv=kf, scoring='r2')
            scores_by_time[t, b] = cv_scores.mean()

    times = np.arange(n_timepoints) * time_step
    return scores_by_time, times

# Example dummy data usage:
np.random.seed(0)
n_trials = 100
n_timepoints = 50
n_latent_dims = 5
n_behaviors = 2

latent_trajectories = np.random.randn(n_trials, n_timepoints, n_latent_dims)
behavioral_data = np.random.randn(n_trials, n_timepoints, n_behaviors)  # continuous behaviors

scores, times = time_resolved_regression(latent_trajectories, behavioral_data)

# Plot results for each behavior:
plt.figure(figsize=(10, 6))
for b in range(n_behaviors):
    plt.plot(times, scores[:, b], label=f'Behavior {b+1}')
plt.xlabel('Time [s]')
plt.ylabel('Cross-validated R²')
plt.title('Time-resolved regression: neural latent -> continuous behavior')
plt.legend()
plt.tight_layout()
plt.show()


# CCA

https://medium.com/@pozdrawiamzuzanna/canonical-correlation-analysis-simple-explanation-and-python-example-a5b8e97648d2

## No lagging

In [102]:
cca_no_lag = cca_class.CCAclass(X1=dec.x_var, X2=dec.y_var_reduced, lagging_included=False)

In [None]:
cca_no_lag.conduct_cca()

## with lags

In [None]:
cca_lags = cca_class.CCAclass(X1=dec.x_var_lags, X2=dec.y_var_lags_reduced, lagging_included=True)

In [None]:
dec.y_var_lags_reduced.shape

In [None]:
cca_lags.conduct_cca()

In [None]:
gc.collect()

## compare lag vs no lag

In [105]:
canon_df = pd.DataFrame(cca_no_lag.canon_corr, columns = ['no_lag'])
canon_df[f'lag_{dec.max_lag_number}'] = cca_lags.canon_corr
canon_df['component'] = [f'CC {i+1}' for i in range(cca_lags.n_components)]
# convert canon_df to long format
canon_df_long = pd.melt(canon_df, id_vars=['component'], var_name='lag', value_name='canon_coeff')

In [None]:
# make a sns bar plot on canon_df_long
plt.figure(figsize=(8, 6))
sns.barplot(x='component', y='canon_coeff', data=canon_df_long, hue='lag')
plt.show()

## cca_inst (choose one between lags and no lag)

In [55]:
# choose lags
cca_inst = cca_lags

In [None]:
# choose no lag
cca_inst = cca_no_lag

## loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1', squared=False)

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2', squared=False)

## squared loadings

### neurons

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X1')

### behavior

In [None]:
cca_inst.plot_ranked_loadings(X1_or_X2='X2')

## abs weights ranked

### neurons

In [None]:
cca_inst.plot_ranked_weights()

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2')

## plot real weights

### neurons

In [None]:
cca_inst.plot_ranked_weights(abs_value=False)

### behavior

In [None]:
cca_inst.plot_ranked_weights(X1_or_X2='X2', abs_value=False)

In [None]:
stop here!

## distribution of each feature

In [None]:
cca_inst.X2_sc.shape

In [None]:
X2_sc_df = pd.DataFrame(cca_inst.X2_sc, columns = cca_inst.X2.columns)
X2_sc_df.describe()

In [None]:
for column in X2_sc_df.columns:
    plt.figure(figsize=(8, 2))
    sns.boxplot(X2_sc_df[column], orient='h')
    plt.show()
    

## heatmap of weights
raw canonical coefficients are interpreted in a manner analogous to interpreting regression coefficients. For example: a one unit increase in reading leads to a .0446 decrease in the first canonical variate of set 2 when all of the other variables are held constant (in some other data)

In [None]:
weight_df = cca_inst.X2_weight_df.copy()
weight_df = weight_df.set_index('feature').drop(columns='feature_category')

In [None]:
plt.subplots(figsize=(15, 25))
sns.heatmap(weight_df.iloc[:20, :10], cmap='coolwarm', annot=True, linewidths=1)
plt.show()

## train test split

In [None]:
train1, test1, train2, test2 = train_test_split(cca_inst.X1_sc, cca_inst.X2_sc, test_size=0.3, random_state=42)
# use training and testing set
nComponents = 10
cca2 = rcca.CCA(kernelcca = False, reg = 0., numCC = nComponents)
cca2.train([train1, train2])
testcorrs = cca2.validate([test1, test2])
testcorrs

## compute explained variance

In [None]:
cca2.compute_ev([test1, test2])

## test for p values

In [None]:
stats_cca = CanCorr(cca_inst.X1_sc, cca_inst.X2_sc)
print(stats_cca.corr_test().summary())
neural_data_modeling.print_weights('X', stats_cca.x_cancoef)
neural_data_modeling.print_weights('Z', stats_cca.y_cancoef)

# PGAM (unfinished)

In [None]:
## Categorize variables
dec.y_var_reduced.columns
temporal_vars = ['time_rel_to_stop',
 'time_when_nxt_ff_first_seen_rel_to_stop',
 'time_when_cur_ff_first_seen_rel_to_stop',
 'time_when_nxt_ff_last_seen_rel_to_stop',
 'time_when_cur_ff_last_seen_rel_to_stop',
 ]

spatial_vars = [x for x in dec.y_var_reduced.columns if x not in temporal_vars]
spatial_vars

# Inspect data

## sparsity of neural data

In [None]:
dec.binned_spikes_df.shape

In [None]:
# inspect neural data

bins = dec.binned_spikes_df

# Calculate percentage of non-zero rows for each column
non_zero_percentages = (bins != 0).mean() * 100

# Create a DataFrame with the results
non_zero_df = pd.DataFrame({
    'Column': non_zero_percentages.index,
    'Percent_Non_Zero': non_zero_percentages.values
})

# Sort by percentage in descending order
non_zero_df = non_zero_df.sort_values('Percent_Non_Zero', ascending=False)

print("Percentage of non-zero values in each column:")
print(non_zero_df)


In [None]:
bins.drop(columns='bin').mean(axis=1).describe()

# plot the percentile of values of mean firing rates across neurons at each time bin
mean_rates = bins.drop(columns='bin').mean(axis=1)

# Calculate percentiles from 0 to 100
percentiles = np.arange(0, 101, 1)
percentile_values = np.percentile(mean_rates, percentiles)

# Create plot
plt.figure(figsize=(6, 4))
plt.plot(percentiles, percentile_values)
plt.xlabel('Percentile')
plt.ylabel('Mean Firing Rate')
plt.title('Distribution of Mean Firing Rates Across Neurons')
plt.grid(True)
plt.show()


## multicollinearity

### y var (behavioral)

In [None]:
y_var_vif = drop_high_vif_vars.get_vif_df(dec.y_var)
print(y_var_vif.head(8))

# calculate the correlation coefficient among the columns with VIF > 5
# specific_columns = vif_df[vif_df['vif'] > 5].feature.values
specific_columns = y_var_vif.feature.values[:10]
corr_coeff = dec.y_var[specific_columns].corr()
#plt.figure(figsize = (6, 6))
plt.figure(figsize = (8, 6))
sns.heatmap(corr_coeff, cmap='coolwarm', annot=True, linewidths=1, vmin=-1)
plt.show()

In [None]:
# Try y_var_reduced

y_var_vif = drop_high_vif_vars.get_vif_df(dec.y_var_reduced)
print(y_var_vif.head(8))

# calculate the correlation coefficient among the columns with VIF > 5
# specific_columns = vif_df[vif_df['vif'] > 5].feature.values
specific_columns = y_var_vif.feature.values[:10]
corr_coeff = dec.y_var[specific_columns].corr()
#plt.figure(figsize = (6, 6))
plt.figure(figsize = (8, 6))
sns.heatmap(corr_coeff, cmap='coolwarm', annot=True, linewidths=1, vmin=-1)
plt.show()

## plot trial segments in pursuit_data

In [None]:
from visualization.matplotlib_tools import plot_trials,
dec.make_PlotTrials_args()

In [None]:
plt.rcParams['figure.figsize'] = [10, 10]                     

max_plot_to_make = 2
plot_counter = 0

for index, row in dec.single_vis_target_df.iloc[2:].iterrows():

    duration = [row['last_vis_time'], row['ff_caught_time']]

    returned_info = plot_trials.PlotTrials(
                duration, 
                *dec.PlotTrials_args,  
                adjust_xy_limits=True,       
                minimal_margin=50,
                show_reward_boundary=True,
                show_alive_fireflies=False,
                show_visible_fireflies=True,
                show_in_memory_fireflies=True,
                show_believed_target_positions=True,
                )
    plt.show()
    

    plot_counter += 1
    if plot_counter >= max_plot_to_make:
        break

### check target_rel_x and y
(The look correct after checking)

In [None]:
pursuit_sub = dec.pursuit_data.loc[dec.pursuit_data['target_index']==65].copy()
pursuit_sub['target_angle_deg'] = pursuit_sub['target_angle'] * 180/pi 

In [None]:
pursuit_sub[['point_index', 'target_angle_deg', 'target_distance', 'target_rel_x', 'target_rel_y']]

# Appendix

## more columns (possibly get in the future)

get also get: (but to be honest, it doesn't make that much sense to get them....so let's skip for now.)
'distance traversed since target last visible',
'd angle since target last visible', 'target_at_right',
'time_till_capture', 'time from last visible to capture

Note that there might be multicollinearity. For example, duration from last visible to capture = time since target last visible + time till capture

Similarly, target angle = target angle last seen frozen - d angle since target last visible

(For distance it's not exactly the same because of the difference between distance and distance traversed, but it's still similar)

The multicollinearity is fine in linear regression (when each feature here is a y var), but need to be dealt with in cca.

## other thoughts

should i actually align each section, as if they are trials???
maybe i can try both that and continuous time... both can shed light on different behavioral variables
but for aligning trials, it may require alignment or warping since trial durations vary.

btw, what does it mean stitch data?

also, what does it look like to use RNN to model it?
I thought about the paper that Noah presented on


btw.......IME

## why ratio of bin/target_index approaches constant

In [None]:
trial_lengths = dec.pursuit_data[['target_index', 'bin']].groupby('target_index').count()
trial_lengths.describe()

In [None]:
sub = dec.y_var_reduced[['time', 'bin', 'target_index']]
sub['factor'] = dec.y_var_reduced['bin']/dec.y_var_reduced['target_index']
sub

In [None]:
plt.hist(np.diff(dec.ff_caught_T_sorted), bins=30)
plt.xlabel('Time difference')
plt.ylabel('Count')
plt.title('Distribution of time differences between caught events')
plt.show()


In [None]:
dec.ff_caught_T_sorted/np.arange(len(dec.ff_caught_T_sorted))

## compared with neural_data_modeling

In [None]:
data_item = neural_vs_behavioral_class.NeuralVsBehavioralClass(raw_data_folder_path=raw_data_folder_path)
data_item.streamline_preparing_neural_and_behavioral_data()

In [None]:
data_item.final_behavioral_data

In [None]:
data_item.y_var

In [None]:
dec.y_var_reduced

In [None]:
data_item.y_var.columns

In [None]:
dec.y_var_reduced.columns

In [None]:
[col for col in data_item.y_var.columns if col not in dec.y_var_reduced.columns]

In [None]:
[col for col in dec.y_var_reduced.columns if col not in data_item.y_var.columns]