# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, ml_methods_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils
from neural_data_analysis.design_kits.design_by_segment import create_design_df, predictor_utils, other_feats
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event, pn_glm_utils
from neural_data_analysis.neural_analysis_tools.glm_tools.tpg import glm_bases, glm_plotting, glm_plotting2, glm_fit

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, psth_postprocessing, psth_stats, compare_events, dpca_utils
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_fit import stop_glm_fit, cv_stop_glm, glm_fit_utils, variance_explained
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_plotting import plot_spikes, plot_glm_fit, plot_tuning_func, compare_glm_fit
from neural_data_analysis.design_kits.design_around_event import event_binning, stop_design, cluster_design, design_checks
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_hyperparams import compare_glm_configs, glm_hyperparams_class
from neural_data_analysis.design_kits.design_by_segment import spike_history, rebin_segments
from neural_data_analysis.topic_based_neural_analysis.full_session import create_full_session_design
from neural_data_analysis.design_kits.design_by_segment import temporal_feats
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import get_stops_utils
from neural_data_analysis.topic_based_neural_analysis.full_session import selected_raw_data_features, selected_pn_design_features, selected_stop_design_features

import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload
import json

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
from numpy import pi

from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Retrieve data

## get data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

pn.rebin_data_in_new_segments(cur_or_nxt='cur', first_or_last='first', time_limit_to_count_sighting=2,
                                 start_t_rel_event=0, end_t_rel_event=1.5, rebinned_max_x_lag_number=2)


# Get full session design df

In [None]:
pn.monkey_information

In [None]:
# add num_ff_visible etc.
pn.make_or_retrieve_ff_dataframe()
pn.monkey_information = pn_utils.add_ff_visible_or_in_memory_info_by_point(
    pn.monkey_information, pn.ff_dataframe)
dt = pn.bin_width

new_seg_info = pd.DataFrame({
    'new_segment': 0,
    'new_seg_start_time': max(0, pn.ff_caught_T_sorted.min() - 1),
    'new_seg_end_time': pn.ff_caught_T_sorted.max(),
    'new_seg_duration': pn.ff_caught_T_sorted.max() - max(0, pn.ff_caught_T_sorted.min() - 1)
}, index=[0])

rebinned_monkey_data = rebin_segments.rebin_all_segments_local_bins(
            pn.monkey_information, new_seg_info, bin_width=pn.bin_width, respect_old_segment=False,
            add_bin_edges=True,
            )

trial_ids = np.repeat(0, len(rebinned_monkey_data))
rebinned_monkey_data = temporal_feats.add_stop_and_capture_columns(rebinned_monkey_data, trial_ids, pn.ff_caught_T_new)

fs_design_df, meta0, meta = create_full_session_design.get_initial_full_session_design_df(rebinned_monkey_data, dt, trial_ids)
fs_meta_groups = meta['groups']

# relevant pn design features

In [None]:
global_bins_2d = rebinned_monkey_data[['bin_left', 'bin_right']].values
pn_df = pn_utils.rebin_all_segments_global_bins(
    pn.planning_data_by_point,
    pn.new_seg_info,
    bins_2d=global_bins_2d,
    how='mean',
    respect_old_segment=True,
    require_full_bin=True,
    add_bin_edges=True,
    add_support_duration=True,
)

In [None]:


trial_ids = pn_df['new_segment']
dt = pn.bin_width
pn_df = temporal_feats.add_stop_and_capture_columns(pn_df, trial_ids, pn.ff_caught_T_new)
pn_design_df, pn_meta0, pn_meta = create_design_df.get_initial_design_df(pn_df, dt, trial_ids)
pn_meta_groups = pn_meta['groups']


# stop_design_features

In [None]:
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import assemble_stop_design, collect_stop_data

In [None]:

# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0301"

pn, datasets, comparisons = collect_stop_data.collect_stop_data_func(
    raw_data_folder_path)

globals().update(datasets)

captures_df, valid_captures_df, filtered_no_capture_stops_df, stops_with_stats = get_stops_utils.prepare_no_capture_and_captures(
    monkey_information=pn.monkey_information,
    closest_stop_to_capture_df=pn.closest_stop_to_capture_df,
    ff_caught_T_new=pn.ff_caught_T_new,
    distance_col="distance_from_ff_to_stop",
)

stops_with_stats['stop_time'] = stops_with_stats['stop_id_start_time']
stops_with_stats['prev_time'] = stops_with_stats['stop_id_end_time'].shift(1)
stops_with_stats['next_time'] = stops_with_stats['stop_id_start_time'].shift(-1)

new_seg_info = event_binning.make_new_seg_info_for_stop_design(stops_with_stats, pn.closest_stop_to_capture_df, pn.monkey_information)


events_with_stats = stops_with_stats[['stop_id','stop_cluster_id','stop_id_start_time','stop_id_end_time']].copy()
events_with_stats.rename(columns={'stop_id':'event_id', 'stop_cluster_id':'event_cluster_id', 
                                  'stop_id_start_time':'event_id_start_time', 
                                  'stop_id_end_time':'event_id_end_time'}, inplace=True)

stop_binned_spikes, stop_binned_feats, stop_offset_log, stop_meta_used, stop_meta_groups = assemble_stop_design.build_stop_design(new_seg_info, events_with_stats, 
                                                                             pn.monkey_information, 
                                                                             pn.spikes_df, pn.ff_dataframe, 
                                                                             datasets=datasets,
                                                                             add_ff_visible_info=True,
                                                                             global_bins_2d=global_bins_2d)
stop_binned_feats = assemble_stop_design.add_interaction_columns(stop_binned_feats)
stop_binned_feats_sc = assemble_stop_design.scale_binned_feats(stop_binned_feats)

stop_binned_feats_sc['bin'] = stop_meta_used['global_bin']




# Combine

In [None]:
def merge_design_blocks(fs_df, pn_df, stop_df):
    fs_cols = set(fs_df.columns) - {'bin'}
    pn_cols = set(pn_df.columns) - {'bin'}
    stop_cols = set(stop_df.columns) - {'bin'}

    print(f'Duplicated FS–PN columns ({len(fs_cols & pn_cols)}):')
    print(sorted(fs_cols & pn_cols))

    print(f'Duplicated FS–STOP columns ({len(fs_cols & stop_cols)}):')
    print(sorted(fs_cols & stop_cols))

    return (
        fs_df
            .merge(pn_df, on='bin', how='left', suffixes=('', '_pn'))
            .merge(stop_df, on='bin', how='left', suffixes=('', '_stop'))
            .fillna(0.0)
            .sort_values('bin')
            .reset_index(drop=True)
    )


In [None]:

# --- FS (monkey) design ---
fs_design_df['bin'] = rebinned_monkey_data['new_bin']

# --- PN design ---
pn_design_df_sub = pn_design_df[selected_pn_design_features.pn_design_predictors].copy()
pn_design_df_sub['bin'] = pn_df['new_bin']

# --- Stop-binned PN features ---
stop_design_df_sub = stop_binned_feats_sc[
    selected_stop_design_features.stop_design_predictors
].copy()
stop_design_df_sub['bin'] = stop_meta_used['global_bin']


merged_design_df = merge_design_blocks(
    fs_design_df,
    pn_design_df_sub,
    stop_design_df_sub,
)

merged_meta_groups = {
    **fs_meta_groups,
    **stop_meta_groups,
    **pn_meta_groups,
}



In [None]:
spike_counts, cluster_ids = event_binning.bin_spikes_by_cluster(
    pn.spikes_df, global_bins_2d, time_col='time', cluster_col='cluster'
)

binned_spikes = (
    pd.DataFrame(spike_counts, columns=cluster_ids)
    .reset_index(drop=True)
)


# Get x_pruned

In [None]:
cols_path = os.path.join(pn.planning_and_neural_folder_path, 'full_session', 'selected_cols.json')

try:
    with open(cols_path, 'r') as f:
        selected_cols = json.load(f)
    X_pruned = merged_design_df[selected_cols].copy()
    print('Loaded selected columns from file')
except:
    X_pruned, vif_report = design_checks.check_design(merged_design_df)
    os.makedirs(os.path.dirname(cols_path), exist_ok=True)
    with open(cols_path, 'w') as f:
        json.dump(X_pruned.columns.tolist(), f)
    print('Saved selected columns to file')

# Get spike history

In [None]:
bin_info = merged_design_df[['bin']].rename(columns={'bin': 'new_bin'})
bin_info['new_segment'] = 0
bin_info

In [None]:
dt = pn.bin_width
t_max = 0.20
spikes_df=pn.spikes_df

design_w_history, basis, colnames, merged_meta_groups = spike_history.build_spike_history_design(
    spikes_df=spikes_df,
    new_seg_info=new_seg_info,
    bin_info=bin_info,
    X_pruned=X_pruned,
    meta_groups=merged_meta_groups,
    dt=dt,
    t_max=t_max,
)


# GLM

## just behavioral vars

In [None]:
df_X = merged_design_df.copy()
df_Y = binned_spikes.copy()

In [None]:
exposure = np.repeat(pn.bin_width, len(df_Y))
offset_log = np.log(exposure)

In [None]:
reload(glm_fit_utils)
reload(stop_glm_fit)

report0 = stop_glm_fit.glm_mini_report(
    df_X=X_pruned, df_Y=df_Y, offset_log=offset_log,
    cov_type='HC1', 
    fast_mle=True,
    do_inference=True, 
    make_plots=True,
    show_plots=True,
)

In [None]:
coefs_df = report0['coefs_df']
coefs_df[(coefs_df['term'] == 'captured') & (coefs_df['sig_FDR'] == True)].sort_values('p', ascending=True)

In [None]:
coefs_df['refit_on_support'].value_counts()

In [None]:
coefs_df.info()

## both (behav and spike history)

In [None]:
cols_path = os.path.join(pn.planning_and_neural_folder_path, 'full_session', 'selected_cols_w_spike_history.json')
try:
    with open(cols_path, 'r') as f:
        selected_cols_w_history = json.load(f)
    X_pruned1 = design_w_history[selected_cols_w_history].copy()
    print(f'Loaded selected columns from {cols_path}')
except:
    os.makedirs(os.path.dirname(cols_path), exist_ok=True)
    X_pruned1, vif_report = design_checks.check_design(design_w_history)
    with open(cols_path, 'w') as f:
        json.dump(X_pruned1.columns.tolist(), f)
    print(f'Saved selected columns to {cols_path}')

In [None]:
reload(glm_fit_utils)
reload(stop_glm_fit)

report1 = stop_glm_fit.glm_mini_report(
    df_X=X_pruned1, df_Y=df_Y, 
    offset_log=offset_log,
    cov_type='HC1', 
    fast_mle=True,
    do_inference=True, 
    make_plots=True,
    show_plots=True,
    meta_groups=meta_groups
)

## just spike history

In [None]:
all_history_cols = [c for c in design_w_history.columns if (c.startswith('cluster_') 
                                                            and c not in binned_feats_sc.columns)]
cols_path = os.path.join(pn.planning_and_neural_folder_path, 'full_session', 'selected_spike_history_cols.json')
try:
    with open(cols_path, 'r') as f:
        selected_history_cols = json.load(f)
    X_pruned2 = design_w_history[selected_history_cols].copy()
    print(f'Loaded selected columns from {cols_path}')
except:
    os.makedirs(os.path.dirname(cols_path), exist_ok=True)
    X_pruned2, vif_report = design_checks.check_design(design_w_history[all_history_cols])
    with open(cols_path, 'w') as f:
        json.dump(X_pruned2.columns.tolist(), f)
    print(f'Saved selected columns to {cols_path}')

In [None]:
reload(glm_fit_utils)
reload(stop_glm_fit)

report2 = stop_glm_fit.glm_mini_report(
    df_X=X_pruned2, df_Y=df_Y, 
    offset_log=offset_log,
    cov_type='HC1', 
    fast_mle=True,
    do_inference=True, 
    make_plots=True,
    show_plots=True,
    meta_groups=meta_groups
)

# Compare deviance explained

## In-sample

In [None]:
metrics_by_model = {
    'Behavior only': report0['metrics_df'],
    'Behavior + history': report1['metrics_df'],
    'History only': report2['metrics_df'],
}

compare_glm_fit.plot_insample_model_comparison(metrics_by_model)


## CV

In [None]:
metrics_by_model = {
    'Behavior only': report0['metrics_df'],
    'Behavior + history': report1['metrics_df'],
    'History only': report2['metrics_df'],
}

compare_glm_fit.plot_cv_model_comparison(metrics_by_model)


# Deviance explained

## in sample

In [None]:
plot_glm_fit.plot_insample_model_diagnostics(
    report0['metrics_df'],
)

In [None]:
plot_glm_fit.plot_insample_model_diagnostics(
    report1['metrics_df'],
)

In [None]:
plot_glm_fit.plot_insample_model_diagnostics(
    report2['metrics_df'],
)

In [None]:
plot_glm_fit.plot_cv_model_diagnostics(
    report0['metrics_df'],
    bins=20,
    show=True,
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

metrics_df = report['metrics_df'].copy()

# ---- derived quantities ----
metrics_df['ll_improvement'] = metrics_df['llf'] - metrics_df['llnull']
metrics_df['ll_improvement_per_obs'] = metrics_df['ll_improvement'] / metrics_df['n_obs']

# ---- figure layout ----
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes = axes.ravel()

# ========== 1. Deviance explained distribution ==========
axes[0].hist(metrics_df['deviance_explained'], bins=20)
axes[0].axvline(metrics_df['deviance_explained'].median(), linestyle='--')
axes[0].set_xlabel('Deviance explained')
axes[0].set_ylabel('Number of neurons')
axes[0].set_title('Model performance (deviance explained)')

# ========== 2. McFadden R² distribution ==========
axes[1].hist(metrics_df['mcfadden_R2'], bins=20)
axes[1].axvline(metrics_df['mcfadden_R2'].median(), linestyle='--')
axes[1].set_xlabel('McFadden $R^2$')
axes[1].set_ylabel('Number of neurons')
axes[1].set_title('Pseudo-$R^2$ distribution')

# ========== 3. Deviance explained vs McFadden R² ==========
axes[2].scatter(
    metrics_df['deviance_explained'],
    metrics_df['mcfadden_R2'],
    alpha=0.7
)
axes[2].set_xlabel('Deviance explained')
axes[2].set_ylabel('McFadden $R^2$')
axes[2].set_title('Consistency check')

# ========== 4. Deviance explained vs null deviance ==========
axes[3].scatter(
    metrics_df['null_deviance'],
    metrics_df['deviance_explained'],
    alpha=0.7
)
axes[3].set_xlabel('Null deviance (rate / variability proxy)')
axes[3].set_ylabel('Deviance explained')
axes[3].set_title('Dependence on firing statistics')

# ========== 5. Log-likelihood improvement ==========
axes[4].hist(metrics_df['ll_improvement'], bins=20)
axes[4].axvline(0, linestyle='--')
axes[4].set_xlabel('Log-likelihood improvement')
axes[4].set_ylabel('Number of neurons')
axes[4].set_title('Improvement over null model')

# ========== 6. LL improvement per observation ==========
axes[5].hist(metrics_df['ll_improvement_per_obs'], bins=20)
axes[5].axvline(0, linestyle='--')
axes[5].set_xlabel('Δ log-likelihood per observation')
axes[5].set_ylabel('Number of neurons')
axes[5].set_title('Predictive gain (normalized)')

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

metrics_df = report['metrics_df'].copy()

# ---- derived quantities ----
metrics_df['ll_improvement'] = metrics_df['llf'] - metrics_df['llnull']
metrics_df['ll_improvement_per_obs'] = metrics_df['ll_improvement'] / metrics_df['n_obs']

# ---- figure layout ----
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
axes = axes.ravel()

# ========== 1. Deviance explained distribution ==========
axes[0].hist(metrics_df['deviance_explained'], bins=20)
axes[0].axvline(metrics_df['deviance_explained'].median(), linestyle='--')
axes[0].set_xlabel('Deviance explained')
axes[0].set_ylabel('Number of neurons')
axes[0].set_title('Model performance (deviance explained)')

# ========== 2. McFadden R² distribution ==========
axes[1].hist(metrics_df['mcfadden_R2'], bins=20)
axes[1].axvline(metrics_df['mcfadden_R2'].median(), linestyle='--')
axes[1].set_xlabel('McFadden $R^2$')
axes[1].set_ylabel('Number of neurons')
axes[1].set_title('Pseudo-$R^2$ distribution')

# ========== 3. Deviance explained vs McFadden R² ==========
axes[2].scatter(
    metrics_df['deviance_explained'],
    metrics_df['mcfadden_R2'],
    alpha=0.7
)
axes[2].set_xlabel('Deviance explained')
axes[2].set_ylabel('McFadden $R^2$')
axes[2].set_title('Consistency check')

# ========== 4. Deviance explained vs null deviance ==========
axes[3].scatter(
    metrics_df['null_deviance'],
    metrics_df['deviance_explained'],
    alpha=0.7
)
axes[3].set_xlabel('Null deviance (rate / variability proxy)')
axes[3].set_ylabel('Deviance explained')
axes[3].set_title('Dependence on firing statistics')

# ========== 5. Log-likelihood improvement ==========
axes[4].hist(metrics_df['ll_improvement'], bins=20)
axes[4].axvline(0, linestyle='--')
axes[4].set_xlabel('Log-likelihood improvement')
axes[4].set_ylabel('Number of neurons')
axes[4].set_title('Improvement over null model')

# ========== 6. LL improvement per observation ==========
axes[5].hist(metrics_df['ll_improvement_per_obs'], bins=20)
axes[5].axvline(0, linestyle='--')
axes[5].set_xlabel('Δ log-likelihood per observation')
axes[5].set_ylabel('Number of neurons')
axes[5].set_title('Predictive gain (normalized)')

plt.tight_layout()
plt.show()


# Appendix

## Select relevant raw features


In [None]:

rebinned_monkey_data['stop'] = (rebinned_monkey_data['stop_id'] > 0)
rebinned_monkey_data_sub = rebinned_monkey_data[selected_raw_data_features.selected_kinematics_features]