# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils
from decision_making_analysis.compare_GUAT_and_TAFT import find_GUAT_or_TAFT_trials

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, get_stops_utils, psth_postprocessing, psth_stats, compare_events, dpca_utils
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm import stop_glm_fit, cv_stop_glm, plot_spikes
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_designs import binning_for_glm, lagged_design, stop_design, cluster_design, design_checks, history_design


import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

# retrieve data

In [None]:
#raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0413"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0321"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0329"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0403"

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0312"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0316"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
# pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
#     pn.planning_data_by_point)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

if not hasattr(pn, 'spikes_df'):
    pn.retrieve_or_make_monkey_data()
    pn.spikes_df = neural_data_processing.make_spikes_df(pn.raw_data_folder_path, pn.ff_caught_T_sorted,
                                                            sampling_rate=pn.sampling_rate)

# NEXT: try stop end time instead

# Get captures

In [None]:
reload(get_stops_utils)

In [None]:
pn.monkey_information.columns

In [None]:
stops_with_stats

In [None]:
# Example wiring (mirrors your original usage)
valid_captures_df, filtered_no_capture_stops_df, stops_with_stats = get_stops_utils.prepare_no_capture_and_captures(
    monkey_information=pn.monkey_information,
    closest_stop_to_capture_df=pn.closest_stop_to_capture_df,
    ff_caught_T_new=pn.ff_caught_T_new,
    capture_match_window=0.3,
    distance_thresh=25.0,
    distance_col="distance_from_ff_to_stop",
)


# Get misses

##  one

In [None]:
columns_to_add = ["stop_id", "stop_id_duration", "stop_id_start_time", "stop_id_end_time"]

pn.make_one_stop_w_ff_df()
one_stop_miss_df = pn.one_stop_w_ff_df[['first_stop_point_index', 'first_stop_time', 'latest_visible_ff', 'ff_distance', 'min_distance_from_adjacent_stops']].copy()
one_stop_miss_df.rename(columns={'first_stop_point_index': 'stop_point_index', 'first_stop_time': 'stop_time'}, inplace=True)
one_stop_miss_df[columns_to_add] = pn.monkey_information.loc[one_stop_miss_df['stop_point_index'], columns_to_add].values

## more

In [None]:
pn.make_or_retrieve_ff_dataframe()
pn.get_try_a_few_times_info()
pn.get_give_up_after_trying_info()


In [None]:
columns_to_add = ["stop_id", "stop_id_duration", "stop_id_start_time", "stop_id_end_time"]
shared_columns = ["stop_point_index", "stop_time"] + columns_to_add

# --- Build expanded + ordered tables for GUAT / TAFT ---
GUAT_expanded = get_stops_utils._expand_trials(pn.GUAT_trials_df, pn.monkey_information)
TAFT_expanded = get_stops_utils._expand_trials(pn.TAFT_trials_df, pn.monkey_information)

# add stop_id to GUAT_trials_df and TAFT_trials_df
GUAT_expanded[columns_to_add] = pn.monkey_information.loc[GUAT_expanded['stop_point_index'], columns_to_add].values
TAFT_expanded[columns_to_add] = pn.monkey_information.loc[TAFT_expanded['stop_point_index'], columns_to_add].values


GUAT = get_stops_utils._add_cluster_ordering(GUAT_expanded)
TAFT = get_stops_utils._add_cluster_ordering(TAFT_expanded)

# --- Per-cluster slices (consistent, vectorized) ---
# First stop in each cluster
GUAT_first = GUAT[GUAT["is_first"]].reset_index(drop=True)
TAFT_first = TAFT[TAFT["is_first"]].reset_index(drop=True)

# Last stop in each cluster
giveup_GUAT_last = GUAT[GUAT["is_last"]].reset_index(drop=True)
capture_TAFT_last = TAFT[TAFT["is_last"]].reset_index(drop=True)

# Middle stops (exclude first and last)
GUAT_middle = GUAT[GUAT["is_middle"]].reset_index(drop=True)
TAFT_middle = TAFT[TAFT["is_middle"]].reset_index(drop=True)

# “First several” = all but the last stop in each cluster
persist_GUAT_nonfinal = GUAT[GUAT["order_in_cluster"] < GUAT["cluster_size"] - 1].reset_index(drop=True)
persist_TAFT_nonfinal = TAFT[TAFT["order_in_cluster"] < TAFT["cluster_size"] - 1].reset_index(drop=True)

# Combine the “first several” from both, keep only columns you care about, then sort by index
both_nonfinal = (
    pd.concat(
        [
            persist_GUAT_nonfinal[shared_columns],
            persist_TAFT_nonfinal[shared_columns],
        ],
        ignore_index=True
    )
    .sort_values("stop_point_index")
    .reset_index(drop=True)
)

persist_both_first = pd.concat([GUAT_first[shared_columns], 
                         TAFT_first[shared_columns]])

both_middle = pd.concat([GUAT_middle[shared_columns], 
                         TAFT_middle[shared_columns]])

# Optional: if you also want “last several” (all but the first), it’s symmetrical:
# giveup_GUAT_last_several = GUAT[GUAT["order_in_cluster"] > 0].reset_index(drop=True)
# capture_TAFT_last_several = TAFT[TAFT["order_in_cluster"] > 0].reset_index(drop=True)

giveup_GUAT_last_plus_single_miss = pd.concat([giveup_GUAT_last[shared_columns], 
                                         one_stop_miss_df[shared_columns]])

all_misses = pd.concat([one_stop_miss_df[shared_columns], 
                                         GUAT_expanded[shared_columns],
                                         persist_TAFT_nonfinal[shared_columns]
                                         ])

all_first_misses = pd.concat(
    [one_stop_miss_df[shared_columns], GUAT_first[shared_columns], TAFT_first[shared_columns]],
    ignore_index=True
)

# captures not in TAFT last (assuming capture_TAFT_last is a subset of captures)
captures_minus_TAFT_last = compare_events.diff_by(valid_captures_df, capture_TAFT_last, key='stop_id')

# non-captures excluding those flagged as 'all_misses'
non_captures_minus_all_misses = compare_events.diff_by(filtered_no_capture_stops_df, all_misses, key='stop_id')


# ===COMPARE EVENTS===

In [None]:
# ---------- dataset registry (canonical) ----------
datasets_raw = {
    'captures': valid_captures_df.copy(),
    'no_capture': filtered_no_capture_stops_df.copy(),
    'persist_nonfinal': both_nonfinal.copy(),
    'persist_middle': both_middle.copy(),
    'giveup_GUAT_last': giveup_GUAT_last.copy(),
    'capture_TAFT_last': capture_TAFT_last.copy(),
    'giveup_single_miss': one_stop_miss_df.copy(),
    'persist_both_first': persist_both_first.copy(),
    'persist_GUAT_nonfinal': persist_GUAT_nonfinal.copy(),
    'persist_TAFT_nonfinal': persist_TAFT_nonfinal.copy(),
    'giveup_GUAT_last_plus_single_miss': giveup_GUAT_last_plus_single_miss.copy(),
    'captures_minus_TAFT_last': captures_minus_TAFT_last.copy(),
    'all_misses': all_misses.copy(),
    'non_captures_minus_all_misses': non_captures_minus_all_misses.copy(),
    'all_first_misses': all_first_misses.copy(),
}

# normalize schema + dedupe within each dataset
datasets = {k: compare_events.dedupe_within(compare_events.ensure_event_schema(v)) for k, v in datasets_raw.items()}

comparisons = compare_events.build_comparisons([
    {'a': 'captures', 'b': 'no_capture', 'key': 'captures_vs_no_capture'},

    {'a': 'persist_nonfinal', 'b': 'giveup_GUAT_last'},
    {'a': 'persist_middle',  'b': 'giveup_GUAT_last'},

    {'a': 'persist_nonfinal', 'b': 'giveup_GUAT_last_plus_single_miss'},
    {'a': 'persist_middle',   'b': 'giveup_GUAT_last_plus_single_miss'},

    {'a': 'giveup_single_miss', 'b': 'giveup_GUAT_last'},
    {'a': 'giveup_single_miss', 'b': 'persist_both_first'},

    {'a': 'persist_GUAT_nonfinal', 'b': 'persist_TAFT_nonfinal'},

    {'a': 'giveup_GUAT_last', 'b': 'capture_TAFT_last'},

    {'a': 'captures_minus_TAFT_last', 'b': 'capture_TAFT_last'},

    {'a': 'captures', 'b': 'all_misses'},

    {'a': 'non_captures_minus_all_misses', 'b': 'all_misses'},
    {'a': 'non_captures_minus_all_misses', 'b': 'all_first_misses'},
])


compare_events.validate(datasets, comparisons)


# new_seg_info

In [None]:
stops_with_stats['stop_time'] = stops_with_stats['stop_id_start_time']
stops_with_stats['prev_time'] = stops_with_stats['stop_id_end_time'].shift(1)
stops_with_stats['next_time'] = stops_with_stats['stop_id_start_time'].shift(-1)
new_seg_info = binning_for_glm.pick_stop_window(stops_with_stats,
                                                pre_s=0.2, post_s=1.0, min_pre_bins=1, min_post_bins=20, bin_dt=0.04)

if 'stop_id' not in pn.closest_stop_to_capture_df.columns:
    pn.closest_stop_to_capture_df = get_stops_utils.add_stop_id_to_closest_stop_to_capture_df(
        pn.closest_stop_to_capture_df,
        pn.monkey_information,
    )
    
if 'captured' not in new_seg_info.columns:
    pn.closest_stop_to_capture_df['captured'] = 1
    new_seg_info = new_seg_info.merge(pn.closest_stop_to_capture_df[['stop_id', 'captured']].drop_duplicates(), on='stop_id', how='left')
    new_seg_info['captured'] = new_seg_info['captured'].fillna(0)
    

# exp

In [None]:
binned_feats

In [None]:
reload(cluster_design)

In [None]:
pn.monkey_information.columns

In [None]:
stops_with_stats

# Get bin info

In [None]:
# 1) Build bins from your stop windows (gaps allowed)
bins_2d, meta = binning_for_glm.stops_windows_to_bins2d(new_seg_info, bin_dt=0.04, only_ok=False)

# 2) Get overlap assignments once
sample_idx, bin_idx_array, dt_array, n_bins = binning_for_glm.build_bin_assignments(
    pn.monkey_information['time'].to_numpy(),
    bins_2d,
)

# 3) Subselect raw samples once
monkey_information_sub = pn.monkey_information.iloc[sample_idx].copy()

# 3a) One pass to get exposure and used_bins
_dummy, exposure, used_bins = binning_for_glm.bin_timeseries_weighted(
    monkey_information_sub['time'].to_numpy(),  # any column of same length works
    dt_array, bin_idx_array, how='mean'
)

# 3b) Aggregate features with the SAME assignments
def agg_feat(col):
    vals = monkey_information_sub[col].to_numpy()
    out, _exp, _ub = binning_for_glm.bin_timeseries_weighted(vals, dt_array, bin_idx_array, how='mean')
    # Defensive checks: exposure/used_bins should match
    assert np.shares_memory(_exp, exposure) or np.allclose(_exp, exposure)
    assert np.array_equal(_ub, used_bins)
    return out

binned_feats = pd.DataFrame({
    'accel':     agg_feat('accel'),
    'speed':     agg_feat('speed'),
    'ang_speed': agg_feat('ang_speed'),
})

# Clean NaNs (optional: choose your policy)
binned_feats = binned_feats.replace([np.inf, -np.inf], np.nan).fillna(0.0)

# 3c) Keep bins with exposure > 0
mask_used = exposure > 0
pos = used_bins[mask_used]
binned_feats = binned_feats.iloc[mask_used].reset_index(drop=True)

meta_by_bin = meta.set_index('bin').sort_index()
meta_used   = meta_by_bin.loc[pos].reset_index()   # rows now match binned_feats


# 4) Bin spikes per cluster across ALL bins, then slice by pos
spike_counts, cluster_ids = binning_for_glm.bin_spikes_by_cluster(
    pn.spikes_df, bins_2d, time_col='time', cluster_col='cluster'
)

# Sanity checks
assert pos.size == binned_feats.shape[0]
assert spike_counts.shape[0] >= (pos.max() + 1)

binned_spikes = pd.DataFrame(
    spike_counts[pos, :],        # slice rows to align with pos
    columns=cluster_ids,         # cluster IDs as column labels
).reset_index(drop=True)

# Build the stop-aware design block (same helper we wrote earlier)
X_stop_df = stop_design.build_stop_design_from_meta(
    meta=meta,
    pos=pos,
    new_seg_info=new_seg_info,
    speed_used=binned_feats['speed'].values,
    history_mode='gated',  # or 'single' / 'sumdiff'
    include_columns=(
        'basis', 'history_gated', 'prepost', 'prepost*speed',
        'captured', 'basis*captured', #'prepost*captured',
        'time_since_prev_stop_post', 'time_to_next_stop_pre',
    )
)

cluster_df = cluster_design.build_cluster_features_workflow(
    meta_used[['stop_id', 'rel_center']], stops_with_stats,
    rel_time_col='rel_center',
    winsor_p=0.5,
    use_midbin_progress=True,
    zscore_progress=False,   # set True if you want progress in SD units
    zscore_rel_time=True
)
cluster_feats = [
        'is_clustered',
        'stop_is_first_in_cluster', 
        #'stop_is_last_in_cluster',
        'prev_gap_s_z',
        'next_gap_s_z',
        'cluster_duration_s_z',
        'cluster_progress_c', 'cluster_progress_c2',
        #'log_n_stops_in_cluster_z',      # optional
        'cluster_rel_time_s_z',          # optional (bin-level)
    ]


cols_to_add_from_stop_design = [c for c in X_stop_df.columns if c not in binned_feats.columns]
binned_feats.loc[:, cols_to_add_from_stop_design] = X_stop_df[cols_to_add_from_stop_design].to_numpy()  # equivalent to .values

cols_to_add_from_cluster_design = [c for c in cluster_feats if c not in binned_feats.columns]
binned_feats.loc[:, cols_to_add_from_cluster_design] = cluster_df[cols_to_add_from_cluster_design].to_numpy()  # equivalent to .values

offset_log = np.log(np.clip(exposure[mask_used], 1e-12, None))

# --- usage ---
binned_feats_sc, scaled_cols = binning_for_glm.selective_zscore(binned_feats)
binned_feats_sc = sm.add_constant(binned_feats_sc, has_constant='add')
print('Scaled columns:', scaled_cols)



In [None]:
binned_feats_sc[['cluster_progress_c', 'cluster_progress_c2']].describe()

# check df

In [None]:
binned_feats_sc.describe()

In [None]:
# Suppose binned_feats_sc is what build_stop_design_from_meta returned
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.design_checks import (
    check_near_constant, pairwise_high_corr, find_duplicate_columns,
    condition_number, rank_deficiency, compute_vif, suggest_columns_to_drop
)

# 0) quick sanity: NaN/inf
bad_any = ~np.isfinite(binned_feats_sc.to_numpy(dtype=float)).all()
print('NaN/inf present?', bad_any)

# 1) near-constant columns
near_const_cols, variances = check_near_constant(binned_feats_sc, tol=1e-12)
print('Near-constant:', near_const_cols)

# 2) duplicates
dups = find_duplicate_columns(binned_feats_sc)
print('Duplicate column groups:', dups)

# 3) correlation spikes
hits, corr = pairwise_high_corr(binned_feats_sc, thresh=0.98)
print('High-corr pairs (|r| >= 0.98):', hits)

# 4) condition number
kappa, svals = condition_number(binned_feats_sc)
print('Condition number:', kappa)

# 5) rank deficiency
rank, p, is_deficient = rank_deficiency(binned_feats_sc)
print(f'rank={rank} of {p} columns; deficient? {is_deficient}')

# 6) VIF
vif_report = compute_vif(binned_feats_sc)
print(vif_report.head(10))

# 7) Suggested columns to drop
to_drop = suggest_columns_to_drop(binned_feats_sc, corr_thresh=0.98, vif_thresh=30.0)
print('Suggested drops (order matters):', to_drop)

# Optional: apply the drop list
X_pruned = binned_feats_sc.drop(columns=to_drop, errors='ignore')


# GLM

## regular

In [None]:
reload(stop_glm_fit)

In [None]:
# df_X: predictors + offset (e.g., ['speed', 'accel', 'offset_log'])
# df_Y: responses per unit (e.g., columns = unit IDs)

# features = ['ang_speed']
# df_X = binned_feats[features + ['offset_log']]


df_X = binned_feats_sc
df_Y = binned_spikes


report = stop_glm_fit.glm_mini_report(
    df_X,
    df_Y,
    offset_log=offset_log,
    feature_names=None,       # infer automatically from df_X
    cluster_ids=None,         # use df_Y.columns
    alpha=0.05,
    delta_for_rr=1.0,
    forest_term='ang_speed',         # defaults to first feature
    forest_top_n=30,
    cov_type='HC1',
    show_plots=True,
    save_dir=None
)


In [None]:
cluster_id = 5
cv_stop_glm.plot_pred_vs_obs(report['results'][cluster_id], df_X, binned_spikes[cluster_id], offset_log)

In [None]:
coefs_df    = report['coefs_df']
metrics_df  = report['metrics_df']
pop_tests   = report['population_tests_df']
metrics_df

## control for FDR

In [None]:
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests

def add_fdr_and_rate_ratio(summary_df, p_col='p_wilcoxon', beta_col='beta_median'):
    out = summary_df.copy()
    ok = out[p_col].notna().values
    pvals = out.loc[ok, p_col].values
    rej, qvals, *_ = multipletests(pvals, method='fdr_bh')
    out.loc[ok, 'q_wilcoxon'] = qvals
    out.loc[ok, 'sig_wilcoxon'] = rej
    # interpretable magnitude: median rate ratio
    out['rate_ratio_median'] = np.exp(out[beta_col].astype(float))
    return out


In [None]:
add_fdr_and_rate_ratio(pop_tests)

## CV

In [None]:
reload(cv_stop_glm)

In [None]:
groups = meta_used['stop_id']  # or your session/stop-window IDs
scores = cv_stop_glm.cv_score_per_cluster(binned_feats_sc, binned_spikes, offset_log, groups, n_splits=5)
cv_stop_glm.plot_cv_scores(scores)
scores

# Inspect data

In [None]:
new_seg_info

In [None]:
seg_sub = new_seg_info[new_seg_info['captured'] > 0]
sns.histplot(new_seg_info['n_pre_bins'])
sns.histplot(seg_sub['n_pre_bins'])

In [None]:
sns.histplot(new_seg_info['n_post_bins'])
sns.histplot(seg_sub['n_post_bins'])

# Lagged design

## discrete lags

In [None]:
# choose window and binning
lag_min_s, lag_max_s, bin_dt = -0.30, 0.40, 0.02
lags = lagged_design.make_integer_lags(lag_min_s, lag_max_s, bin_dt)

df_X_design, df_Y_aligned = lagged_design.build_lagged_design_by_group(
    df_X=binned_feats,
    df_Y=binned_spikes,
    group_col='stop_id',
    predictors=['speed', 'accel', 'ang_speed'],
    offset_log=offset_log,
    order_col='time',             # optional: if you have a time column
    lags_bins=lags,
    basis_df=None,
    keep_cols=['session_id']      # optional passthrough
)

results, coefs_df, metrics_df = stop_glm_fit.fit_poisson_glm_per_cluster(
    df_X_design, df_Y_aligned,
    offset_log=offset_log,
    cov_type='HC1'
)


## raised-cosine basis

In [None]:
basis = lagged_design.make_raised_cosine_basis(
    n_basis=6,
    lag_min_s=-0.30,
    lag_max_s=0.40,
    bin_dt=0.02
)

df_X_basis, df_Y_aligned = lagged_design.build_lagged_design_by_group(
    df_X=df_X,
    df_Y=df_Y,
    group_col='stop_id',
    predictors=['speed', 'accel'],
    offset_log=offset_log,
    order_col='time',
    lags_bins=None,
    basis_df=basis,
)

results, coefs_df, metrics_df = stop_glm_fit.fit_poisson_glm_per_cluster(
    df_X_basis, df_Y_aligned,
    offset_log=offset_log,
    cov_type='HC1'
)


# plot_spaghetti_per_stop

## run func

In [None]:
cols = ['stop_id', 'rel_center', 't_left', 't_right']
binned_spikes2 = binned_spikes.copy()
binned_spikes2[cols] = meta_used[cols]

In [None]:
binned_spikes

In [None]:
# choose a unit column by name or int (e.g., 3)
# unit_col = 3  # or '3' if your columns are strings
for unit_col in binned_spikes.columns:
    df_rate = plot_spikes.make_rate_df_from_binned(binned_spikes2, unit_col)

    # plot (with gentle smoothing and pre-stop baseline subtraction)
    fig, ax, n = plot_spikes.plot_spaghetti_per_stop(
        df_rate,
        smooth_sigma_s=0.08,          # ~80 ms sigma (auto-converted to bins)
        # baseline_window=(-0.5, -0.1), # subtract mean pre-stop activity
        baseline_window=None,
        max_stops=None,               # or an int to limit how many lines
        median_label='median (all stops)',
        title=f'Unit {unit_col}: rate per stop'
    )
    plt.show()
    print(f'Plotted {n} stops.')
    


## run independent

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- params you can tweak ---
stop_col = 'stop_id'
time_col = 'rel_center'
rate_col = 'rate_hz'
smooth_sigma_bins = None          # e.g., 1.0 for ~1-bin Gaussian smoothing
smooth_sigma_s = 0.08             # or set in seconds and leave _bins=None
bin_width_s_hint = None           # set if you know your bin width (e.g., 0.04)
baseline_window = (-0.5, -0.1)    # set to None to disable baseline subtract
max_stops = None                  # e.g., 200 to limit lines
alpha = 0.3
lw = 1.2
show_median = True
median_lw = 2.2
median_label = 'median across stops'
title = 'Firing rate per stop (one line per stop)'
xlabel = 'Time from stop (s)'
ylabel = 'Rate (Hz)'

# --- helper: gaussian smoothing in bins ---
def _gaussian_smooth_1d(y, sigma_bins):
    if sigma_bins is None or sigma_bins <= 0:
        return y
    radius = max(1, int(np.ceil(3 * sigma_bins)))
    x = np.arange(-radius, radius + 1)
    w = np.exp(-(x**2) / (2 * sigma_bins**2))
    w /= w.sum()
    return np.convolve(y, w, mode='same')

# --- infer bin width if user gave smoothing in seconds ---
if smooth_sigma_s is not None and smooth_sigma_bins is None:
    if bin_width_s_hint is not None:
        bw = bin_width_s_hint
    else:
        tmp = df_rate.sort_values([stop_col, time_col])
        diffs = tmp.groupby(stop_col)[time_col].diff().dropna().to_numpy()
        diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
        bw = np.median(diffs) if diffs.size else 0.04
    smooth_sigma_bins = smooth_sigma_s / max(bw, 1e-9)

# --- optionally downselect stops ---
stops = df_rate[stop_col].unique().tolist()
if max_stops is not None and len(stops) > max_stops:
    stops = stops[:max_stops]
g = df_rate[df_rate[stop_col].isin(stops)].copy()

# --- plot each stop ---
fig, ax = plt.subplots(figsize=(8, 5))
lines_plotted = 0

for sid, df_s in g.groupby(stop_col, sort=True):
    y = df_s.sort_values(time_col)
    xvals = y[time_col].to_numpy()
    yvals = y[rate_col].to_numpy()

    # baseline subtract per stop
    if baseline_window is not None:
        t0, t1 = baseline_window
        m = (xvals >= t0) & (xvals < t1)
        base = yvals[m].mean() if m.any() else 0.0
        yvals = yvals - base

    # optional smoothing (in bins)
    if smooth_sigma_bins is not None and smooth_sigma_bins > 0:
        yvals = _gaussian_smooth_1d(yvals, smooth_sigma_bins)
        

    # safety: enforce equal length
    n = min(len(xvals), len(yvals))
    if n == 0:
        continue
    ax.plot(xvals[:n], yvals[:n], alpha=alpha, lw=lw)
    lines_plotted += 1

# --- median across stops at each time (pooled by exact time stamps) ---
if show_median and not g.empty:
    med = g.groupby(time_col)[rate_col].median().reset_index().sort_values(time_col)
    xmed = med[time_col].to_numpy()
    ymed = med[rate_col].to_numpy()
    if baseline_window is not None:
        t0, t1 = baseline_window
        m = (xmed >= t0) & (xmed < t1)
        base = ymed[m].mean() if m.any() else 0.0
        ymed = ymed - base
    if smooth_sigma_bins is not None and smooth_sigma_bins > 0:
        ymed = _gaussian_smooth_1d(ymed, smooth_sigma_bins)
    ax.plot(xmed, ymed, lw=median_lw, label=median_label)
    ax.legend(frameon=False, loc='best')

ax.axvline(0.0, ls='--', lw=1.0)
ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
ax.grid(True, alpha=0.25)
plt.tight_layout()

print(f'Plotted {lines_plotted} stops.')


# Appendix

## Debug ff dataframe

In [None]:
pn.make_or_retrieve_ff_dataframe()

In [None]:
pn.ff_dataframe

In [None]:
pn.ff_dataframe.shape

In [None]:
h5_file_pathway = os.path.join(os.path.join(
    pn.processed_data_folder_path, 'ff_dataframe.h5'))

h5_file_pathway = 'all_monkey_data/processed_data/monkey_Schro/data_0413/ff_dataframe.h5'

ff_dataframe = pd.read_hdf(h5_file_pathway, 'ff_dataframe')
print("Retrieved ff_dataframe from", h5_file_pathway)
ff_dataframe

## use concat_new_seg_info

In [None]:
new_seg_info['new_segment'] = np.arange(len(new_seg_info))

In [None]:
concat_seg_data = pn_utils.concat_new_seg_info(
    pn.monkey_information, new_seg_info, bin_width=0.04)

concat_seg_data['time_since_start_time'] = concat_seg_data['time'] - concat_seg_data['new_seg_start_time']
concat_seg_data['dt'] = np.minimum(concat_seg_data['time_since_start_time'], concat_seg_data['dt'])
