# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding import pn_decoding_utils, plot_pn_decoding, pn_decoding_model_specs
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding.interactions import add_interactions, discrete_decoders, conditional_decoding, interaction_decoding, interaction_plots
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils

import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
from catboost import CatBoostRegressor

from sklearn.linear_model import RidgeCV
from sklearn.model_selection import cross_val_score

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca
import quantities as pq

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# retrieve data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0321"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0329"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0403"

In [None]:
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0312"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0316"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True
bin_width = 0.1

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path, bin_width=bin_width)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
    pn.planning_data_by_point)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

In [None]:
pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=False)

# get planning_data by segment

## get data and fit gpfa

In [None]:
pn.prepare_seg_aligned_data(start_t_rel_event=-0.25, end_t_rel_event=1.25, end_at_stop_time=False)
pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=True)

# for regression later
use_raw_spike_data_instead = False
use_lagged_rebinned_behav_data = False
pn.get_concat_data_for_regression(use_raw_spike_data_instead=use_raw_spike_data_instead,
                                  use_lagged_rebinned_behav_data=use_lagged_rebinned_behav_data,
                                  apply_pca_on_raw_spike_data=True,
                                  use_lagged_raw_spike_data=False,) 


pn.print_data_dimensions()

## gpfa DURING train test split (also point-wise)

In [None]:
pn.prepare_seg_aligned_data()
pn.get_concat_data_for_regression(use_raw_spike_data_instead=True) 

# Get data (try only a few features right now)

In [None]:

key_features = [
    'new_bin', 'new_segment', 'whether_test',
    'cur_eye_hor_l', 'cur_eye_ver_l', 'cur_eye_hor_r', 'cur_eye_ver_r',
    'nxt_eye_hor_l', 'nxt_eye_ver_l', 'nxt_eye_hor_r', 'nxt_eye_ver_r',
    'LDz', 'RDz', 'LDx', 'RDx',
    'gaze_mky_view_x', 'gaze_mky_view_y', 'gaze_mky_view_angle',
    'cur_opt_arc_dheading',
    'cur_ff_distance',
    'cur_ff_rel_x',
    'cur_ff_rel_y',
    'nxt_ff_rel_x',
    'nxt_ff_rel_y',
    'nxt_ff_distance',
    'num_ff_visible',
    'num_ff_in_memory',
    'cur_ff_distance_at_ref',
    'cur_ff_angle_boundary_at_ref',
    'nxt_ff_distance_at_ref',
    'ang_speed',
    'speed',
    'accel',
    'ang_accel',
    'monkey_speeddummy',
    'curv_of_traj',
    'angle_from_cur_ff_to_nxt_ff',
    'time_since_last_capture',
    'bin_mid_time_rel_to_event',
    'time', 
    'target_index',
    # categorical modeling for the below:
    'cur_vis',
    'nxt_vis',
    'nxt_in_memory',
    'any_ff_visible',
    # 'cur_in_memory', # don't used those two cause they will just be one
    # 'any_ff_in_memory',
    ]

In [None]:
has_duplicates = len(key_features) != len(set(key_features)) 
print(has_duplicates)

dupes = {x for x in key_features if key_features.count(x) > 1}
print(dupes)

In [None]:
# pn.prepare_seg_aligned_data()

In [None]:
pn.get_concat_data_for_regression(use_raw_spike_data_instead=False,
                                    apply_pca_on_raw_spike_data=False,
                                    use_lagged_raw_spike_data=True) 

In [None]:
pn.concat_behav_trials


In [None]:
import numpy as np
import pandas as pd

pn.concat_behav_trials, added_cols = pn_decoding_utils.prep_behav(pn.concat_behav_trials)
pn.rebinned_behav_data, _ = pn_decoding_utils.prep_behav(pn.rebinned_behav_data)
key_features = list(set(key_features + added_cols))
key_features = [f for f in key_features if f in pn.concat_behav_trials.columns]
pn.concat_behav_trials = pn.concat_behav_trials[key_features].copy()
pn.rebinned_behav_data = pn.rebinned_behav_data[key_features].copy()
    
# (Optional) peek at shapes
print('concat_behav_trials:', pn.concat_behav_trials.shape)
print('rebinned_behav_data:', pn.rebinned_behav_data.shape)

In [None]:
mask = pn.concat_behav_trials['bin_mid_time_rel_to_event'] > 0
pn.concat_behav_trials = pn.concat_behav_trials[mask]
pn.concat_neural_trials = pn.concat_neural_trials[mask]

### Add interaction

In [None]:
pn.concat_behav_trials, cols_added = pn_decoding_utils.add_interaction_terms_and_features(pn.concat_behav_trials)
key_features2 = (
    ['cur_ff_distance', 'log1p_cur_ff_distance', 'speed',
        'accel', 'time_since_last_capture']
    + cols_added
)

print("Speed quantiles:")
for q in [0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1]:
    val = pn.concat_behav_trials['speed'].quantile(q)
    print(f"  {q:.2f}: {val:.5f}")



# Build interaction labels

## find ranges of vars

In [None]:
pn.concat_behav_trials.columns

In [None]:
# also should i use 'cur_ff_angle' too?

# CAN ALSO have interaction based on stop

In [None]:
for var in ['speed', 'ang_speed', 'accel', 'ang_accel', 'cur_ff_distance', 'nxt_ff_distance', 'cur_ff_angle', 'nxt_ff_angle', 'cur_ff_rel_x', 'cur_ff_rel_y', 'cur_ff_distance_at_ref']:
    print(f"Quantiles for {var}:")
    for q in [0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 1]:
        val = pn.planning_data_by_point[var].quantile(q)
        print(f"  {q:.2f}: {val:.5f}")
    print("\n")




## add bands

In [None]:
pn.get_concat_data_for_regression(use_raw_spike_data_instead=True) 
df = pn.concat_behav_trials.copy()

In [None]:
df = add_interactions.add_behavior_bands(df)

In [None]:
df = add_interactions.add_pairwise_interaction(
    df=df,
    var_a='speed_band',
    var_b='cur_ff_angle_band',
    new_col='speed_angle_state',
)

y_var, x_var = add_interactions.prune_rare_states_two_dfs(
    df,
    pn.concat_neural_trials,
    label_col='speed_angle_state',
    min_count=200
)

results_df = discrete_decoders.sweep_decoders_xy(
    x_df=x_var,
    y_df=y_var,
    label_col='speed_angle_state',
    model_types=['logreg', 'svm', 'ridge'],
)

summary_df = (
    results_df
    .groupby('model', as_index=False)
    .agg(mean_bal_acc=('balanced_accuracy', 'mean'))
)

print(summary_df)

## iterate

In [None]:
PAIRWISE_INTERACTIONS = [

    # =========================================================
    # 1. Core movement control (highest priority)
    # =========================================================

    # Speed × geometry: classic steering regimes
    ('speed_band', 'cur_ff_angle_band', 'speed_angle_state'),
    ('speed_band', 'cur_ff_dist_band', 'speed_distance_state'),

    # Speed × motor output
    ('speed_band', 'ang_speed_band', 'speed_turnrate_state'),
    ('speed_band', 'accel_band', 'speed_accel_state'),

    # =========================================================
    # 2. Geometry × distance (navigation state)
    # =========================================================

    # How far + how misaligned am I from current target?
    ('cur_ff_angle_band', 'cur_ff_dist_band', 'angle_distance_state'),

    # Lateral geometry vs forward progress (optional but clean)
    ('cur_ff_rel_x_band', 'cur_ff_dist_band', 'lateral_distance_state'),

    # =========================================================
    # 3. Planning / lookahead geometry
    # =========================================================

    # Current vs next target geometry (planning competition)
    ('cur_ff_angle_band', 'nxt_ff_angle_band', 'cur_next_angle_state'),

    # Commitment stage × next-target relevance
    ('cur_ff_dist_band', 'nxt_ff_dist_band', 'curdist_nextdist_state'),

    # =========================================================
    # 4. Control dynamics (policy change / replanning)
    # =========================================================

    # Acceleration conditioned on geometry
    ('accel_band', 'cur_ff_angle_band', 'accel_angle_state'),
    ('accel_band', 'cur_ff_dist_band', 'accel_distance_state'),

    # Turn acceleration vs current geometry (replanning signal)
    ('ang_accel_band', 'cur_ff_angle_band', 'angaccel_angle_state'),

    # =========================================================
    # 5. Commitment / learning (late-stage, optional)
    # =========================================================

    # Early vs late commitment interacting with geometry
    ('cur_ff_dist_ref_band', 'cur_ff_angle_band', 'commit_angle_state'),

    # Commitment timing × speed (hesitation vs execution)
    ('cur_ff_dist_ref_band', 'speed_band', 'commit_speed_state'),
]


In [None]:
max_to_plot = 10
counter = 0
for var_a, var_b, new_col in PAIRWISE_INTERACTIONS:
    out = interaction_decoding.run_pairwise_interaction_analysis(
        x_df=pn.concat_neural_trials,
        y_df=df,
        var_a=var_a,
        var_b=var_b,
        interaction_col=new_col,
    )


    fig = interaction_plots.plot_pairwise_interaction_analysis(
        analysis_out=out,
        interaction_name=new_col,
        var_a=var_a,
        var_b=var_b,
    )

    plt.show()

    if counter >= max_to_plot:
        break
    counter += 1


In [None]:
out['cond_var_a_summary']

## decode

In [None]:
df = add_interactions.add_pairwise_interaction(
    df=df,
    var_a='speed_band',
    var_b='cur_ff_angle_band',
    new_col='speed_angle_state',
)

y_var, x_var = add_interactions.prune_rare_states_two_dfs(
    df,
    pn.concat_neural_trials,
    label_col='speed_angle_state',
    min_count=200
)

results_df = discrete_decoders.sweep_decoders_xy(
    x_df=x_var,
    y_df=y_var,
    label_col='speed_angle_state',
    model_types=['logreg', 'svm', 'ridge'],
)

summary_df = (
    results_df
    .groupby('model', as_index=False)
    .agg(mean_bal_acc=('balanced_accuracy', 'mean'))
)

## conditioned

In [None]:
reload(discrete_decoders)

In [None]:
results_df = conditional_decoding.compare_component_conditioned_vs_global(
    x_df=x_var,
    y_df=y_var,
    target_col='cur_ff_angle_band',
    condition_col='speed_band',
    model_type='logreg',
)

summary_df = (
    results_df
    .groupby('context', as_index=False)
    .agg(mean_bal_acc=('balanced_accuracy', 'mean'))
)


In [None]:
summary_df

## hier

In [None]:
hier_df = discrete_decoders.hierarchical_decode_speed_angle(
    x_df=x_var,
    y_df=y_var,
    speed_col='speed_band',
    angle_col='cur_ff_angle_band',
    model_type='logreg',
)

hier_df

## cross_condition_decode

In [None]:
rows = []

rows.append(discrete_decoders.cross_condition_decode(
    x_var, y_var,
    target_col='cur_ff_angle_band',
    condition_col='speed_band',
    train_conditions=['FAST', 'SLOW'],
    test_conditions=['CRUISE'],
))

rows.append(discrete_decoders.cross_condition_decode(
    x_var, y_var,
    target_col='cur_ff_angle_band',
    condition_col='speed_band',
    train_conditions=['CRUISE'],
    test_conditions=['FAST'],
))

cross_df = pd.DataFrame([r for r in rows if r is not None])
cross_df

## hyperparam tuning

In [None]:

results_df = discrete_decoders.decode_with_param_sweep_xy(
    x_df=x_var,                         # neural features
    y_df=y_var,                         # behavior labels
    label_col='speed_angle_state',     # what you decode
    model_type='logreg',               # ONE model family
    n_splits=5,
)


## chance level

In [None]:
def add_chance_level(results_df, df, label_col):
    probs = df[label_col].value_counts(normalize=True)
    chance = np.sum(probs ** 2)

    results_df = results_df.copy()
    results_df['chance'] = chance
    results_df['above_chance'] = results_df['accuracy'] - chance

    return results_df


In [None]:
summary_df = (
    results_df
    .groupby(['label', 'model'], as_index=False)
    .agg(
        mean_acc=('accuracy', 'mean'),
        sem_acc=('accuracy', lambda x: x.std() / np.sqrt(len(x)))
    )
)


In [None]:
len(pn.concat_behav_trials.columns)

In [None]:
pn.planning_data_by_bin.columns.to_list()

In [None]:
df = pn.concat_behav_trials.copy()

In [None]:
df

In [None]:
stop!

# Regression on all data

## regular methods

In [None]:
print(pn.concat_neural_trials.shape)
print(pn.concat_behav_trials.shape)

pn.separate_test_and_control_data()
# columns_of_interest = ['whether_test', 'cur_ff_distance', 'cur_ff_angle', 'nxt_ff_distance', 'nxt_ff_rel_y', 'nxt_opt_arc_dheading', 'nxt_ff_rel_x', 'nxt_ff_angle', 'nxt_ff_angle_at_ref']
columns_of_interest = pn.concat_behav_trials.columns
all_results = []
#for test_or_control in ['both']: #['test', 'control', 'both']:
for test_or_control in ['test', 'control', 'both']:
    x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control=test_or_control)
    
    results_summary = ml_methods_utils.run_segment_split_regression_cv(
        x_var, 
        y_var, 
        columns_of_interest, 
        num_folds=5, 
    )
    results_summary['test_or_control'] = test_or_control
    all_results.append(results_summary)

all_results = pd.concat(all_results)
all_results.head()




reg_results = all_results[all_results['Model'] == 'Linear Regression']
class_results = all_results[all_results['Model'] == 'Logistic Regression']



# first only plot key_features2 cur_ff_distance
for metric in ['test_r2']:
    ml_methods_utils.make_barplot_to_compare_results(
            reg_results, 
            metric=metric, 
            features=key_features2,
        )
    print('='*100)
    print('='*100)



rest_of_features = [c for c in reg_results['Feature'].unique() if c not in key_features2]

# regression results
for metric in ['test_r2']:
    ml_methods_utils.make_barplot_to_compare_results(
            reg_results, 
            metric=metric, 
            features=rest_of_features,
        )
    print('='*100)
    print('='*100)
    
# classification results
for metric in ['test_roc_auc', 'test_f1']:
    ml_methods_utils.make_barplot_to_compare_results(
        class_results, 
        metric=metric, 
    )

## CatBoostRegressor

In [None]:
results_df = pn_decoding_utils.run_cv_decoding(
    X=x_var,
    y_df=y_var,
    behav_features=key_features2,
    groups=y_var['new_segment'].values,
    n_splits=5,
    config=pn_decoding_utils.DecodingRunConfig(
        fast_mode=False,
        make_plots=True,
        n_jobs=-1,
        use_early_stopping=False,  # matches original behavior
    ),
    context_label='pooled',
)


# Regression on Subsets

# Num visible ff (Regression on Subsets)

(Take out subsets of data based on num_ff_visible)

In [None]:
# pn.planning_data_by_point['num_ff_visible'].unique()

In [None]:
x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control='both')
y_var, cols_added = pn_decoding_utils.add_interaction_terms_and_features(y_var)
y_var, added_cols = pn_decoding_utils.prep_behav(y_var)

## Try one model

In [None]:
ff_visibility_col = 'num_ff_visible'

config = pn_decoding_utils.DecodingRunConfig(
    fast_mode=False,
    make_plots=False,
    n_jobs=-1,
    use_early_stopping=True,
)

results_df = pn_decoding_utils.decode_by_num_ff_visible_or_in_memory(
    x_var,
    y_var,
    key_features2,
    config=config,
)


In [None]:
from IPython.display import display

for behav_feature in results_df['behav_feature'].unique():
    display(
        results_df[results_df['behav_feature'] == behav_feature][[ff_visibility_col, 'r2_cv', 'r_cv', 'rmse_cv', 'n_samples']]
        .sort_values(by='r2_cv', ascending=False)
        .reset_index(drop=True)
        .style
        .format(precision=3)
        .set_caption(behav_feature)
    )
    
    break


## Iterate through models

In [None]:
# ff_visibility_col = 'num_ff_in_memory'
ff_visibility_col = 'num_ff_visible'

In [None]:
all_results = []

save_path = os.path.join(pn.planning_and_neural_folder_path, 'pn_decoding', 'conditioned_on_ff_visibility')
for model_name, spec in pn_decoding_model_specs.MODEL_SPECS.items():
    
    config = pn_decoding_utils.DecodingRunConfig(
        model_class=spec['model_class'],
        model_kwargs=spec['model_kwargs'],
        use_early_stopping=False,
        make_plots=False,
    )

    results_df = pn_decoding_utils.decode_by_num_ff_visible_or_in_memory(
        x_var,
        y_var,
        key_features2,
        config=config,
        save_path=save_path,
        ff_visibility_col=ff_visibility_col,
        
    )

    results_df['model_name'] = model_name
    all_results.append(results_df)

all_results_df = pd.concat(all_results, ignore_index=True)


### heatmap

In [None]:
plot_pn_decoding.plot_decoding_heatmaps_with_n(
    all_results_df,
    ff_visibility_col)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

for behav_feature in all_results_df['behav_feature'].unique():
    sub = all_results_df.query('behav_feature == @behav_feature')

    df = (
        sub
        .pivot_table(
            index=ff_visibility_col,
            columns='model_name',
            values='r2_cv',
            aggfunc='mean'
        )
        .sort_index()
    )

    # ---- compute n_samples per visibility level ----
    n_per_row = (
        sub
        .groupby(ff_visibility_col)['n_samples']
        .mean()   # or .first() if guaranteed identical
        .loc[df.index]
        .astype(int)
    )

    # ---- build ytick labels with sample size ----
    yticklabels = [
        f'{idx} ({n})'
        for idx, n in zip(df.index, n_per_row)
    ]

    plt.figure(figsize=(1.2 * df.shape[1], 1.2 * df.shape[0]))
    vmax = max(0.1, df.max().max())

    ax = sns.heatmap(
        df,
        annot=True,
        fmt='.3f',
        cmap='viridis',
        vmin=0,
        vmax=vmax,
        cbar_kws={'label': 'CV $R^2$'}
    )

    ax.set_yticklabels(yticklabels, rotation=0)

    plt.title(behav_feature)
    plt.ylabel(f'{ff_visibility_col} (with sample size)')
    plt.xlabel('Model')
    plt.tight_layout()
    plt.show()


### print df

In [None]:
# from IPython.display import display

# for behav_feature in all_results_df['behav_feature'].unique():
#     display(
#         all_results_df[all_results_df['behav_feature'] == behav_feature][[ff_visibility_col, 'model_name', 'r2_cv', 'r_cv', 'rmse_cv', 'n_samples']]
#         .sort_values(by='r2_cv', ascending=False)
#         .reset_index(drop=True)
#         .style
#         .format(precision=3)
#         .set_caption(behav_feature)
#     )
    
#     # break


## Cur ff visible only

In [None]:

results_cur_only = pn_decoding_utils.decode_cur_ff_only(
    x_var,
    y_var,
    key_features2,
    ff_visibility_col=ff_visibility_col,
    config=config,
)


In [None]:
all_results_df2 = pd.concat(
    [all_results_df, results_cur_only],
    ignore_index=True,
)


In [None]:
ff_visibility_col

In [None]:
plot_pn_decoding.plot_decoding_heatmaps_with_n(
    all_results_df2,
    ff_visibility_col)

# Point-wise regressions

## point-wise segment regressions

In [None]:
pn.retrieve_or_make_time_resolved_cv_scores_gpfa(latent_dimensionality=7, exists_ok=False)

In [None]:
pn.plot_time_resolved_regression(time_resolved_cv_scores = pn.time_resolved_cv_scores_gpfa)

In [None]:
stop!

In [None]:
pn.time_resolved_cv_scores_gpfa['trial_count'] = pn.time_resolved_cv_scores_gpfa['train_trial_count'].astype(int)
# features_to_plot = None
features_to_plot=['time_rel_to_stop', 'cur_ff_distance', 'cur_ff_distance_at_ref', 'time_since_last_capture']
pn.plot_time_resolved_regression(time_resolved_cv_scores = pn.time_resolved_cv_scores_gpfa, features_to_plot=features_to_plot)

In [None]:
plot_time_resolved_regression.plot_trial_counts_by_timepoint(
            pn.time_resolved_cv_scores_gpfa, 'trial_count')

In [None]:
pn.concat_behav_trials[features_to_plot].corr()

## point-wise segment regression (for ppt)

In [None]:
pn.make_time_resolved_cv_scores()

In [None]:

features_to_plot = [
'time', 'time_rel_to_stop',
'target_distance',
'target_angle',
'target_rel_x',
'target_rel_y',
'speed',
'stop']

pn.time_resolved_cv_scores.loc[pn.time_resolved_cv_scores['feature'] == 'monkey_speeddummy', 'feature'] = 'stop'
pn.plot_time_resolved_regression(features_to_plot=features_to_plot, n_behaviors_per_plot=8)



In [None]:
pn.time_resolved_cv_scores.loc[pn.time_resolved_cv_scores['feature'] == 'monkey_speeddummy', 'feature'] = 'stop'
for features in [['target_distance', 'target_rel_y'],
                 ['target_rel_x', 'target_angle'],
                 ['time', 'time_rel_to_stop'],
                 ['speed', 'stop']]:
    
    pn.plot_time_resolved_regression(features_to_plot=features)



In [None]:
pn.plot_trial_counts_by_timepoint()  # 

## point-wise segment regression

In [None]:
pn.prepare_seg_aligned_data()
pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=True)

use_raw_spike_data_instead = False
use_lagged_rebinned_behav_data = False
pn.get_concat_data_for_regression(use_raw_spike_data_instead=use_raw_spike_data_instead,
                                  use_lagged_rebinned_behav_data=use_lagged_rebinned_behav_data,
                                  apply_pca_on_raw_spike_data=True,
                                  use_lagged_raw_spike_data=False,) 

In [None]:
pn.retrieve_or_make_time_resolved_cv_scores()

### plot some

In [None]:
pn.plot_time_resolved_regression(features_to_plot=['time_rel_to_stop', 'cur_ff_distance', 'cur_ff_distance_at_ref', 'time_since_last_capture'])

### plot all

In [None]:
pn.plot_time_resolved_regression()

## compare cv score: GPFA Inside vs Outside CV Loop


### one feature

In [None]:
feature = 'event_time'
new_cv_scores = pd.concat([pn.time_resolved_cv_scores[['new_bin', 'bin_mid_time', 'trial_count', feature]], 
                           pn.time_resolved_cv_scores_gpfa[[feature]].rename(columns={feature: f'{feature}_cv_w_gpfa'})], axis=1)
pn.plot_time_resolved_regression(time_resolved_cv_scores=new_cv_scores, score_threshold_to_plot=None,
                                 rank_by_max_score=False)

### all features

In [None]:
ranked_features = pn.time_resolved_cv_scores.max().sort_values(ascending=False).index.values
features_not_to_plot = ['new_bin', 'new_seg_duration', 'trial_count', 'bin_mid_time']
ranked_features = [feature for feature in ranked_features if feature not in features_not_to_plot]
for feature in ranked_features:
    print(feature)
    print('='*100)
    new_cv_scores = pd.concat([pn.time_resolved_cv_scores[['new_bin', 'bin_mid_time', 'trial_count', feature]], 
                            pn.time_resolved_cv_scores_gpfa[[feature]].rename(columns={feature: f'{feature}_cv_w_gpfa'})], axis=1)
    pn.plot_time_resolved_regression(time_resolved_cv_scores=new_cv_scores, score_threshold_to_plot=None,
                                        rank_by_max_score=False)

# Others

## trial count per time point

In [None]:
pn.plot_trial_counts_by_timepoint()

## plot latent dimensions

In [None]:
raw_data_folder_path

In [None]:
plot_gpfa_utils.plot_gpfa_traj_3d_timecolored_average(pn.trajectories)

In [None]:
plot_gpfa_utils.plot_gpfa_traj_3d_uniform_color(pn.trajectories)

In [None]:
# First, enable interactive mode in your notebook
%matplotlib inline

# Import required modules
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Create the interactive plot
fig, ax = plot_gpfa_utils.plot_gpfa_traj_3d(
    trajectories=pn.trajectories,
    figsize=(15, 5),
    linewidth_single_trial=0.75,
    alpha_single_trial=0.3,
    linewidth_trial_average=2,
    title='Latent dynamics extracted by GPFA',
    view_azim=-5,
    view_elev=60
)

plt.show()

In [None]:
# fig = plot_gpfa_utils.plot_gpfa_traj_3d_plotly(trajectories)

In [None]:
# Find variance explained by each latent dimension
traj_stack = np.stack(pn.trajectories, axis=0)  # shape: (n_trials, 3, T)
var_by_dim = np.var(traj_stack, axis=(0, 2))    # variance across trials and time
var_by_dim /= var_by_dim.sum()               # normalize to get explained variance ratio
print("Variance explained by each latent dimension:", var_by_dim)

In [None]:

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(figsize=(15, 5))

ax.set_title('Latent dynamics extracted by GPFA')
ax.set_xlabel('Time [s]')

average_trajectory = np.mean(pn.trajectories, axis=0)
time = np.arange(len(average_trajectory[0])) * pn.bin_width  # assuming all trajectories have the same length

for i, x in enumerate(average_trajectory):
    ax.plot(time, x, label=f'Dim {i+1}')

ax.legend()

plt.tight_layout()
plt.show()

## check corr between vars

In [None]:
feature1 = 'new_segment'
feature2 = 'target_index'

from scipy.stats import pearsonr

x = pn.concat_behav_trials[feature1].values
y = pn.concat_behav_trials[feature2].values

r, p = pearsonr(x, y)
print(f"Pearson r: {r:.16f}")

In [None]:
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
import numpy as np

# Extract the data
x = pn.concat_behav_trials[feature1].values.reshape(-1, 1)  # Ensure x is 2D
y = pn.concat_behav_trials[feature2].values          # y can remain 1D

# Fit the linear regression model
model = LinearRegression()
model.fit(x, y)

# Predict y values for the regression line
y_pred = model.predict(x)

# Plot scatter and regression line
plt.figure(figsize=(8, 6))
plt.scatter(x, y, color='blue', label='Data points', alpha=0.6)
plt.plot(x, y_pred, color='red', linewidth=2, label='Regression line')
plt.xlabel('Behavioral Time')
plt.ylabel('Neural Event Time')
plt.title('Linear Regression of Neural vs Behavioral Time')
plt.legend()
plt.grid(True)
plt.show()

## why poor performance?

In [None]:
stop!

In [None]:
import neural_data_analysis.neural_analysis_tools.gpfa_methods.time_resolved_regression as time_resolved_regression

# 1. Print number of trials per timepoint
time_resolved_regression.print_trials_per_timepoint(pn.gpfa_neural_trials)

# 2. Check for NaNs
time_resolved_regression.check_for_nans_in_trials(pn.gpfa_neural_trials, name='latent')
time_resolved_regression.check_for_nans_in_trials(pn.behav_trials, name='behavioral')

# 3. Standardize trials
latent_trials_std = time_resolved_regression.standardize_trials(pn.gpfa_neural_trials)
behav_trials_std = time_resolved_regression.standardize_trials(pn.behav_trials)

# 4. Plot latent and behavioral variables for a few trials
time_resolved_regression.plot_latents_and_behav_trials(latent_trials_std, behav_trials_std, pn.bin_width, n_trials=5)

# why time prediction is good

# Appendix

## see rel_cur_ff_first_seen_time distribution

In [None]:
rel_seen_time_df = pn.planning_data_by_point[['rel_cur_ff_first_seen_time_bbas', 'rel_cur_ff_last_seen_time_bbas']].drop_duplicates().reset_index(drop=True)
sns.histplot(rel_seen_time_df['rel_cur_ff_first_seen_time_bbas'], bins=50, label='cur ff first seen')
sns.histplot(rel_seen_time_df['rel_cur_ff_last_seen_time_bbas'], bins=50, label='cur ff last seen')
plt.xlabel('Time relative to stop (s)')
plt.ylabel('Count')
plt.title('Time relative to stop')
plt.legend()
plt.show()

## debug inconsistent number of new_bins

In [None]:
# example trajectories
for traj in pn.trajectories[:]:
    print(traj.shape)

In [None]:
for traj in pn.behav_trials[:]:
    print(traj.shape)

In [None]:
segments = pn.rebinned_behav_data.groupby('new_segment').size()[pn.rebinned_behav_data.groupby('new_segment').size() < 10].index

In [None]:
pn.rebinned_behav_data[pn.rebinned_behav_data['new_segment'].isin(segments)]

## exp: to match (new_segment, new_bin) tuples

In [None]:
reload(pn_aligned_by_event)
reload(gpfa_helper_class)

In [None]:
pn.prepare_seg_aligned_data(cur_or_nxt='cur', first_or_last='last', time_limit_to_count_sighting=2,
                              start_t_rel_event=1, end_t_rel_event=1.25, rebinned_max_x_lag_number=2)

In [None]:
pn.rebinned_y_var.columns

In [None]:
pn.rebinned_y_var['bin_mid_time_rel_to_event'].unique()

In [None]:
pn.rebinned_y_var.groupby('new_segment').min()['new_bin'].max()

In [None]:
pn.rebinned_x_var.groupby('new_segment').min()['new_bin'].max()

In [None]:
pn.rebinned_y_var.groupby('new_segment').size()

In [None]:
pn.rebinned_y_var

In [None]:
# example trajectories
for traj in pn.trajectories[:]:
    print(traj.shape)

In [None]:
pn.gpfa_neural_trials[0].shape

In [None]:
pn.get_gpfa_traj(latent_dimensionality=7, exists_ok=False)

In [None]:
# for regression later
use_raw_spike_data_instead = False

pn.get_concat_data_for_regression(use_raw_spike_data_instead=False,
                                    use_lagged_raw_spike_data=False,
                                    apply_pca_on_raw_spike_data=False,
                                    num_pca_components=7)


pn.print_data_dimensions()

In [None]:
import pandas as pd
import numpy as np

# Assuming `pn.concat_behav_trials` and `pn.concat_neural_trials` are DataFrames:

# Convert the relevant columns to sets of tuples
behav_set = set(map(tuple, pn.concat_behav_trials[['new_segment', 'new_bin']].values))
neural_set = set(map(tuple, pn.concat_neural_trials[['new_segment', 'new_bin']].values))

# Compute the difference
diff = behav_set - neural_set
diff2 = neural_set - behav_set

In [None]:
diff2

In [None]:
df = pn.rebinned_behav_data[['new_segment', 'new_bin']]
df

In [None]:
pn.concat_neural_trials

In [None]:
pn.concat_behav_trials

In [None]:
pn.concat_neural_trials[['new_segment', 'new_bin']].drop_duplicates().shape

In [None]:
# example trajectories
for traj in pn.trajectories[:5]:
    print(traj.shape)

In [None]:
for traj in pn.behav_trials[:5]:
    print(traj.shape)

In [None]:
df = pn.rebinned_behav_data[pn.rebinned_behav_data['new_segment'].isin([45])]
df

In [None]:
pn.rebinned_behav_data.loc[700:730]

## point-wise regression on one var

In [None]:
from contextlib import contextmanager
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed
import sys
from data_wrangling import process_monkey_information, specific_utils, further_processing_class, specific_utils, general_utils
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from pattern_discovery import pattern_by_trials, pattern_by_points, make_ff_dataframe, ff_dataframe_utils, pattern_by_trials, pattern_by_points, cluster_analysis, organize_patterns_and_features, category_class
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from null_behaviors import curvature_utils, curv_of_traj_utils
import warnings
import os
import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import math
import seaborn as sns
import logging
from matplotlib import rc
from os.path import exists
from statsmodels.stats.outliers_influence import variance_inflation_factor
from elephant.gpfa import GPFA

import numpy as np
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import KFold, cross_val_score
from sklearn.preprocessing import StandardScaler


In [None]:
# x_cols = [col for col in pn.concat_neural_trials.columns if col.startswith('dim_')]
# x_df = pn.concat_neural_trials[x_cols].copy()

In [None]:
import statsmodels.api as sm
from sklearn.model_selection import train_test_split

for new_bin in pn.concat_neural_trials['new_bin'].unique():
    x_df2 = pn.concat_neural_trials[pn.concat_neural_trials['new_bin'] == new_bin]
    x_df2 = x_df2[[col for col in x_df2.columns if col.startswith('dim_')]]

    # Add intercept
    x_df2 = sm.add_constant(x_df2)

    y_df2 = pn.concat_behav_trials[pn.concat_behav_trials['new_bin'] == new_bin][['segment']].copy()

    # Train-test split
    x_train, x_test, y_train, y_test = train_test_split(
        x_df2, y_df2, test_size=0.2, random_state=42
    )

    # Fit OLS
    model = sm.OLS(y_train, x_train)
    results = model.fit()

    print(results.summary())

    # Evaluate on test set
    y_pred = results.predict(x_test)
    y_test_flat = y_test.squeeze()
    r2_test = 1 - ((y_test_flat - y_pred) ** 2).sum() / ((y_test_flat - y_test_flat.mean()) ** 2).sum()
    print(f"Test R² score: {r2_test:.4f}")
    
    # Create a comparison DataFrame
    comparison_df = pd.DataFrame({
        'y_test': y_test.squeeze().values,  # Ground truth
        'y_pred': y_pred                    # Model predictions
    })

    print(comparison_df.head(10))  # Show the first 10 rows

    break

In [None]:
import numpy as np
from scipy.stats import pearsonr

# Ensure both arrays are 1D
y_test_flat = y_test.squeeze().values
y_pred_flat = y_pred

# Manually compute R² (already done, for reference)
r2_test = 1 - ((y_test_flat - y_pred_flat) ** 2).sum() / ((y_test_flat - y_test_flat.mean()) ** 2).sum()

# Compute Pearson correlation coefficient (R)
if len(np.unique(y_test_flat)) > 1:
    r_test = np.corrcoef(y_test_flat, y_pred_flat)[0, 1]
    # Or alternatively: r_test, _ = pearsonr(y_test_flat, y_pred_flat)
else:
    r_test = np.nan  # Correlation is undefined when y is constant

print(f"Test R² score: {r2_test:.4f}")
print(f"Test R (Pearson correlation): {r_test:.4f}")

In [None]:
alphas = np.logspace(-6, 6, 13)
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=42)


model = RidgeCV(alphas=alphas, fit_intercept=True)
try:
    score = cross_val_score(
        model, x_df2, y_df2.values.ravel(), cv=kf, scoring='r2', n_jobs=1)
    print(score.mean())
except Exception:
    pass

## LR no CV

In [None]:
# Multivariate linear regression
pn.y_var_lr_df = neural_data_modeling.get_y_var_lr_df(
                pn.concat_neural_trials, pn.concat_behav_trials)

In [None]:
pn.y_var_lr_df.head(5)

## LR on ind var

In [None]:
for test_or_control in ['both']:
    x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control=test_or_control)
    y_var = y_var[['time_rel_to_stop', 'time_since_last_capture']].copy()
    # train test split
    X_train, X_test, y_train, y_test = train_test_split(x_var, y_var, test_size=0.2, random_state=42)
    # use linear regression
    for y_var_column in y_var.columns:
        summary_df, y_pred, results, r2_test = regression_utils.use_linear_regression(
            X_train, X_test, y_train[y_var_column], y_test[y_var_column], show_plot=True, y_var_name=y_var_column)
        print(summary_df)
        print(y_pred)
        print(results)
        print(r2_test)
    

## manually save scores_by_time_full_cv

In [None]:
latent_dimensionality = 7
cv_folds = 5
bin_width_str = f"{pn.bin_width:.4f}".rstrip(
    '0').rstrip('.').replace('.', 'p')
file_name = f'scores_bin{bin_width_str}_{pn.cur_or_nxt}_{pn.first_or_last}_d{latent_dimensionality}_cv{cv_folds}.csv'


time_resolved_cv_scores_gpfa_folder_path = os.path.join(
                    pn.gpfa_data_folder_path, "time_resolved_cv_scores_gpfa")
os.makedirs(time_resolved_cv_scores_gpfa_folder_path, exist_ok=True)
time_resolved_cv_scores_path = os.path.join(
            time_resolved_cv_scores_gpfa_folder_path, file_name)

time_resolved_cv_scores_gpfa.to_csv(time_resolved_cv_scores_path)

In [None]:
time_resolved_cv_scores_path

In [None]:
time_resolved_cv_scores_gpfa = time_resolved_cv_scores

In [None]:
file_name