# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, ml_methods_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils
from neural_data_analysis.neural_analysis_tools.glm_tools.prep_predictors import predictor_bases, predictor_utils, other_feats
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.glm_tools.tpg import glm_bases, glm_plotting, glm_plotting2, glm_fit

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, get_stops_utils, psth_postprocessing, psth_stats, compare_events, dpca_utils
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_fit import stop_glm_fit, cv_stop_glm, glm_fit_utils, variance_explained
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_plotting import plot_spikes, plot_glm_fit, plot_tuning_func
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_designs import binning_for_glm, lagged_design, stop_design, cluster_design, design_checks, history_design
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_hyperparams import compare_glm_configs, glm_hyperparams_class



import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Retrieve data

## get data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"

In [None]:
reload(pn_helper_class)

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
    pn.planning_data_by_point)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

pn.rebin_data_in_new_segments(cur_or_nxt='cur', first_or_last='first', time_limit_to_count_sighting=2,
                                 pre_event_window=0, post_event_window=1.5, rebinned_max_x_lag_number=2)

for col in ['cur_vis', 'nxt_vis', 'cur_in_memory', 'nxt_in_memory']:
    pn.rebinned_y_var[col] = (pn.rebinned_y_var[col] > 0).astype(int)

In [None]:
# reduce_y_var_lags = False
# planning_data_by_point_exists_ok = False
# y_data_exists_ok = True

# pn = pn_aligned_by_seg.PlanningAndNeuralSegmentAligned(raw_data_folder_path=raw_data_folder_path)
# pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
# pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
#     pn.planning_data_by_point)
# # pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)
# pn.rebin_data_in_new_segments(segment_duration=2, rebinned_max_x_lag_number=2)

# for col in ['cur_vis', 'nxt_vis', 'cur_in_memory', 'nxt_in_memory']:
#     pn.rebinned_y_var[col] = (pn.rebinned_y_var[col] > 0).astype(int)

In [None]:
pn.rebinned_y_var.columns.tolist()

# try stop glm

In [None]:
pn.rebinned_y_var.groupby('new_segment').size()

In [None]:
reload(predictor_bases)
reload(other_feats)

data = pn.rebinned_y_var.copy()
trial_ids = data['new_segment']
dt = pn.bin_width

design_df, meta0, meta = predictor_bases.get_initial_design_df(data, dt, trial_ids)

cluster_num = 3
y = pn.rebinned_x_var[f'cluster_{cluster_num}']

# design_df, meta = predictor_bases.add_spike_history(
#     design_df, y, meta0['trial_ids'], dt,
#     n_basis=4, t_max=0.20, edge='zero',
#     prefix='spk_hist', style='bjk',
#     meta=meta
# )

chk = predictor_utils.check_design_vs_bases(design_df, meta, strict=True)
assert chk['ok'], chk['problems']


reload(glm_fit_utils)
reload(stop_glm_fit)



## check data

In [None]:
# Collinearity & rank issues
import numpy as np
X = design_df.to_numpy(dtype=float)
u, s, vT = np.linalg.svd(X, full_matrices=False)
cond = s.max()/max(s.min(), 1e-12)
print('cond=', cond)


In [None]:
meta['groups']

In [None]:
# design_df.describe().T

## fit

In [None]:
df_X = design_df
print('df_X.shape:', df_X.shape)
cluster_cols = [col for col in pn.rebinned_x_var.columns if col.startswith('cluster_')]
df_Y = pn.rebinned_x_var[cluster_cols]
df_Y.columns = df_Y.columns.str.replace('cluster_', '').astype(int)


In [None]:
exposure = np.full_like(y, fill_value=pn.bin_width, dtype=float)
offset_log = np.log(exposure)

report = stop_glm_fit.glm_mini_report(
    df_X=df_X, df_Y=df_Y, offset_log=offset_log,
    cov_type='HC1',            # or 'nonrobust' for even faster
    fast_mle=True,             # << use the ultra-fast path
    do_inference=True,        # skip FDR/ratios/pop-tests
    make_plots=True,          # skip figure creation
    show_plots=True,          # nothing to display
)


In [None]:
coefs_df    = report['coefs_df']
metrics_df  = report['metrics_df']
pop_tests   = report['population_tests_df']
metrics_df

## VE

In [None]:
reload(variance_explained)
df_Y_pred = variance_explained.build_df_Y_pred_from_results(
    results=report['results'],
    df_X=df_X,
    offset_log=offset_log,
    df_Y=df_Y
)

# sanity checks
assert df_Y_pred.shape == df_Y.shape
assert np.isfinite(df_Y_pred.values).all()

In [None]:
# Convert observed/predicted DataFrames to arrays
X = df_Y.to_numpy()      # observed counts
X_hat = df_Y_pred.to_numpy() # predicted expected counts
#stop_ids = meta_used['stop_id'].to_numpy()
stop_ids = pn.rebinned_x_var['new_segment'].to_numpy()

ve_pop, k_eff = variance_explained.population_VE_in_PCspace(X, X_hat, k=10, center='neuron')


In [None]:
ve_per_neuron, ve_mean = variance_explained.single_neuron_temporal_VE(X, X_hat, aggregate='mean')
ve_median = float(np.median(ve_per_neuron))

print(f'Mean single-neuron VE:   {ve_mean:.3f}')
print(f'Median single-neuron VE: {ve_median:.3f}')

variance_explained.plot_single_neuron_VE_hist(ve_per_neuron)

ve_pop, k_eff = variance_explained.population_VE_in_PCspace(X, X_hat, k=10, center='neuron')
print(f'Population VE in {k_eff} PCs: {ve_pop:.3f}')

variance_explained.plot_population_VE_bar(ve_pop)


In [None]:
np.sort(ve_per_neuron)

In [None]:
# Per-stop breakdown (good for figures)
per_stop_df = variance_explained.per_stop_breakdown(X, X_hat, stop_ids=stop_ids, k=10)
print(per_stop_df.head())

summary_metrics = {
    'VE_population_PC': ve_pop,
    'VE_single_unit_mean': ve_mean,
    'VE_single_unit_median': ve_median,
    'PCs_used': k_eff
}
print(pd.Series(summary_metrics))


In [None]:
stop!

## Plot spikes

In [None]:
reload(plot_spikes)

In [None]:
cluster_idx = 4

pn.rebinned_y_var['rel_time'] = pn.rebinned_y_var['time'] - pn.rebinned_y_var['new_seg_start_time']

for stop_id in range(40, 52):
    # If your GLM used offset_log = np.log(exposure_s), you can omit exposure_s:
    plot_spikes.plot_observed_vs_predicted_stop(
        binned_feats_sc=df_X,
        binned_spikes=df_Y,
        meta_used=pn.rebinned_y_var,
        offset_log=offset_log,
        model_res=report['results'][cluster_idx],   # GLM for cluster 0
        cluster_idx=cluster_idx,
        seg_id=stop_id,
        time_col='rel_time',
        seg_col='new_segment'
    )
plt.show()

# fit single neuron (w spike history)

## glm data

In [None]:
data = pn.rebinned_y_var.copy()
trial_ids = data['new_segment']
dt = pn.bin_width

design_df, meta0, meta = predictor_bases.get_initial_design_df(data, dt, trial_ids)

cluster_num = 3
y = pn.rebinned_x_var[f'cluster_{cluster_num}']

design_df, meta = predictor_bases.add_spike_history(
    design_df, y, meta0['trial_ids'], dt,
    n_basis=4, t_max=0.20, edge='zero',
    prefix='spk_hist', style='bjk',
    meta=meta
)

chk = predictor_utils.check_design_vs_bases(design_df, meta, strict=True)
assert chk['ok'], chk['problems']


In [None]:
design_df.shape

## fit

In [None]:
res = glm_fit.fit_poisson_glm_trials(
    design_df, y,
    dt=dt,
    trial_ids=meta0['trial_ids'],
    add_const=False,            # we already inserted 'const' in design_df
    cluster_se=True
)


## check columns

In [None]:
df2, info = predictor_utils.drop_aliased_columns(design_df)
info['dropped']

In [None]:
# 1) rank of the history block
H = design_df.filter(regex=r'^spk_hist:').to_numpy()
print("hist shape:", H.shape, "rank:", np.linalg.matrix_rank(H))

# 2) condition / singular values
sv = np.linalg.svd(H, compute_uv=False)
print("singular values:", sv)

# 3) is that column almost in the span of the others?
import numpy as np
c = design_df['spk_hist:b0:3'].to_numpy()
R = design_df[[c for c in design_df.columns if c.startswith('spk_hist:') and c != 'spk_hist:b0:3']].to_numpy()
beta, *_ = np.linalg.lstsq(R, c, rcond=None)
resid = c - R @ beta
print("relative residual:", np.linalg.norm(resid)/max(np.linalg.norm(c), 1e-12))


## check fr

In [None]:
dt = 0.04
rates = df_Y.sum(axis=0) / (len(df_Y) * dt)
# rates is a Series indexed by unit, in Hz
rates

In [None]:
import pandas as pd

def firing_rates_from_df(spikes_df, time_col='time', cluster_col='cluster'):
    start_s = spikes_df[time_col].min()
    end_s = spikes_df[time_col].max()
    duration = end_s - start_s

    counts = spikes_df.groupby(cluster_col).size()
    rates_hz = counts / duration
    return rates_hz.rename('rate_hz').reset_index()

firing_rates_from_df(pn.spikes_df)

## plot

In [None]:
theta, f, std, info =glm_plotting.plot_angle_tuning_function(
    res, design_df, meta,
    base_prefix='cur_ff_angle',
    M=None,                 # auto-detect harmonics
    polar=False,             # pretty polar plot
    z=1.96                  # 95% CI
)

theta, f, std, info =glm_plotting.plot_angle_tuning_function(
    res, design_df, meta,
    base_prefix='cur_ff_angle',
    M=None,                 # auto-detect harmonics
    polar=True,             # pretty polar plot
    z=1.96                  # 95% CI
)


In [None]:
glm_plotting.plot_fitted_kernels(
    res, design_df, meta, dt,
    prefixes=['cur_vis_on','cur_vis_off','nxt_vis_on','nxt_vis_off','spk_hist'],
    z=1.96  # 95% CI
)

## across neurons

In [None]:
cluster_cols = [col for col in pn.rebinned_x_var.columns if col.startswith('cluster_')]
cluster_nums = [int(col.split('_')[1]) for col in cluster_cols]

In [None]:
results = []
designs = []
B_hist_ref = None

for cluster in cluster_nums:
    # .....
    
    
    
    
    # Fit GLM (cluster-robust by trial)
    res = fit_poisson_glm_trials(design_df, y_fit, dt, trial_ids, add_const=True, l2=0.0, cluster_se=False)
    results.append(res)
    designs.append(design_df)

meta = {"B_hist": B_hist_ref}

# Collect population history kernels
hist_df = collect_history_kernels_across_neurons(results, designs, meta, dt)

# Plot overlays + heatmap
plot_history_kernels_population(hist_df, overlay_mean=True, heatmap=True, max_overlays=50)

In [None]:
# Suppose you looped over N neurons and saved:
results = [res_n0, res_n1, ..., res_nN]         # statsmodels results
designs = [X_n0_df, X_n1_df, ..., X_nN_df]      # matching design DataFrames
meta    = meta_from_any_single_fit              # must contain 'B_hist'
dt      = 0.01                                  # your bin size (s)

hist_df = collect_history_kernels_across_neurons(
    results, designs, meta, dt, neuron_ids=None  # or e.g. list of unit IDs
)


# Overlay individual kernels (up to max_overlays) + population mean ± 95% CI
plot_history_kernels_population(hist_df, overlay_mean=True, heatmap=False, max_overlays=60)

# Heatmap only (neuron × lag)
plot_history_kernels_population(hist_df, overlay_mean=False, heatmap=True)

# Both
plot_history_kernels_population(hist_df, overlay_mean=True, heatmap=True)


