# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, ml_methods_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils
from neural_data_analysis.neural_analysis_tools.glm_tools import glm_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
import sys
import math
import gc
import subprocess
from pathlib import Path
from importlib import reload

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
from numpy import pi

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
print("done")

%load_ext autoreload
%autoreload 2

# Retrieve data

## get data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"

In [None]:
reload(pn_helper_class)

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_seg.PlanningAndNeuralSegmentAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
    pn.planning_data_by_point)
pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

# get planning_data by segment

In [None]:
pn.prepare_seg_aligned_data(segment_duration=2, rebinned_max_x_lag_number=2)

In [None]:
pn.rebinned_x_var

In [None]:
# for regression later
use_raw_spike_data_instead = False

pn.get_concat_data_for_regression(use_raw_spike_data_instead=True,
                                    use_lagged_raw_spike_data=False,
                                    apply_pca_on_raw_spike_data=False,
                                    num_pca_components=7)
pn.print_data_dimensions()

# get train and test trials

In [None]:
train_trials

# GLM most recent

## synthetic data

In [None]:
n_trials=8
trial_len=300
dt=0.01
seed=7
data = simulate_multiff_trials(n_trials=n_trials, trial_len=trial_len, dt=dt, seed=seed)
res, design_df, metrics, meta = fit_multiff_glm(
    dt=data["dt"], trial_ids=data["trial_ids"],
    cur_vis=data["cur_vis"], nxt_vis=data["nxt_vis"],
    cur_dist=data["cur_dist"], nxt_dist=data["nxt_dist"],
    cur_angle=data["cur_angle"], nxt_angle=data["nxt_angle"],
    heading=data["heading"], speed=data["speed"], curvature=data["curvature"],
    spike_counts=data["spike_counts"], l2=0.0, use_trial_FE=True, cluster_se=False,
)
print(res.summary())
print("Overall deviance:", metrics["deviance"])  
print("Pseudo-R^2:", metrics["pseudo_R2"])  
print("Per-trial deviance (head):", metrics["per_trial_deviance"].head())
cv = fit_and_score_cv(design_df, data["spike_counts"], data["dt"], data["trial_ids"], n_splits=5, l2=0.0, cluster_se=False)
print("Trial-wise CV:", cv)


In [None]:
plot_fitted_kernels(res, design_df, meta, dt)

In [None]:
debug_all_kernels_flat(res, design_df, data["spike_counts"], data["trial_ids"], meta, dt)

## functions

In [None]:
"""
Trial-aware Poisson GLM for multiFF neural spike data (statsmodels)
------------------------------------------------------------------
End-to-end, runnable module:
- Raised-cosine bases
- Trial-aware stimulus & spike-history design (no cross-trial leakage)
- Optional trial fixed effects or cluster-robust SEs by trial
- Poisson fitting (statsmodels.GLM), prediction, metrics, trial-wise CV
- MultiFF adapter with visibility/distance/angle/heading features
- Stable simulators using capped Poisson intensities
- Demos: run `python this_file.py` or call `run_multiff_demo()`
"""
from __future__ import annotations

import numpy as np
import pandas as pd
from typing import Dict, Iterable, List, Optional, Tuple
from scipy import signal
from scipy.linalg import toeplitz
import statsmodels.api as sm

# =====================
# Helpers & bases
# =====================

def _unique_trials(trial_ids: np.ndarray) -> np.ndarray:
    return np.unique(np.asarray(trial_ids))


def raised_cosine_basis(n_basis: int, t_max: float, dt: float, *, t_min: float = 0.0,
                        log_spaced: bool = True, eps: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
    """Causal raised-cosine basis that tiles [t_min, t_max].
    Returns lags (L,), B (L x K) with unit-area columns (sum * dt = 1).
    """
    lags = np.arange(0.0, t_max + 1e-12, dt)
    K = int(n_basis)

    def warp(x):
        return np.log(x + eps) if log_spaced else x

    W = warp(lags)
    W_min, W_max = warp(t_min), warp(t_max)
    centers = np.linspace(W_min, W_max, K)
    delta = centers[1] - centers[0] if K > 1 else (W_max - W_min + 1e-12)
    width = delta * 1.5

    B = []
    for c in centers:
        arg = (W - c) / width
        bk = np.cos(np.clip(arg, -np.pi, np.pi))
        bk[np.abs(arg) > np.pi] = 0.0
        bk = np.maximum(bk, 0.0)
        bk[lags < t_min] = 0.0
        area = bk.sum() * dt
        if area > 0:
            bk = bk / area
        B.append(bk)
    B = np.column_stack(B) if B else np.zeros((len(lags), 0))
    return lags, B


def safe_poisson_lambda(eta: float | np.ndarray, dt: float, *, max_rate_hz: float = 200.0) -> np.ndarray:
    """Convert log-rate `eta` (per-second) to expected bin count lambda, with a cap."""
    log_min = np.log(1e-6)
    log_max = np.log(max_rate_hz)
    eta_clipped = np.clip(eta, log_min, log_max)
    rate_hz = np.exp(eta_clipped)
    return rate_hz * dt


def wrap_angle(theta: np.ndarray) -> np.ndarray:
    return (theta + np.pi) % (2 * np.pi) - np.pi


def angle_sin_cos(theta: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    th = wrap_angle(theta)
    return np.sin(th), np.cos(th)


def onset_from_mask_trials(mask: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    mask = (mask > 0).astype(int)
    on = np.zeros_like(mask, dtype=float)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        m = mask[idx]
        d = np.diff(np.r_[0, m])
        on[idx] = (d == 1).astype(float)
    return on


def offset_from_mask_trials(mask: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    mask = (mask > 0).astype(int)
    off = np.zeros_like(mask, dtype=float)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        m = mask[idx]
        d = np.diff(np.r_[m, 0])
        off[idx] = (d == 1).astype(float)
    return off

# =====================
# Trial-aware design builders
# =====================

def lagged_design_from_signal_trials(x: np.ndarray, basis: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    T = len(x)
    L, K = basis.shape
    Xk = np.zeros((T, K))
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        xt = x[idx]
        for k in range(K):
            y = signal.convolve(xt, basis[:, k], mode="full")[: len(idx)]
            Xk[idx, k] = y
    return Xk


def spike_history_design_with_trials(y_counts: np.ndarray, basis: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    T = len(y_counts)
    L, K = basis.shape
    Xh = np.zeros((T, K))
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        y_tr = y_counts[idx]
        col0 = np.r_[0, y_tr[:-1]]  # strictly past
        toepl = toeplitz(col0, np.zeros(L))
        Xh[idx, :] = toepl @ basis
    return Xh


def build_glm_design_with_trials(
    dt: float,
    trial_ids: np.ndarray,
    stimulus_dict: Optional[Dict[str, np.ndarray]] = None,
    stimulus_basis_dict: Optional[Dict[str, np.ndarray]] = None,
    spike_counts: Optional[np.ndarray] = None,
    history_basis: Optional[np.ndarray] = None,
    extra_covariates: Optional[Dict[str, np.ndarray]] = None,
    use_trial_FE: bool = True,
) -> Tuple[pd.DataFrame, Optional[np.ndarray]]:
    T = len(trial_ids)
    cols: List[np.ndarray] = []
    names: List[str] = []

    if stimulus_dict is not None:
        for name, x in stimulus_dict.items():
            if stimulus_basis_dict is not None and name in stimulus_basis_dict:
                B = stimulus_basis_dict[name]
                Xk = lagged_design_from_signal_trials(x, B, trial_ids)
                for k in range(Xk.shape[1]):
                    cols.append(Xk[:, k]); names.append(f"{name}_rc{k+1}")
            else:
                cols.append(x); names.append(name)

    if spike_counts is not None and history_basis is not None:
        Xh = spike_history_design_with_trials(spike_counts, history_basis, trial_ids)
        for k in range(Xh.shape[1]):
            cols.append(Xh[:, k]); names.append(f"hist_rc{k+1}")

    if extra_covariates is not None:
        for n, v in extra_covariates.items():
            cols.append(v); names.append(n)

    X = np.column_stack(cols) if cols else np.zeros((T, 0))
    design_df = pd.DataFrame(X, columns=names)

    if use_trial_FE:
        trial_FE = pd.get_dummies(trial_ids, prefix="trial", drop_first=True)
        design_df = pd.concat([design_df, trial_FE], axis=1)

    y = spike_counts if spike_counts is not None else None
    return design_df, y

# =====================
# Fitting & metrics
# =====================

def add_intercept(X: np.ndarray) -> np.ndarray:
    return np.column_stack([np.ones(len(X)), X])



def fit_poisson_glm_trials(
    design_df: pd.DataFrame,
    y: np.ndarray,
    dt: float,
    trial_ids: np.ndarray,
    *,
    add_const: bool = True,
    l2: float = 0.0,
    cluster_se: bool = False,
):
    """Fit Poisson GLM keeping column names (DataFrame) so kernels map correctly.
    Uses exposure=dt. If l2>0, uses fit_regularized (SEs not available).
    If cluster_se=False and l2==0, uses cluster-robust SEs grouped by trial.
    """
    X_df = design_df.copy()
    if add_const:
        X_df = sm.add_constant(X_df, has_constant='add')  # preserves 'const' and column names
    exposure = np.full_like(y, fill_value=dt, dtype=float)

    model = sm.GLM(y, X_df, family=sm.families.Poisson(), exposure=exposure)
    if l2 > 0:
        res = model.fit_regularized(alpha=l2, L1_wt=0.0, maxiter=1000)
        return res
    else:
        if cluster_se:
            return model.fit(cov_type="cluster", cov_kwds={"groups": trial_ids})
        else:
            return model.fit()
        
def predict_mu(result, design_df: pd.DataFrame, dt: float, add_const: bool = True) -> np.ndarray:
    """Predict mean counts with exposure=dt, aligning columns to the fitted model."""
    X_df = design_df.copy()
    if add_const:
        X_df = sm.add_constant(X_df, has_constant='add')
    mu = result.predict(X_df, exposure=np.full(len(X_df), dt))
    return np.asarray(mu, dtype=float)


def poisson_deviance(y: np.ndarray, mu: np.ndarray) -> float:
    y = np.asarray(y, dtype=float)
    mu = np.asarray(mu, dtype=float)
    eps = 1e-12
    term = np.where(y > 0, y * np.log((y + eps) / (mu + eps)), 0.0)
    dev = 2.0 * np.sum(term - (y - mu))
    return float(dev)


def pseudo_R2(y: np.ndarray, mu_full: np.ndarray, mu_null: np.ndarray) -> float:
    eps = 1e-12
    ll_full = np.sum(y * np.log(mu_full + eps) - mu_full)
    ll_null = np.sum(y * np.log(mu_null + eps) - mu_null)
    return float(1.0 - ll_full / ll_null)


def per_trial_deviance(y: np.ndarray, mu: np.ndarray, trial_ids: np.ndarray) -> pd.DataFrame:
    dev = pd.DataFrame({"trial": trial_ids, "y": y, "mu": mu})
    eps = 1e-12
    dev["bin_dev"] = np.where(dev["y"] > 0, dev["y"] * np.log((dev["y"] + eps) / (dev["mu"] + eps)), 0.0) - (dev["y"] - dev["mu"])
    out = dev.groupby("trial", as_index=False)["bin_dev"].sum()
    out.rename(columns={"bin_dev": "trial_deviance"}, inplace=True)
    out["trial_deviance"] *= 2.0
    return out

# =====================
# CV utilities
# =====================

def trialwise_folds(trial_ids: np.ndarray, n_splits: int, *, shuffle: bool = False, random_state: Optional[int] = None) -> List[Tuple[np.ndarray, np.ndarray]]:
    rng = np.random.default_rng(random_state)
    trials = _unique_trials(trial_ids)
    if shuffle:
        trials = trials.copy(); rng.shuffle(trials)
    folds = np.array_split(trials, n_splits)
    out = []
    for k in range(n_splits):
        val_trials = set(folds[k])
        val_mask = np.isin(trial_ids, list(val_trials))
        train_mask = ~val_mask
        out.append((train_mask, val_mask))
    return out


def fit_and_score_cv(
    design_df: pd.DataFrame,
    y: np.ndarray,
    dt: float,
    trial_ids: np.ndarray,
    *,
    n_splits: int = 5,
    l2: float = 0.0,
    use_trial_FE: bool = True,
    cluster_se: bool = False,
    random_state: Optional[int] = 0,
) -> pd.DataFrame:
    folds = trialwise_folds(trial_ids, n_splits, shuffle=True, random_state=random_state)
    rows = []
    for i, (train_mask, val_mask) in enumerate(folds, start=1):
        X_train = design_df.iloc[train_mask]
        y_train = y[train_mask]
        tr_ids_train = trial_ids[train_mask]

        X_val = design_df.iloc[val_mask]
        y_val = y[val_mask]
        tr_ids_val = trial_ids[val_mask]

        res = fit_poisson_glm_trials(X_train, y_train, dt, tr_ids_train,
                                     add_const=True, l2=l2, cluster_se=False)
        mu_val = predict_mu(res, X_val, dt)
        mu_null = np.full_like(y_val, y_train.mean(), dtype=float)

        fold_dev = poisson_deviance(y_val, mu_val)
        fold_r2 = pseudo_R2(y_val, mu_val, mu_null)
        rows.append({"fold": i, "val_deviance": fold_dev, "val_pseudo_R2": fold_r2, "n_val_bins": int(val_mask.sum())})
    return pd.DataFrame(rows)

# =====================
# MultiFF adapter (features & end-to-end)
# =====================

def build_multiff_design(
    *,
    dt: float,
    trial_ids: np.ndarray,
    cur_vis: np.ndarray,
    nxt_vis: np.ndarray,
    cur_dist: np.ndarray,
    nxt_dist: np.ndarray,
    cur_angle: np.ndarray,
    nxt_angle: np.ndarray,
    heading: Optional[np.ndarray] = None,
    speed: Optional[np.ndarray] = None,
    curvature: Optional[np.ndarray] = None,
    spike_counts: Optional[np.ndarray] = None,
    use_trial_FE: bool = True,
) -> Tuple[pd.DataFrame, Optional[np.ndarray], Dict[str, np.ndarray]]:
    T = len(trial_ids)

    # Bases
    _, B_event = raised_cosine_basis(n_basis=6, t_max=0.60, dt=dt, t_min=0.0, log_spaced=True)
    _, B_short = raised_cosine_basis(n_basis=5, t_max=0.30, dt=dt, t_min=0.0, log_spaced=True)
    _, B_hist  = raised_cosine_basis(n_basis=5, t_max=0.20, dt=dt, t_min=dt,  log_spaced=True)

    # Onsets and gated features
    cur_on = onset_from_mask_trials(cur_vis, trial_ids)
    nxt_on = onset_from_mask_trials(nxt_vis, trial_ids)

    cur_dist_g = cur_dist * (cur_vis > 0)
    nxt_dist_g = nxt_dist * (nxt_vis > 0)

    cur_sin, cur_cos = angle_sin_cos(cur_angle)
    nxt_sin, nxt_cos = angle_sin_cos(nxt_angle)
    cur_sin *= (cur_vis > 0); cur_cos *= (cur_vis > 0)
    nxt_sin *= (nxt_vis > 0); nxt_cos *= (nxt_vis > 0)

    stimulus_dict: Dict[str, np.ndarray] = {
        "cur_on": cur_on,
        "nxt_on": nxt_on,
        "cur_dist": cur_dist_g,
        "nxt_dist": nxt_dist_g,
        "cur_angle_sin": cur_sin,
        "cur_angle_cos": cur_cos,
        "nxt_angle_sin": nxt_sin,
        "nxt_angle_cos": nxt_cos,
    }

    extra_covariates: Dict[str, np.ndarray] = {}
    if heading is not None:
        cur_align = np.cos(wrap_angle(heading - cur_angle)) * (cur_vis > 0)
        nxt_align = np.cos(wrap_angle(heading - nxt_angle)) * (nxt_vis > 0)
        stimulus_dict["cur_align"] = cur_align
        stimulus_dict["nxt_align"] = nxt_align

    if speed is not None:   extra_covariates["speed"] = speed
    if curvature is not None: extra_covariates["curvature"] = curvature

    stimulus_basis_dict: Dict[str, np.ndarray] = {
        "cur_on": B_event, "nxt_on": B_event,
        "cur_dist": B_short, "nxt_dist": B_short,
        "cur_angle_sin": B_short, "cur_angle_cos": B_short,
        "nxt_angle_sin": B_short, "nxt_angle_cos": B_short,
        **({"cur_align": B_short, "nxt_align": B_short} if heading is not None else {}),
    }

    design_df, y = build_glm_design_with_trials(
        dt=dt,
        trial_ids=trial_ids,
        stimulus_dict=stimulus_dict,
        stimulus_basis_dict=stimulus_basis_dict,
        spike_counts=spike_counts,
        history_basis=B_hist,
        extra_covariates=extra_covariates,
        use_trial_FE=use_trial_FE,
    )

    meta = {"B_event": B_event, "B_short": B_short, "B_hist": B_hist}
    return design_df, y, meta


def fit_multiff_glm(
    *,
    dt: float,
    trial_ids: np.ndarray,
    cur_vis: np.ndarray,
    nxt_vis: np.ndarray,
    cur_dist: np.ndarray,
    nxt_dist: np.ndarray,
    cur_angle: np.ndarray,
    nxt_angle: np.ndarray,
    heading: Optional[np.ndarray] = None,
    speed: Optional[np.ndarray] = None,
    curvature: Optional[np.ndarray] = None,
    spike_counts: np.ndarray,
    l2: float = 0.0,
    use_trial_FE: bool = True,
    cluster_se: bool = False,
):
    design_df, y, meta = build_multiff_design(
        dt=dt, trial_ids=trial_ids,
        cur_vis=cur_vis, nxt_vis=nxt_vis,
        cur_dist=cur_dist, nxt_dist=nxt_dist,
        cur_angle=cur_angle, nxt_angle=nxt_angle,
        heading=heading, speed=speed, curvature=curvature,
        spike_counts=spike_counts, use_trial_FE=use_trial_FE,
    )
    res = fit_poisson_glm_trials(design_df, y, dt, trial_ids, add_const=True, l2=l2, cluster_se=False)
    mu = predict_mu(res, design_df, dt)
    mu_null = np.full_like(y, y.mean(), dtype=float)
    metrics = {
        "deviance": poisson_deviance(y, mu),
        "pseudo_R2": pseudo_R2(y, mu, mu_null),
        "per_trial_deviance": per_trial_deviance(y, mu, trial_ids),
    }
    return res, design_df, metrics, meta

# =====================
# Stable simulators
# =====================

def simulate_spikes_with_trials(
    *, n_trials: int = 20, trial_len: int = 300, dt: float = 0.01, seed: int = 1,
    max_rate_hz: float = 200.0
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], Dict[str, np.ndarray], np.ndarray]:
    rng = np.random.default_rng(seed)
    T = n_trials * trial_len
    trial_ids = np.repeat(np.arange(n_trials), trial_len)

    _, B_stim = raised_cosine_basis(n_basis=5, t_max=0.300, dt=dt, t_min=0.0, log_spaced=True)
    _, B_hist = raised_cosine_basis(n_basis=4, t_max=0.200, dt=dt, t_min=dt, log_spaced=True)

    event = np.concatenate([(rng.random(trial_len) < 0.02).astype(float) for _ in range(n_trials)])
    speed = np.concatenate([signal.lfilter([1.0], [1.0, -0.7], rng.normal(0.0, 1.0, size=trial_len)) for _ in range(n_trials)])
    speed = (speed - speed.mean()) / (speed.std() + 1e-12)

    beta_event = np.array([0.8, 0.5, 0.2, -0.1, -0.2])
    beta_speed = np.array([0.4, 0.2, 0.0, -0.1, -0.2])
    beta_hist  = np.array([-1.2, -0.8, -0.5, -0.2])

    X_event = lagged_design_from_signal_trials(event, B_stim, trial_ids)
    X_speed = lagged_design_from_signal_trials(speed, B_stim, trial_ids)

    h_hist = B_hist @ beta_hist
    Lh = len(h_hist)

    baseline_per_trial = -2.8 + 0.4 * rng.standard_normal(n_trials)

    y = np.zeros(T, dtype=int)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        past = np.zeros(Lh)
        b0 = baseline_per_trial[tr]
        for t in idx:
            stim_drive = X_event[t] @ beta_event + X_speed[t] @ beta_speed
            hist_drive = np.dot(past, h_hist[::-1])
            eta = b0 + stim_drive + hist_drive
            lam = safe_poisson_lambda(eta, dt, max_rate_hz=max_rate_hz)
            y[t] = np.random.poisson(lam)
            past = np.roll(past, 1); past[0] = y[t]

    stimulus_dict = {"event": event, "speed": speed}
    stimulus_basis_dict = {"event": B_stim, "speed": B_stim}
    return trial_ids, y, stimulus_dict, stimulus_basis_dict, B_hist


def simulate_multiff_trials(
    *, n_trials: int = 12, trial_len: int = 400, dt: float = 0.01, seed: int = 3,
    max_rate_hz: float = 200.0
):
    rng = np.random.default_rng(seed)
    T = n_trials * trial_len
    trial_ids = np.repeat(np.arange(n_trials), trial_len)

    def random_vis_mask():
        m = np.zeros(trial_len, dtype=int)
        t = 0
        while t < trial_len:
            off = rng.integers(20, 80)
            on  = rng.integers(30, 120)
            t += off
            if t >= trial_len: break
            m[t: min(trial_len, t + on)] = 1
            t += on
        return m

    cur_vis = np.concatenate([random_vis_mask() for _ in range(n_trials)])
    nxt_vis = np.concatenate([np.r_[np.zeros(rng.integers(40, 120)), random_vis_mask()][:trial_len] for _ in range(n_trials)])

    def distance_from_vis(m):
        d = np.zeros_like(m, dtype=float)
        run = 0.0
        for i, v in enumerate(m):
            if v:
                run = 1.0 if run == 0 else max(0.0, run - 0.02)
                d[i] = run
            else:
                run = 0.0; d[i] = 0.0
        return d

    cur_dist = np.concatenate([distance_from_vis(cur_vis[i*trial_len:(i+1)*trial_len]) for i in range(n_trials)])
    nxt_dist = np.concatenate([distance_from_vis(nxt_vis[i*trial_len:(i+1)*trial_len]) for i in range(n_trials)])

    def rand_angle_series(L):
        a = np.cumsum(rng.normal(0, 0.05, size=L))
        return ((a + np.pi) % (2*np.pi)) - np.pi

    cur_angle = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])
    nxt_angle = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])
    heading   = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])

    speed = np.concatenate([signal.lfilter([1.0], [1.0, -0.8], rng.normal(0, 1.0, size=trial_len)) for _ in range(n_trials)])
    speed = (speed - speed.mean()) / (speed.std() + 1e-12)
    curvature = np.concatenate([rng.normal(0.0, 0.4, size=trial_len) for _ in range(n_trials)])

    # Build design (no history yet) to compute stim_drive for simulation
    design_nohist, _, meta = build_multiff_design(
        dt=dt, trial_ids=trial_ids,
        cur_vis=cur_vis, nxt_vis=nxt_vis,
        cur_dist=cur_dist, nxt_dist=nxt_dist,
        cur_angle=cur_angle, nxt_angle=nxt_angle,
        heading=heading, speed=speed, curvature=curvature,
        spike_counts=None, use_trial_FE=True,
    )

    cols = list(design_nohist.columns)
    def idxs(prefix):
        return [i for i, c in enumerate(cols) if c.startswith(prefix)]

    beta = np.zeros(design_nohist.shape[1])
    for p, gain in [("cur_on", 0.9), ("nxt_on", 0.5)]:
        j = idxs(p); 
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    for p, gain in [("cur_dist", -0.7), ("nxt_dist", -0.5)]:
        j = idxs(p); 
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    for p, gain in [("cur_angle_sin", 0.15), ("cur_angle_cos", 0.15), ("nxt_angle_sin", 0.10), ("nxt_angle_cos", 0.10)]:
        j = idxs(p);
        if j: beta[j] = gain / max(1, len(j))
    for p, gain in [("cur_align", 0.35), ("nxt_align", 0.20)]:
        j = idxs(p);
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    if "speed" in cols:     beta[cols.index("speed")] = 0.2
    if "curvature" in cols: beta[cols.index("curvature")] = -0.25

    stim_drive = design_nohist.values @ beta

    B_hist = meta["B_hist"]
    h_hist = B_hist @ np.array([-1.2, -0.8, -0.5, -0.3, -0.2])
    Lh = len(h_hist)

    baseline = -3.0 + 0.3 * rng.standard_normal(n_trials)

    y = np.zeros(T, dtype=int)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        past = np.zeros(Lh)
        b0 = baseline[tr]
        for t in idx:
            eta = b0 + stim_drive[t] + np.dot(past, h_hist[::-1])
            lam = safe_poisson_lambda(eta, dt, max_rate_hz=max_rate_hz)
            y[t] = np.random.poisson(lam)
            past = np.roll(past, 1); past[0] = y[t]

    return {
        "trial_ids": trial_ids,
        "cur_vis": cur_vis,
        "nxt_vis": nxt_vis,
        "cur_dist": cur_dist,
        "nxt_dist": nxt_dist,
        "cur_angle": cur_angle,
        "nxt_angle": nxt_angle,
        "heading": heading,
        "speed": speed,
        "curvature": curvature,
        "spike_counts": y,
        "dt": dt,
    }

# =====================
# Demos
# =====================

# ---------- Plotting helpers ----------

def _coef_series(result, design_df):
    if hasattr(result, 'params'):
        params = np.asarray(result.params).ravel()
        names = getattr(result, 'exog_names', None)
        if names is None and hasattr(result.model, 'exog_names'):
            names = result.model.exog_names
        if names and names[0] == 'const' and len(params) == len(names):
            params = params[1:]
            names = names[1:]
        return pd.Series(params, index=names if names else list(design_df.columns))
    return pd.Series(np.zeros(design_df.shape[1]), index=list(design_df.columns))


def reconstruct_kernel(prefix: str, basis: np.ndarray, coef_s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
    cols = [c for c in coef_s.index if c.startswith(prefix + "_rc")]
    if not cols:
        return np.arange(basis.shape[0]), np.zeros(basis.shape[0])
    def rc_idx(c):
        try:
            return int(c.split('_rc')[-1])
        except:
            return 0
    cols.sort(key=rc_idx)
    w = coef_s.loc[cols].values
    k = basis @ w
    t = np.arange(basis.shape[0])
    return t, k


def reconstruct_history_kernel(B_hist: np.ndarray, coef_s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
    cols = [c for c in coef_s.index if c.startswith('hist_rc')]
    if not cols:
        return np.arange(B_hist.shape[0]), np.zeros(B_hist.shape[0])
    def rc_idx(c):
        try:
            return int(c.split('_rc')[-1])
        except:
            return 0
    cols.sort(key=rc_idx)
    w = coef_s.loc[cols].values
    k = B_hist @ w
    t = np.arange(B_hist.shape[0])
    return t, k


def plot_fitted_kernels(result, design_df, meta, dt, *, prefixes=None):
    import matplotlib.pyplot as plt
    if prefixes is None:
        prefixes = ['cur_on', 'nxt_on', 'cur_dist', 'nxt_dist', 'cur_angle_sin', 'cur_angle_cos', 'nxt_angle_sin', 'nxt_angle_cos']
    coef_s = _coef_series(result, design_df)
    B_event, B_short, B_hist = meta['B_event'], meta['B_short'], meta['B_hist']

    def pick_basis(p):
        if p in ['cur_on', 'nxt_on']:
            return B_event
        elif p.startswith(('cur_', 'nxt_')):
            return B_short
        else:
            return B_short

    for p in prefixes:
        B = pick_basis(p)
        t, k = reconstruct_kernel(p, B, coef_s)
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(t * dt, k)
        plt.xlabel('Time lag (s)'); plt.ylabel('Kernel weight')
        plt.title(f'{p} kernel')
        plt.show()

    coef_s = _coef_series(result, design_df)
    t_h, k_h = reconstruct_history_kernel(B_hist, coef_s)
    plt.figure()
    plt.plot(t_h * dt, k_h)
    plt.xlabel('Time lag (s)'); plt.ylabel('History weight')
    plt.title('Spike history kernel')
    plt.show()

# ---------- Demos ----------

# =====================
# Debugging utilities
# =====================

def _column_blocks(design_df: pd.DataFrame) -> Dict[str, List[str]]:
    """Group columns by prefix (before _rc if present)."""
    blocks: Dict[str, List[str]] = {}
    for c in design_df.columns:
        if '_rc' in c:
            p = c.split('_rc')[0]
        else:
            p = c
        blocks.setdefault(p, []).append(c)
    return blocks


def design_summary(design_df: pd.DataFrame, y: np.ndarray, *, topk: int = 10) -> pd.DataFrame:
    """Quick stats per column: variance, nonzero %, |corr| with y (if defined)."""
    X = design_df.values
    var = X.var(axis=0)
    nz = (np.abs(X) > 0).mean(axis=0)
    # correlation with y (guard zero-variance)
    y0 = (y - y.mean()) if y is not None else None
    cors = np.full(X.shape[1], np.nan)
    if y0 is not None and y0.std() > 0:
        for j in range(X.shape[1]):
            xj = X[:, j]
            if xj.std() > 0:
                cors[j] = np.corrcoef(xj, y0)[0, 1]
    df = pd.DataFrame({"col": design_df.columns, "var": var, "nonzero_frac": nz, "corr_y": cors})
    df["abs_corr_y"] = np.abs(df["corr_y"]) 
    return df.sort_values("abs_corr_y", ascending=False).head(topk)


def block_summary(design_df: pd.DataFrame, y: np.ndarray) -> pd.DataFrame:
    """Aggregate per block: ncols, mean var, max |corr| with y, zero-var count."""
    blocks = _column_blocks(design_df)
    rows = []
    X = design_df
    for p, cols in blocks.items():
        sub = X[cols]
        v = sub.var().mean()
        zero = int((sub.var() == 0).sum())
        # max abs corr with y
        mabs = np.nan
        if y is not None and y.std() > 0:
            cors = []
            for c in cols:
                xv = sub[c].values
                if xv.std() > 0:
                    cors.append(np.corrcoef(xv, y)[0, 1])
            mabs = float(np.nanmax(np.abs(cors))) if len(cors) else np.nan
        rows.append({"block": p, "ncols": len(cols), "mean_var": v, "zero_var": zero, "max_abs_corr_y": mabs})
    return pd.DataFrame(rows).sort_values("max_abs_corr_y", ascending=False)


def constant_or_near_constant_columns(design_df: pd.DataFrame, tol: float = 1e-12) -> List[str]:
    v = design_df.var()
    return list(v.index[v <= tol])


def svd_report(design_df: pd.DataFrame, *, k: int = 20) -> Dict[str, object]:
    X = design_df.values
    # center columns for SVD diagnostics
    Xc = X - X.mean(axis=0, keepdims=True)
    U, s, Vt = np.linalg.svd(Xc, full_matrices=False)
    cond = float(s[0] / s[-1]) if s[-1] > 0 else np.inf
    return {"rank": int((s > 1e-10).sum()), "ncols": X.shape[1], "nrows": X.shape[0], "cond_number": cond, "singular_values_top": s[:k]}


def check_param_mapping(result, design_df: pd.DataFrame) -> pd.DataFrame:
    names = getattr(result, 'exog_names', None)
    if names is None and hasattr(result.model, 'exog_names'):
        names = result.model.exog_names
    if names is None:
        # fallback: assume intercept + design columns
        names = ['const'] + list(design_df.columns)
    params = np.asarray(result.params).ravel()
    if len(params) == len(names):
        return pd.DataFrame({'name': names, 'param': params})
    # if intercept dropped
    return pd.DataFrame({'name': ['const'] + list(design_df.columns[:len(params)-1]), 'param': params})


def single_block_fit(prefix: str, design_df: pd.DataFrame, y: np.ndarray, dt: float, trial_ids: np.ndarray):
    cols = [c for c in design_df.columns if c.startswith(prefix + '_rc')] or [prefix]
    X = design_df[cols]
    res = fit_poisson_glm_trials(X, y, dt, trial_ids, add_const=True, l2=0.0, cluster_se=False)
    mu = predict_mu(res, X, dt)
    dev = poisson_deviance(y, mu)
    null = poisson_deviance(y, np.full_like(y, y.mean()))
    return {"prefix": prefix, "deviance": dev, "null_dev": null, "pseudo_R2": 1 - dev/null}


def peth_from_onsets(onsets: np.ndarray, y: np.ndarray, trial_ids: np.ndarray, *, window_bins: int = 40) -> Tuple[np.ndarray, np.ndarray]:
    """Simple PSTH around onsets, computed within trials.
    Returns (lags, mean_counts)
    """
    lags = np.arange(-window_bins, window_bins+1)
    snippets = []
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        on = np.where(onsets[idx] > 0)[0]
        for t in on:
            left = t - window_bins
            right = t + window_bins
            if left >= 0 and right < len(idx):
                seg = y[idx][left:right+1]
                snippets.append(seg)
    if not snippets:
        return lags, np.zeros_like(lags, dtype=float)
    M = np.vstack(snippets)
    return lags, M.mean(axis=0)




def _kernel_with_ci(result, design_df, prefix: str, basis: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Mean and 95% CI for a single prefix kernel using delta method.
    Returns (t_idx, mean, std) for the time-domain kernel.
    """
    coef_names = getattr(result, 'model', None).exog_names if hasattr(result, 'model') else None
    if coef_names is None:
        raise ValueError('No coefficient names found; fit must use a DataFrame design.')
    cols = _prefix_cols(prefix, coef_names)
    if not cols:
        L = basis.shape[0]
        return np.arange(L), np.zeros(L), np.zeros(L)

    # weights and covariance for the basis weights of this prefix
    w = result.params[result.params.index.isin(cols)] if hasattr(result.params, 'index') else None
    if w is None or len(w) != len(cols):
        # fall back to locating by name via cov_params DataFrame
        params_series = _coef_series(result, design_df)
        w = params_series.loc[cols]
    cov = result.cov_params().loc[cols, cols].values

    # kernel mean and variance across lags
    L = basis.shape[0]
    mean = basis @ w.values
    var = np.einsum('li,ij,lj->l', basis, cov, basis)
    std = np.sqrt(np.maximum(var, 0.0))
    t_idx = np.arange(L)
    return t_idx, mean, std


def plot_angle_kernels_with_ci(result, design_df, meta, dt, base_prefix: str = 'cur_angle'):
    """Plot sin, cos, and amplitude (with 95% CIs) for angle tuning over time.
    Uses cluster/robust covariance if that was used in the fit.
    """
    import matplotlib.pyplot as plt

    B = meta['B_short']
    # sin
    t_idx, ksin, std_s = _kernel_with_ci(result, design_df, f'{base_prefix}_sin', B)
    # cos
    _,    kcos, std_c = _kernel_with_ci(result, design_df, f'{base_prefix}_cos', B)

    # Cross-covariance between sin and cos weights to get Var(amplitude)
    coef_names = result.model.exog_names
    sin_cols = _prefix_cols(f'{base_prefix}_sin', coef_names)
    cos_cols = _prefix_cols(f'{base_prefix}_cos', coef_names)
    cov_full = result.cov_params()
    cov_ss = cov_full.loc[sin_cols, sin_cols].values
    cov_cc = cov_full.loc[cos_cols, cos_cols].values
    cov_sc = cov_full.loc[sin_cols, cos_cols].values

    # Build amplitude mean and variance per lag
    A = np.sqrt(ksin**2 + kcos**2)
    var_A = np.zeros_like(A)
    for l in range(B.shape[0]):
        b = B[l, :][:, None]  # (K,1)
        var_s = float(b.T @ cov_ss @ b)
        var_c = float(b.T @ cov_cc @ b)
        cov_sc_l = float(b.T @ cov_sc @ b)
        # delta method
        eps = 1e-12
        denom = max(A[l], eps)
        g = np.array([ksin[l] / denom, kcos[l] / denom])
        Sigma = np.array([[var_s, cov_sc_l], [cov_sc_l, var_c]])
        var_A[l] = float(g.T @ Sigma @ g)

    std_A = np.sqrt(np.maximum(var_A, 0.0))

    # --- plots ---
    t = t_idx * dt
    # sin
    plt.figure()
    plt.plot(t, ksin, label='sin')
    plt.fill_between(t, ksin - 1.96*std_s, ksin + 1.96*std_s, alpha=0.2)
    plt.xlabel('Time lag (s)'); plt.ylabel('Kernel weight'); plt.title(f'{base_prefix}_sin kernel (95% CI)')
    plt.legend(); plt.show()

    # cos
    plt.figure()
    plt.plot(t, kcos, label='cos')
    plt.fill_between(t, kcos - 1.96*std_c, kcos + 1.96*std_c, alpha=0.2)
    plt.xlabel('Time lag (s)'); plt.ylabel('Kernel weight'); plt.title(f'{base_prefix}_cos kernel (95% CI)')
    plt.legend(); plt.show()

    # amplitude
    plt.figure()
    plt.plot(t, A, label='amplitude')
    plt.fill_between(t, np.maximum(0.0, A - 1.96*std_A), A + 1.96*std_A, alpha=0.2)
    plt.xlabel('Time lag (s)'); plt.ylabel('Amplitude'); plt.title(f'{base_prefix} amplitude (95% CI)')
    plt.legend(); plt.show()

    t_h, k_h = reconstruct_history_kernel(B_hist, coef_s)
    plt.figure()
    plt.plot(t_h * dt, k_h)
    plt.xlabel('Time lag (s)'); plt.ylabel('History weight')
    plt.title('Spike history kernel')
    plt.show()


def _prefix_cols(prefix: str, names: List[str]) -> List[str]:
    cols = [n for n in names if n.startswith(prefix + '_rc')]
    def rc_idx(c):
        try:
            return int(c.split('_rc')[-1])
        except:
            return 0
    return sorted(cols, key=rc_idx)


def angle_tuning_vs_time(result, design_df, meta, base_prefix: str = 'cur_angle'):
    """Return time-index, sin/cos kernels, amplitude and preferred angle over lags.
    `base_prefix` should be 'cur_angle' or 'nxt_angle'.
    """
    coef_s = _coef_series(result, design_df)
    B = meta['B_short']
    t_idx, k_sin = reconstruct_kernel(f'{base_prefix}_sin', B, coef_s)
    _,    k_cos = reconstruct_kernel(f'{base_prefix}_cos', B, coef_s)
    A = np.sqrt(k_sin**2 + k_cos**2)
    phi = np.arctan2(k_sin, k_cos)
    return t_idx, k_sin, k_cos, A, phi


def plot_angle_tuning(result, design_df, meta, dt):
    import matplotlib.pyplot as plt
    t_idx, k_sin, k_cos, A, phi = angle_tuning_vs_time(result, design_df, meta)

    # kernels
    plt.figure(); plt.plot(t_idx*dt, k_cos, label="cos"); plt.plot(t_idx*dt, k_sin, label="sin")
    plt.xlabel("Time lag (s)"); plt.ylabel("Kernel weight"); plt.title("Angle kernels"); plt.legend(); plt.show()

    # amplitude
    plt.figure(); plt.plot(t_idx*dt, A)
    plt.xlabel("Time lag (s)"); plt.ylabel("Amplitude"); plt.title("Directional tuning amplitude vs lag"); plt.show()

    # preferred angle
    plt.figure(); plt.plot(t_idx*dt, phi)
    plt.xlabel("Time lag (s)"); plt.ylabel("Preferred angle (rad)"); plt.title("Preferred angle vs lag"); plt.show()



def debug_all_kernels_flat(res, design_df, y, trial_ids, meta, dt):
    """Run a suite of checks and print concise results."""
    print("[DEBUG] Constant/near-constant columns:")
    print(constant_or_near_constant_columns(design_df)[:20])

    print("[DEBUG] Top-10 columns by |corr(y)|:")
    print(design_summary(design_df, y, topk=10))

    print("[DEBUG] Block summary (max |corr(y)| per block):")
    print(block_summary(design_df, y))

    print("[DEBUG] SVD report:")
    print(svd_report(design_df))

    print("[DEBUG] Param mapping (first 20):")
    pm = check_param_mapping(res, design_df)
    print(pm.head(20))

    # Quick single-block tests
    prefixes = sorted(set([c.split('_rc')[0] if '_rc' in c else c for c in design_df.columns]))
    tests = []
    for p in prefixes:
        tests.append(single_block_fit(p, design_df, y, dt, trial_ids))
    tests_df = pd.DataFrame(tests).sort_values('pseudo_R2', ascending=False)
    print("[DEBUG] Single-block fits (sorted by pseudo_R2):")
    print(tests_df)

    # Example PETH if cur_on is present
    if 'cur_on_rc1' in design_df.columns or 'cur_on' in design_df.columns:
        on_name = 'cur_on' if 'cur_on' in design_df.columns else None
        if on_name is None:
            # reconstruct onsets from basis columns by deconvolving first basis roughly (just for visualization)
            on_name = 'cur_on'
            print("Note: cur_on onset vector not present raw; consider passing it to peth_from_onsets directly from your data.")
        else:
            lags, mean_counts = peth_from_onsets(design_df[on_name].values, y, trial_ids, window_bins=40)
            import matplotlib.pyplot as plt
            plt.figure()
            plt.plot(lags * dt, mean_counts)
            plt.axvline(0, linestyle='--')
            plt.xlabel('Time (s)'); plt.ylabel('Mean spikes/bin')
            plt.title('PETH around cur_on onsets')
            plt.show()


In [None]:
plot_angle_tuning(res, design_df, meta, dt)


In [None]:
# After fitting:


# Plots with 95% CIs
plot_angle_kernels_with_ci(res, design_df, meta, dt, base_prefix='cur_angle')

# If you also model next-angle:
plot_angle_kernels_with_ci(res, design_df, meta, dt, base_prefix='nxt_angle')


# GLM 1

## synthetic data

In [None]:
data

In [None]:


# run_multiff_demo
data = simulate_multiff_trials(n_trials=8, trial_len=300, dt=0.01, seed=7)
res, design_df, metrics, meta = fit_multiff_glm(
    dt=data["dt"], trial_ids=data["trial_ids"],
    cur_vis=data["cur_vis"], nxt_vis=data["nxt_vis"],
    cur_dist=data["cur_dist"], nxt_dist=data["nxt_dist"],
    cur_angle=data["cur_angle"], nxt_angle=data["nxt_angle"],
    heading=data["heading"], speed=data["speed"], curvature=data["curvature"],
    spike_counts=data["spike_counts"], l2=0.0, use_trial_FE=True, cluster_se=False,
)
print(res.summary())
print("Overall deviance:", metrics["deviance"])
print("Pseudo-R^2:", metrics["pseudo_R2"])
print("Per-trial deviance (head):", metrics["per_trial_deviance"].head())
cv = fit_and_score_cv(design_df, data["spike_counts"], data["dt"], data["trial_ids"], n_splits=5, l2=0.0, cluster_se=False)
print("Trial-wise CV:", cv)

plot_fitted_kernels(res, design_df, meta, dt)

In [None]:
run_spike_only_demo()
run_multiff_demo()

In [None]:
plot_fitted_kernels(res, design_df, meta, dt)

## functions (early)

In [None]:
"""
Trial-aware Poisson GLM for multiFF neural spike data (statsmodels)
------------------------------------------------------------------
End-to-end, runnable module:
- Raised-cosine bases
- Trial-aware stimulus & spike-history design (no cross-trial leakage)
- Optional trial fixed effects or cluster-robust SEs by trial
- Poisson fitting (statsmodels.GLM), prediction, metrics, trial-wise CV
- MultiFF adapter with visibility/distance/angle/heading features
- Stable simulators using capped Poisson intensities
- Demos: run `python this_file.py` or call `run_multiff_demo()`
"""
from __future__ import annotations

import numpy as np
import pandas as pd
from typing import Dict, Iterable, List, Optional, Tuple
from scipy import signal
from scipy.linalg import toeplitz
import statsmodels.api as sm

# =====================
# Helpers & bases
# =====================

def _unique_trials(trial_ids: np.ndarray) -> np.ndarray:
    return np.unique(np.asarray(trial_ids))


def raised_cosine_basis(n_basis: int, t_max: float, dt: float, *, t_min: float = 0.0,
                        log_spaced: bool = True, eps: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
    """Causal raised-cosine basis that tiles [t_min, t_max].
    Returns lags (L,), B (L x K) with unit-area columns (sum * dt = 1).
    """
    lags = np.arange(0.0, t_max + 1e-12, dt)
    K = int(n_basis)

    def warp(x):
        return np.log(x + eps) if log_spaced else x

    W = warp(lags)
    W_min, W_max = warp(t_min), warp(t_max)
    centers = np.linspace(W_min, W_max, K)
    delta = centers[1] - centers[0] if K > 1 else (W_max - W_min + 1e-12)
    width = delta * 1.5

    B = []
    for c in centers:
        arg = (W - c) / width
        bk = np.cos(np.clip(arg, -np.pi, np.pi))
        bk[np.abs(arg) > np.pi] = 0.0
        bk = np.maximum(bk, 0.0)
        bk[lags < t_min] = 0.0
        area = bk.sum() * dt
        if area > 0:
            bk = bk / area
        B.append(bk)
    B = np.column_stack(B) if B else np.zeros((len(lags), 0))
    return lags, B


def safe_poisson_lambda(eta: float | np.ndarray, dt: float, *, max_rate_hz: float = 200.0) -> np.ndarray:
    """Convert log-rate `eta` (per-second) to expected bin count lambda, with a cap."""
    log_min = np.log(1e-6)
    log_max = np.log(max_rate_hz)
    eta_clipped = np.clip(eta, log_min, log_max)
    rate_hz = np.exp(eta_clipped)
    return rate_hz * dt


def wrap_angle(theta: np.ndarray) -> np.ndarray:
    return (theta + np.pi) % (2 * np.pi) - np.pi


def angle_sin_cos(theta: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    th = wrap_angle(theta)
    return np.sin(th), np.cos(th)


def onset_from_mask_trials(mask: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    mask = (mask > 0).astype(int)
    on = np.zeros_like(mask, dtype=float)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        m = mask[idx]
        d = np.diff(np.r_[0, m])
        on[idx] = (d == 1).astype(float)
    return on


def offset_from_mask_trials(mask: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    mask = (mask > 0).astype(int)
    off = np.zeros_like(mask, dtype=float)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        m = mask[idx]
        d = np.diff(np.r_[m, 0])
        off[idx] = (d == 1).astype(float)
    return off

# =====================
# Trial-aware design builders
# =====================

def lagged_design_from_signal_trials(x: np.ndarray, basis: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    T = len(x)
    L, K = basis.shape
    Xk = np.zeros((T, K))
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        xt = x[idx]
        for k in range(K):
            y = signal.convolve(xt, basis[:, k], mode="full")[: len(idx)]
            Xk[idx, k] = y
    return Xk


def spike_history_design_with_trials(y_counts: np.ndarray, basis: np.ndarray, trial_ids: np.ndarray) -> np.ndarray:
    T = len(y_counts)
    L, K = basis.shape
    Xh = np.zeros((T, K))
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        y_tr = y_counts[idx]
        col0 = np.r_[0, y_tr[:-1]]  # strictly past
        toepl = toeplitz(col0, np.zeros(L))
        Xh[idx, :] = toepl @ basis
    return Xh


def build_glm_design_with_trials(
    dt: float,
    trial_ids: np.ndarray,
    stimulus_dict: Optional[Dict[str, np.ndarray]] = None,
    stimulus_basis_dict: Optional[Dict[str, np.ndarray]] = None,
    spike_counts: Optional[np.ndarray] = None,
    history_basis: Optional[np.ndarray] = None,
    extra_covariates: Optional[Dict[str, np.ndarray]] = None,
    use_trial_FE: bool = True,
) -> Tuple[pd.DataFrame, Optional[np.ndarray]]:
    T = len(trial_ids)
    cols: List[np.ndarray] = []
    names: List[str] = []

    if stimulus_dict is not None:
        for name, x in stimulus_dict.items():
            if stimulus_basis_dict is not None and name in stimulus_basis_dict:
                B = stimulus_basis_dict[name]
                Xk = lagged_design_from_signal_trials(x, B, trial_ids)
                for k in range(Xk.shape[1]):
                    cols.append(Xk[:, k]); names.append(f"{name}_rc{k+1}")
            else:
                cols.append(x); names.append(name)

    if spike_counts is not None and history_basis is not None:
        Xh = spike_history_design_with_trials(spike_counts, history_basis, trial_ids)
        for k in range(Xh.shape[1]):
            cols.append(Xh[:, k]); names.append(f"hist_rc{k+1}")

    if extra_covariates is not None:
        for n, v in extra_covariates.items():
            cols.append(v); names.append(n)

    X = np.column_stack(cols) if cols else np.zeros((T, 0))
    design_df = pd.DataFrame(X, columns=names)

    if use_trial_FE:
        trial_FE = pd.get_dummies(trial_ids, prefix="trial", drop_first=True)
        design_df = pd.concat([design_df, trial_FE], axis=1)

    y = spike_counts if spike_counts is not None else None
    return design_df, y

# =====================
# Fitting & metrics
# =====================

def add_intercept(X: np.ndarray) -> np.ndarray:
    return np.column_stack([np.ones(len(X)), X])


def fit_poisson_glm_trials(
    design_df: pd.DataFrame,
    y: np.ndarray,
    dt: float,
    trial_ids: np.ndarray,
    *,
    add_const: bool = True,
    l2: float = 0.0,
    cluster_se: bool = False,
):
    X = design_df.values
    if add_const:
        X = add_intercept(X)
    exposure = np.full_like(y, fill_value=dt, dtype=float)

    model = sm.GLM(y, X, family=sm.families.Poisson(), exposure=exposure)
    if l2 > 0:
        res = model.fit_regularized(alpha=l2, L1_wt=0.0, maxiter=1000)
        # attach names for convenience
        res.exog_names = (["const"] + list(design_df.columns)) if add_const else list(design_df.columns)
        return res
    else:
        if cluster_se:
            return model.fit(cov_type="cluster", cov_kwds={"groups": trial_ids})
        else:
            return model.fit()


def predict_mu(result, design_df: pd.DataFrame, dt: float, add_const: bool = True) -> np.ndarray:
    X = design_df.values
    if add_const:
        X = add_intercept(X)
    mu = result.predict(X, exposure=np.full(len(X), dt))
    return np.asarray(mu, dtype=float)


def poisson_deviance(y: np.ndarray, mu: np.ndarray) -> float:
    y = np.asarray(y, dtype=float)
    mu = np.asarray(mu, dtype=float)
    eps = 1e-12
    term = np.where(y > 0, y * np.log((y + eps) / (mu + eps)), 0.0)
    dev = 2.0 * np.sum(term - (y - mu))
    return float(dev)


def pseudo_R2(y: np.ndarray, mu_full: np.ndarray, mu_null: np.ndarray) -> float:
    eps = 1e-12
    ll_full = np.sum(y * np.log(mu_full + eps) - mu_full)
    ll_null = np.sum(y * np.log(mu_null + eps) - mu_null)
    return float(1.0 - ll_full / ll_null)


def per_trial_deviance(y: np.ndarray, mu: np.ndarray, trial_ids: np.ndarray) -> pd.DataFrame:
    dev = pd.DataFrame({"trial": trial_ids, "y": y, "mu": mu})
    eps = 1e-12
    dev["bin_dev"] = np.where(dev["y"] > 0, dev["y"] * np.log((dev["y"] + eps) / (dev["mu"] + eps)), 0.0) - (dev["y"] - dev["mu"])
    out = dev.groupby("trial", as_index=False)["bin_dev"].sum()
    out.rename(columns={"bin_dev": "trial_deviance"}, inplace=True)
    out["trial_deviance"] *= 2.0
    return out

# =====================
# CV utilities
# =====================

def trialwise_folds(trial_ids: np.ndarray, n_splits: int, *, shuffle: bool = False, random_state: Optional[int] = None) -> List[Tuple[np.ndarray, np.ndarray]]:
    rng = np.random.default_rng(random_state)
    trials = _unique_trials(trial_ids)
    if shuffle:
        trials = trials.copy(); rng.shuffle(trials)
    folds = np.array_split(trials, n_splits)
    out = []
    for k in range(n_splits):
        val_trials = set(folds[k])
        val_mask = np.isin(trial_ids, list(val_trials))
        train_mask = ~val_mask
        out.append((train_mask, val_mask))
    return out


def fit_and_score_cv(
    design_df: pd.DataFrame,
    y: np.ndarray,
    dt: float,
    trial_ids: np.ndarray,
    *,
    n_splits: int = 5,
    l2: float = 0.0,
    use_trial_FE: bool = True,
    cluster_se: bool = False,
    random_state: Optional[int] = 0,
) -> pd.DataFrame:
    folds = trialwise_folds(trial_ids, n_splits, shuffle=True, random_state=random_state)
    rows = []
    for i, (train_mask, val_mask) in enumerate(folds, start=1):
        X_train = design_df.iloc[train_mask]
        y_train = y[train_mask]
        tr_ids_train = trial_ids[train_mask]

        X_val = design_df.iloc[val_mask]
        y_val = y[val_mask]
        tr_ids_val = trial_ids[val_mask]

        res = fit_poisson_glm_trials(X_train, y_train, dt, tr_ids_train,
                                     add_const=True, l2=l2, cluster_se=False)
        mu_val = predict_mu(res, X_val, dt)
        mu_null = np.full_like(y_val, y_train.mean(), dtype=float)

        fold_dev = poisson_deviance(y_val, mu_val)
        fold_r2 = pseudo_R2(y_val, mu_val, mu_null)
        rows.append({"fold": i, "val_deviance": fold_dev, "val_pseudo_R2": fold_r2, "n_val_bins": int(val_mask.sum())})
    return pd.DataFrame(rows)

# =====================
# MultiFF adapter (features & end-to-end)
# =====================

def build_multiff_design(
    *,
    dt: float,
    trial_ids: np.ndarray,
    cur_vis: np.ndarray,
    nxt_vis: np.ndarray,
    cur_dist: np.ndarray,
    nxt_dist: np.ndarray,
    cur_angle: np.ndarray,
    nxt_angle: np.ndarray,
    heading: Optional[np.ndarray] = None,
    speed: Optional[np.ndarray] = None,
    curvature: Optional[np.ndarray] = None,
    spike_counts: Optional[np.ndarray] = None,
    use_trial_FE: bool = True,
) -> Tuple[pd.DataFrame, Optional[np.ndarray], Dict[str, np.ndarray]]:
    T = len(trial_ids)

    # Bases
    _, B_event = raised_cosine_basis(n_basis=6, t_max=0.60, dt=dt, t_min=0.0, log_spaced=True)
    _, B_short = raised_cosine_basis(n_basis=5, t_max=0.30, dt=dt, t_min=0.0, log_spaced=True)
    _, B_hist  = raised_cosine_basis(n_basis=5, t_max=0.20, dt=dt, t_min=dt,  log_spaced=True)

    # Onsets and gated features
    cur_on = onset_from_mask_trials(cur_vis, trial_ids)
    nxt_on = onset_from_mask_trials(nxt_vis, trial_ids)

    cur_dist_g = cur_dist * (cur_vis > 0)
    nxt_dist_g = nxt_dist * (nxt_vis > 0)

    cur_sin, cur_cos = angle_sin_cos(cur_angle)
    nxt_sin, nxt_cos = angle_sin_cos(nxt_angle)
    cur_sin *= (cur_vis > 0); cur_cos *= (cur_vis > 0)
    nxt_sin *= (nxt_vis > 0); nxt_cos *= (nxt_vis > 0)

    stimulus_dict: Dict[str, np.ndarray] = {
        "cur_on": cur_on,
        "nxt_on": nxt_on,
        "cur_dist": cur_dist_g,
        "nxt_dist": nxt_dist_g,
        "cur_angle_sin": cur_sin,
        "cur_angle_cos": cur_cos,
        "nxt_angle_sin": nxt_sin,
        "nxt_angle_cos": nxt_cos,
    }

    extra_covariates: Dict[str, np.ndarray] = {}
    if heading is not None:
        cur_align = np.cos(wrap_angle(heading - cur_angle)) * (cur_vis > 0)
        nxt_align = np.cos(wrap_angle(heading - nxt_angle)) * (nxt_vis > 0)
        stimulus_dict["cur_align"] = cur_align
        stimulus_dict["nxt_align"] = nxt_align

    if speed is not None:   extra_covariates["speed"] = speed
    if curvature is not None: extra_covariates["curvature"] = curvature

    stimulus_basis_dict: Dict[str, np.ndarray] = {
        "cur_on": B_event, "nxt_on": B_event,
        "cur_dist": B_short, "nxt_dist": B_short,
        "cur_angle_sin": B_short, "cur_angle_cos": B_short,
        "nxt_angle_sin": B_short, "nxt_angle_cos": B_short,
        **({"cur_align": B_short, "nxt_align": B_short} if heading is not None else {}),
    }

    design_df, y = build_glm_design_with_trials(
        dt=dt,
        trial_ids=trial_ids,
        stimulus_dict=stimulus_dict,
        stimulus_basis_dict=stimulus_basis_dict,
        spike_counts=spike_counts,
        history_basis=B_hist,
        extra_covariates=extra_covariates,
        use_trial_FE=use_trial_FE,
    )

    meta = {"B_event": B_event, "B_short": B_short, "B_hist": B_hist}
    return design_df, y, meta


def fit_multiff_glm(
    *,
    dt: float,
    trial_ids: np.ndarray,
    cur_vis: np.ndarray,
    nxt_vis: np.ndarray,
    cur_dist: np.ndarray,
    nxt_dist: np.ndarray,
    cur_angle: np.ndarray,
    nxt_angle: np.ndarray,
    heading: Optional[np.ndarray] = None,
    speed: Optional[np.ndarray] = None,
    curvature: Optional[np.ndarray] = None,
    spike_counts: np.ndarray,
    l2: float = 0.0,
    use_trial_FE: bool = True,
    cluster_se: bool = False,
):
    design_df, y, meta = build_multiff_design(
        dt=dt, trial_ids=trial_ids,
        cur_vis=cur_vis, nxt_vis=nxt_vis,
        cur_dist=cur_dist, nxt_dist=nxt_dist,
        cur_angle=cur_angle, nxt_angle=nxt_angle,
        heading=heading, speed=speed, curvature=curvature,
        spike_counts=spike_counts, use_trial_FE=use_trial_FE,
    )
    res = fit_poisson_glm_trials(design_df, y, dt, trial_ids, add_const=True, l2=l2, cluster_se=False)
    mu = predict_mu(res, design_df, dt)
    mu_null = np.full_like(y, y.mean(), dtype=float)
    metrics = {
        "deviance": poisson_deviance(y, mu),
        "pseudo_R2": pseudo_R2(y, mu, mu_null),
        "per_trial_deviance": per_trial_deviance(y, mu, trial_ids),
    }
    return res, design_df, metrics, meta

# =====================
# Stable simulators
# =====================

def simulate_spikes_with_trials(
    *, n_trials: int = 20, trial_len: int = 300, dt: float = 0.01, seed: int = 1,
    max_rate_hz: float = 200.0
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray], Dict[str, np.ndarray], np.ndarray]:
    rng = np.random.default_rng(seed)
    T = n_trials * trial_len
    trial_ids = np.repeat(np.arange(n_trials), trial_len)

    _, B_stim = raised_cosine_basis(n_basis=5, t_max=0.300, dt=dt, t_min=0.0, log_spaced=True)
    _, B_hist = raised_cosine_basis(n_basis=4, t_max=0.200, dt=dt, t_min=dt, log_spaced=True)

    event = np.concatenate([(rng.random(trial_len) < 0.02).astype(float) for _ in range(n_trials)])
    speed = np.concatenate([signal.lfilter([1.0], [1.0, -0.7], rng.normal(0.0, 1.0, size=trial_len)) for _ in range(n_trials)])
    speed = (speed - speed.mean()) / (speed.std() + 1e-12)

    beta_event = np.array([0.8, 0.5, 0.2, -0.1, -0.2])
    beta_speed = np.array([0.4, 0.2, 0.0, -0.1, -0.2])
    beta_hist  = np.array([-1.2, -0.8, -0.5, -0.2])

    X_event = lagged_design_from_signal_trials(event, B_stim, trial_ids)
    X_speed = lagged_design_from_signal_trials(speed, B_stim, trial_ids)

    h_hist = B_hist @ beta_hist
    Lh = len(h_hist)

    baseline_per_trial = -2.8 + 0.4 * rng.standard_normal(n_trials)

    y = np.zeros(T, dtype=int)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        past = np.zeros(Lh)
        b0 = baseline_per_trial[tr]
        for t in idx:
            stim_drive = X_event[t] @ beta_event + X_speed[t] @ beta_speed
            hist_drive = np.dot(past, h_hist[::-1])
            eta = b0 + stim_drive + hist_drive
            lam = safe_poisson_lambda(eta, dt, max_rate_hz=max_rate_hz)
            y[t] = np.random.poisson(lam)
            past = np.roll(past, 1); past[0] = y[t]

    stimulus_dict = {"event": event, "speed": speed}
    stimulus_basis_dict = {"event": B_stim, "speed": B_stim}
    return trial_ids, y, stimulus_dict, stimulus_basis_dict, B_hist


def simulate_multiff_trials(
    *, n_trials: int = 12, trial_len: int = 400, dt: float = 0.01, seed: int = 3,
    max_rate_hz: float = 200.0
):
    rng = np.random.default_rng(seed)
    T = n_trials * trial_len
    trial_ids = np.repeat(np.arange(n_trials), trial_len)

    def random_vis_mask():
        m = np.zeros(trial_len, dtype=int)
        t = 0
        while t < trial_len:
            off = rng.integers(20, 80)
            on  = rng.integers(30, 120)
            t += off
            if t >= trial_len: break
            m[t: min(trial_len, t + on)] = 1
            t += on
        return m

    cur_vis = np.concatenate([random_vis_mask() for _ in range(n_trials)])
    nxt_vis = np.concatenate([np.r_[np.zeros(rng.integers(40, 120)), random_vis_mask()][:trial_len] for _ in range(n_trials)])

    def distance_from_vis(m):
        d = np.zeros_like(m, dtype=float)
        run = 0.0
        for i, v in enumerate(m):
            if v:
                run = 1.0 if run == 0 else max(0.0, run - 0.02)
                d[i] = run
            else:
                run = 0.0; d[i] = 0.0
        return d

    cur_dist = np.concatenate([distance_from_vis(cur_vis[i*trial_len:(i+1)*trial_len]) for i in range(n_trials)])
    nxt_dist = np.concatenate([distance_from_vis(nxt_vis[i*trial_len:(i+1)*trial_len]) for i in range(n_trials)])

    def rand_angle_series(L):
        a = np.cumsum(rng.normal(0, 0.05, size=L))
        return ((a + np.pi) % (2*np.pi)) - np.pi

    cur_angle = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])
    nxt_angle = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])
    heading   = np.concatenate([rand_angle_series(trial_len) for _ in range(n_trials)])

    speed = np.concatenate([signal.lfilter([1.0], [1.0, -0.8], rng.normal(0, 1.0, size=trial_len)) for _ in range(n_trials)])
    speed = (speed - speed.mean()) / (speed.std() + 1e-12)
    curvature = np.concatenate([rng.normal(0.0, 0.4, size=trial_len) for _ in range(n_trials)])

    # Build design (no history yet) to compute stim_drive for simulation
    design_nohist, _, meta = build_multiff_design(
        dt=dt, trial_ids=trial_ids,
        cur_vis=cur_vis, nxt_vis=nxt_vis,
        cur_dist=cur_dist, nxt_dist=nxt_dist,
        cur_angle=cur_angle, nxt_angle=nxt_angle,
        heading=heading, speed=speed, curvature=curvature,
        spike_counts=None, use_trial_FE=True,
    )

    cols = list(design_nohist.columns)
    def idxs(prefix):
        return [i for i, c in enumerate(cols) if c.startswith(prefix)]

    beta = np.zeros(design_nohist.shape[1])
    for p, gain in [("cur_on", 0.9), ("nxt_on", 0.5)]:
        j = idxs(p); 
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    for p, gain in [("cur_dist", -0.7), ("nxt_dist", -0.5)]:
        j = idxs(p); 
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    for p, gain in [("cur_angle_sin", 0.15), ("cur_angle_cos", 0.15), ("nxt_angle_sin", 0.10), ("nxt_angle_cos", 0.10)]:
        j = idxs(p);
        if j: beta[j] = gain / max(1, len(j))
    for p, gain in [("cur_align", 0.35), ("nxt_align", 0.20)]:
        j = idxs(p);
        if j: beta[j] = np.linspace(gain, 0.0, num=len(j))
    if "speed" in cols:     beta[cols.index("speed")] = 0.2
    if "curvature" in cols: beta[cols.index("curvature")] = -0.25

    stim_drive = design_nohist.values @ beta

    B_hist = meta["B_hist"]
    h_hist = B_hist @ np.array([-1.2, -0.8, -0.5, -0.3, -0.2])
    Lh = len(h_hist)

    baseline = -3.0 + 0.3 * rng.standard_normal(n_trials)

    y = np.zeros(T, dtype=int)
    for tr in _unique_trials(trial_ids):
        idx = np.where(trial_ids == tr)[0]
        past = np.zeros(Lh)
        b0 = baseline[tr]
        for t in idx:
            eta = b0 + stim_drive[t] + np.dot(past, h_hist[::-1])
            lam = safe_poisson_lambda(eta, dt, max_rate_hz=max_rate_hz)
            y[t] = np.random.poisson(lam)
            past = np.roll(past, 1); past[0] = y[t]

    return {
        "trial_ids": trial_ids,
        "cur_vis": cur_vis,
        "nxt_vis": nxt_vis,
        "cur_dist": cur_dist,
        "nxt_dist": nxt_dist,
        "cur_angle": cur_angle,
        "nxt_angle": nxt_angle,
        "heading": heading,
        "speed": speed,
        "curvature": curvature,
        "spike_counts": y,
        "dt": dt,
    }



# =====================
# plotting helpers

def _coef_series(result, design_df):
    if hasattr(result, 'params'):
        params = np.asarray(result.params).ravel()
        names = getattr(result, 'exog_names', None)
        if names is None and hasattr(result.model, 'exog_names'):
            names = result.model.exog_names
        if names and names[0] == 'const' and len(params) == len(names):
            params = params[1:]
            names = names[1:]
        return pd.Series(params, index=names if names else list(design_df.columns))
    return pd.Series(np.zeros(design_df.shape[1]), index=list(design_df.columns))


def reconstruct_kernel(prefix: str, basis: np.ndarray, coef_s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
    cols = [c for c in coef_s.index if c.startswith(prefix + "_rc")]
    if not cols:
        return np.arange(basis.shape[0]), np.zeros(basis.shape[0])
    def rc_idx(c):
        try:
            return int(c.split('_rc')[-1])
        except:
            return 0
    cols.sort(key=rc_idx)
    w = coef_s.loc[cols].values
    k = basis @ w
    t = np.arange(basis.shape[0])
    return t, k


def reconstruct_history_kernel(B_hist: np.ndarray, coef_s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
    cols = [c for c in coef_s.index if c.startswith('hist_rc')]
    if not cols:
        return np.arange(B_hist.shape[0]), np.zeros(B_hist.shape[0])
    def rc_idx(c):
        try:
            return int(c.split('_rc')[-1])
        except:
            return 0
    cols.sort(key=rc_idx)
    w = coef_s.loc[cols].values
    k = B_hist @ w
    t = np.arange(B_hist.shape[0])
    return t, k


def plot_fitted_kernels(result, design_df, meta, dt, *, prefixes=None):
    import matplotlib.pyplot as plt
    if prefixes is None:
        prefixes = ['cur_on', 'nxt_on', 'cur_dist', 'nxt_dist', 'cur_angle_sin', 'cur_angle_cos', 'nxt_angle_sin', 'nxt_angle_cos']
    coef_s = _coef_series(result, design_df)
    B_event, B_short, B_hist = meta['B_event'], meta['B_short'], meta['B_hist']

    def pick_basis(p):
        if p in ['cur_on', 'nxt_on']:
            return B_event
        elif p.startswith(('cur_', 'nxt_')):
            return B_short
        else:
            return B_short

    for p in prefixes:
        B = pick_basis(p)
        t, k = reconstruct_kernel(p, B, coef_s)
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(t * dt, k)
        plt.xlabel('Time lag (s)'); plt.ylabel('Kernel weight')
        plt.title(f'{p} kernel')
        plt.show()

    t_h, k_h = reconstruct_history_kernel(B_hist, coef_s)
    plt.figure()
    plt.plot(t_h * dt, k_h)
    plt.xlabel('Time lag (s)'); plt.ylabel('History weight')
    plt.title('Spike history kernel')
    plt.show()
# =====================






# GLM 2

In [None]:
# import glm_models
from neural_data_analysis.neural_analysis_tools.glm_tools import glm_models

In [None]:
dt = 0.01  # 10 ms bins for modeling

bases_cfg = {
    'speed':   {'kind':'cont',  't_max':0.30, 'K':6, 'log':True},   # 0–300 ms filter
    'go':      {'kind':'event', 't_max':0.50, 'K':6, 'log':True},   # 0–500 ms post-event
    'hist':    {'kind':'hist',  't_max_short':0.025, 'K_short':3,   # 0–25 ms (fine)
                              't_max_long':0.250,  'K_long':5},     # 25–250 ms (coarse)
    'heading': {'kind':'angle'}                                      # instantaneous sin/cos
}

model_bundle = glm_models.fit_glm_cv(train_trials, dt, bases_cfg,
                          alphas=np.logspace(-4, 1, 8),
                          n_splits=5, shuffle_trials=True, random_state=0)
print("Best alpha:", model_bundle['best_alpha'])


results = evaluate_on_test(train_trials, test_trials, dt, bases_cfg,
                           model_bundle, smooth_sigma_ms=40)  # optional PSTH smoothing
for k, v in results.items():
    if not isinstance(v, np.ndarray):
        print(k, ":", v)


### model functions

In [None]:
# import numpy as np
# from scipy.signal import fftconvolve
# from scipy.special import gammaln
# from sklearn.preprocessing import StandardScaler
# from sklearn.linear_model import PoissonRegressor
# from sklearn.model_selection import KFold

# # -------------------- Bases & design helpers --------------------

# def raised_cosine_basis(n_basis, t_max, dt, t_min=0.0, log_spaced=True, eps=1e-3):
#     """
#     Causal raised-cosine basis functions that tile [t_min, t_max].
#     Columns are normalized to unit area (sum * dt = 1).
#     Returns: lags (L,), B (L x K)
#     """
#     lags = np.arange(0, t_max + dt, dt)

#     def warp(x):  # denser near 0 if log_spaced
#         return np.log(x + eps) if log_spaced else x

#     W = warp(lags)
#     c = np.linspace(warp(t_min + eps), warp(t_max), n_basis)
#     w = (c[1] - c[0]) if n_basis > 1 else (warp(t_max) - warp(t_min) + 1e-6)

#     B = []
#     for ci in c:
#         arg = (W - ci) * np.pi / w
#         b = 0.5 * (1 + np.cos(np.clip(arg, -np.pi, np.pi)))
#         b[(W < ci - w) | (W > ci + w)] = 0.0
#         B.append(b)
#     B = np.stack(B, axis=1)
#     B /= (B.sum(axis=0, keepdims=True) * dt + 1e-12)  # unit area
#     return lags, B

# def convolve_causal(x, k):
#     """
#     Causal convolution of time series x with kernel k (defined on nonnegative lags).
#     Output length matches x (truncate 'full' at T).
#     """
#     return fftconvolve(x, k, mode='full')[:len(x)]

# def build_design_for_trial(trial, dt, bases_cfg):
#     """
#     trial: dict with keys:
#       - 'y': counts per bin (T,)
#       - optional continuous series: e.g., 'speed', 'dist', ... each (T,)
#       - optional events as binary series: e.g., 'go', 'flash', ... each (T,)
#       - optional angle series in radians: e.g., 'heading' (T,)
#     bases_cfg: dict describing which features get which bases.
#       Example:
#         bases_cfg = {
#            'speed':  {'kind':'cont',  't_max':0.30, 'K':6, 'log':True},
#            'go':     {'kind':'event', 't_max':0.50, 'K':6, 'log':True},
#            'hist':   {'kind':'hist',  't_max_short':0.025, 'K_short':3,
#                                     't_max_long':0.250, 'K_long':5},
#            'heading':{'kind':'angle'}  # instant sin/cos (no lags)
#         }
#     Returns: X_trial (T x P), colnames (list of str)
#     """
#     y = trial['y']
#     T = len(y)
#     X_cols = []
#     names = []

#     # 1) Continuous covariates with temporal filters (project onto basis)
#     for key, cfg in bases_cfg.items():
#         if cfg.get('kind') == 'cont' and key in trial:
#             lags, B = raised_cosine_basis(cfg['K'], cfg['t_max'], dt,
#                                           t_min=0.0, log_spaced=cfg.get('log', True))
#             for k in range(B.shape[1]):
#                 X_cols.append(convolve_causal(trial[key], B[:, k]))
#                 names.append(f"{key}_rc{k+1}")
#     # 2) Event covariates (binary impulses convolved with basis)
#     for key, cfg in bases_cfg.items():
#         if cfg.get('kind') == 'event' and key in trial:
#             lags, B = raised_cosine_basis(cfg['K'], cfg['t_max'], dt,
#                                           t_min=0.0, log_spaced=cfg.get('log', True))
#             for k in range(B.shape[1]):
#                 X_cols.append(convolve_causal(trial[key], B[:, k]))
#                 names.append(f"{key}_rc{k+1}")

#     # 3) Spike history: fine 0–t_max_short and coarse t_max_short–t_max_long
#     if 'hist' in bases_cfg:
#         cfg = bases_cfg['hist']

#         # Full lag grid up to the long window (for padding)
#         l_full = np.arange(0, cfg['t_max_long'] + dt, dt)
#         L_full = len(l_full)

#         # Short window bases (0 .. t_max_short), linear spacing
#         l1, B1 = raised_cosine_basis(
#             n_basis=cfg['K_short'],
#             t_max=cfg['t_max_short'],
#             dt=dt,
#             t_min=0.0,
#             log_spaced=False
#         )

#         # Long window bases (t_max_short .. t_max_long), log spacing
#         l2, B2 = raised_cosine_basis(
#             n_basis=cfg['K_long'],
#             t_max=cfg['t_max_long'],
#             dt=dt,
#             t_min=cfg['t_max_short'],
#             log_spaced=True
#         )
#         # B2 is already defined on [0 .. t_max_long] (zeros before t_min),
#         # so its number of rows should be L_full. B1 has fewer rows; pad it.

#         # Zero-pad B1 to the full length
#         pad_B1 = np.zeros((L_full, B1.shape[1]))
#         pad_B1[:len(l1), :] = B1

#         # Make sure B2 has the same number of rows (guard small rounding diffs)
#         if len(l2) != L_full:
#             pad_B2 = np.zeros((L_full, B2.shape[1]))
#             L = min(L_full, len(l2))
#             pad_B2[:L, :] = B2[:L, :]
#             Bhist = np.hstack([pad_B1, pad_B2])  # columns = K_short + K_long
#         else:
#             Bhist = np.hstack([pad_B1, B2])

#         # Convolve trial spikes with each history basis column (causal)
#         for k in range(Bhist.shape[1]):
#             X_cols.append(convolve_causal(y, Bhist[:, k]))
#             names.append(f"hist_rc{k+1}")


#     # 4) Instantaneous angle features (no lags): sin/cos for circular variables
#     for key, cfg in bases_cfg.items():
#         if cfg.get('kind') == 'angle' and key in trial:
#             ang = trial[key]
#             X_cols.append(np.sin(ang)); names.append(f"{key}_sin")
#             X_cols.append(np.cos(ang)); names.append(f"{key}_cos")

#     if not X_cols:
#         X = np.zeros((T, 0))
#     else:
#         X = np.column_stack(X_cols)
#     return X, names

# # -------------------- Metrics --------------------

# def loglik_poisson(y, mu):
#     """Sum log-likelihood for Poisson counts with mean mu (counts/bin)."""
#     mu = np.clip(mu, 1e-12, None)
#     return float(np.sum(y * np.log(mu) - mu - gammaln(y + 1)))

# def bits_per_spike(y, mu, mu0=None):
#     """
#     Bits/spike relative to a baseline (mu0 = homogeneous mean if None).
#     y, mu, mu0 are in counts/bin.
#     """
#     if mu0 is None:
#         mu0 = np.full_like(y, y.sum() / len(y))
#     LLm = loglik_poisson(y, mu)
#     LL0 = loglik_poisson(y, mu0)
#     return (LLm - LL0) / (y.sum() * np.log(2) + 1e-12)

# def psth_from_trials(trials_counts, dt, smooth_kernel=None):
#     """
#     Average counts across trials, optional smoothing on counts,
#     then convert to Hz (divide by dt).
#     """
#     M = np.mean(np.stack(trials_counts), axis=0)  # counts/bin
#     if smooth_kernel is not None:
#         M = np.convolve(M, smooth_kernel, mode='same')
#     return M / dt  # Hz

# def split_half_sb(test_trials_counts, dt, n_splits=200, rng=0, center=True, smooth_kernel=None):
#     rng = np.random.default_rng(rng)
#     counts = list(test_trials_counts)
#     rsb = []
#     for _ in range(n_splits):
#         idx = rng.permutation(len(counts))
#         A = [counts[i] for i in idx[::2]]
#         B = [counts[i] for i in idx[1::2]]
#         if len(A) == 0 or len(B) == 0:
#             continue
#         pA = psth_from_trials(A, dt, smooth_kernel)  # << same smoothing as eval
#         pB = psth_from_trials(B, dt, smooth_kernel)
#         if center:
#             pA -= pA.mean(); pB -= pB.mean()
#         r = np.dot(pA, pB) / (np.linalg.norm(pA)*np.linalg.norm(pB) + 1e-12)
#         rsb.append(2*r / (1 + r))  # Spearman–Brown
#     return float(np.mean(rsb)) if rsb else np.nan

# # -------------------- Cross-validated GLM fit --------------------

# def fit_glm_cv(trials, dt, bases_cfg, alphas=np.logspace(-4, 1, 8),
#                n_splits=5, shuffle_trials=True, random_state=0):
#     """
#     trials: list of per-trial dicts. Each must contain:
#         'y'  (counts, shape T,)
#       Optional per-trial arrays of same length T:
#         e.g., 'speed', 'heading', 'go', ...
#     Returns dict with model, scalers, column names, and CV metrics.
#     """
#     # Build per-trial design matrices
#     X_trials, y_trials = [], []
#     for tr in trials:
#         Xtr, _names = build_design_for_trial(tr, dt, bases_cfg)
#         X_trials.append(Xtr); y_trials.append(tr['y'])

#     # CV split by trials (NOT by time-bins)
#     kf = KFold(n_splits=n_splits, shuffle=shuffle_trials, random_state=random_state)

#     # Hyperparameter search (ridge strength)
#     best_alpha, best_ll = None, -np.inf
#     for a in alphas:
#         ll_sum = 0.0
#         for tr_idx, te_idx in kf.split(X_trials):
#             # Assemble train arrays
#             Xtr = np.vstack([X_trials[i] for i in tr_idx])
#             ytr = np.concatenate([y_trials[i] for i in tr_idx])

#             # Fit scaler on TRAIN ONLY
#             scaler = StandardScaler(with_mean=True, with_std=True)
#             Xtrz = scaler.fit_transform(Xtr)

#             # Fit Poisson GLM (ridge)
#             model = PoissonRegressor(alpha=a, max_iter=5000, fit_intercept=True)
#             model.fit(Xtrz, ytr)

#             # Evaluate on TEST trials
#             ll_fold = 0.0
#             for i in te_idx:
#                 Xte = X_trials[i]; yte = y_trials[i]
#                 Xtez = scaler.transform(Xte)
#                 mu = model.predict(Xtez)  # mean counts/bin
#                 ll_fold += loglik_poisson(yte, mu)
#             ll_sum += ll_fold
#         if ll_sum > best_ll:
#             best_ll, best_alpha = ll_sum, a

#     # Refit on ALL trials with best alpha
#     X_all = np.vstack(X_trials)
#     y_all = np.concatenate(y_trials)
#     scaler = StandardScaler(with_mean=True, with_std=True)
#     X_all_z = scaler.fit_transform(X_all)
#     final_model = PoissonRegressor(alpha=best_alpha, max_iter=5000, fit_intercept=True)
#     final_model.fit(X_all_z, y_all)

#     return {
#         "model": final_model,
#         "scaler": scaler,
#         "X_trials": X_trials,
#         "y_trials": y_trials,
#         "colnames": _names,  # from the last built trial (same across trials)
#         "dt": dt,
#         "best_alpha": best_alpha
#     }

# # -------------------- Evaluation on held-out trials --------------------

# def evaluate_on_test(train_trials, test_trials, dt, bases_cfg, model_bundle,
#                      smooth_sigma_ms=None):
#     """
#     Build PSTHs & metrics on held-out TEST trials.
#     Optionally smooth counts (both data & model) with Gaussian kernel (σ in ms).
#     """
#     # Gaussian kernel (on counts) that preserves units (sum*dt = 1)
#     smooth_kernel = None
#     if smooth_sigma_ms is not None:
#         sigma_s = smooth_sigma_ms / 1000.0
#         sigma_bins = max(1, int(round(sigma_s / dt)))
#         n = np.arange(-5*sigma_bins, 5*sigma_bins + 1)
#         g = np.exp(-0.5 * (n / sigma_bins)**2)
#         g /= (g.sum() * dt)
#         smooth_kernel = g

#     # Build per-trial designs for TEST
#     X_te, y_te = [], []
#     for tr in test_trials:
#         X, _ = build_design_for_trial(tr, dt, bases_cfg)
#         X_te.append(X); y_te.append(tr['y'])

#     # Predict per test trial
#     model = model_bundle['model']
#     scaler = model_bundle['scaler']
#     mu_te = []
#     for X, y in zip(X_te, y_te):
#         mu = model.predict(scaler.transform(X))  # counts/bin
#         mu_te.append(mu)

#     # Metrics: held-out LL and bits/spike (concatenated across test trials)
#     y_cat  = np.concatenate(y_te)
#     mu_cat = np.concatenate(mu_te)
#     bps = bits_per_spike(y_cat, mu_cat)  # vs homogeneous baseline on test
#     LL  = loglik_poisson(y_cat, mu_cat)

#     # PSTH correlation and ceiling-normalized score on TEST
#     psth_data  = psth_from_trials(y_te, dt, smooth_kernel)
#     psth_model = psth_from_trials(mu_te, dt, smooth_kernel)
#     # mean-center before correlation if you care about *shape*
#     r_model = np.corrcoef(psth_model - psth_model.mean(),
#                           psth_data  - psth_data.mean())[0, 1]
#     r_sb = split_half_sb(y_te, dt, n_splits=200, rng=1, center=True)
#     r_ceiling = np.sqrt(max(r_sb, 0.0))
#     r_norm = np.clip(r_model / (r_ceiling + 1e-12), 0, 1) if np.isfinite(r_ceiling) else np.nan

#     return {
#         "LL_test": LL,
#         "bits_per_spike_test": bps,
#         "psth_corr_test": r_model,
#         "psth_ceiling_rsb": r_sb,
#         "psth_ceiling_sqrt": r_ceiling,
#         "psth_corr_normalized": r_norm,
#         "psth_model_Hz": psth_model,
#         "psth_data_Hz": psth_data
#     }


### data functions

In [None]:
# import numpy as np

# def _exp_kernel(tau_s, dt, L_mult=5, unit_area=True):
#     """One-sided exponential kernel k(t) ~ exp(-t/tau), t>=0."""
#     L = max(1, int(round(L_mult * tau_s / dt)))
#     k = np.exp(-np.arange(L) * dt / tau_s)
#     if unit_area:
#         # Normalize so sum(k)*dt = 1 → unit area (stable across dt)
#         k /= (k.sum() * dt)
#     return k

# def _causal_conv(x, k):
#     """Causal convolution (truncate to len(x))."""
#     from scipy.signal import fftconvolve
#     return fftconvolve(x, k, mode='full')[:len(x)]

# def make_synthetic_trials(
#     n_conditions=3,
#     train_repeats_per_cond=10,
#     test_repeats_Cstar=4,
#     C_star=0,
#     T_s=2.0,       # trial length (s)
#     dt=0.01,       # bin size (s) → 10 ms
#     seed=0,
#     # safety knobs:
#     target_peak_hz=30.0,   # try to keep peak expected rate (no-history) ≤ this
#     clip_peak_hz=80.0      # hard cap on expected rate during sampling
# ):
#     """
#     Returns:
#       train_trials: list of dicts (keys: 'y','speed','heading','go','condition','cond_angle')
#       test_trials : same keys (held-out repeats of condition C_star only)
#     """
#     rng = np.random.default_rng(seed)
#     T = int(round(T_s / dt))
#     t = np.arange(T) * dt

#     # --- true underlying filters/weights (used to simulate spikes) ---
#     # Kernels (unit-area so weights are gains)
#     k_evt   = _exp_kernel(tau_s=0.12, dt=dt)   # 120 ms event kernel (unit area)
#     k_speed = _exp_kernel(tau_s=0.15, dt=dt)   # 150 ms speed kernel (unit area)
#     k_hist  = _exp_kernel(tau_s=0.02, dt=dt)   # history kernel (unit area)

#     # Coefficients (log-link)
#     # log μ = b0 + β_evt*(go⊛k_evt) + β_spd*(speed⊛k_speed)
#     #         + β_sin*sin(heading) + β_cos*cos(heading) + hist_gain*(y⊛k_hist)
#     b0         = np.log(0.03)   # ~3 Hz baseline at dt=0.01
#     beta_evt   = 0.12           # event adds ≈ 1.0 to log-rate at onset (unit-area kernel)
#     beta_speed = 0.01           # modest effect of smoothed speed
#     beta_sin   = 0.20
#     beta_cos   = -0.10
#     hist_gain  = -1.0           # inhibitory history

#     # Condition angles (e.g., target directions)
#     cond_angles = np.linspace(0, 2*np.pi, n_conditions, endpoint=False)

#     train_trials, test_trials = [], []

#     def simulate_one_trial(cond_idx):
#         phi = cond_angles[cond_idx]  # condition angle

#         # Event binary series (aligned near t=0.2 s)
#         go = np.zeros(T)
#         go[int(0.2 / dt)] = 1.0

#         # Speed: baseline + event-driven bump + colored noise
#         speed_base = 10.0 + 2.0 * rng.normal()
#         bump = 6.0 * _causal_conv(go, _exp_kernel(0.4, dt))               # slower bump
#         noise = _causal_conv(rng.normal(0, 0.20, size=T), _exp_kernel(0.05, dt))
#         speed = np.maximum(0.0, speed_base + bump + noise)                 # cm/s

#         # Heading: around the condition angle with slow jitter (wrap to [-π, π])
#         heading = phi + _causal_conv(rng.normal(0, 0.10, size=T), _exp_kernel(0.2, dt))
#         heading = (heading + np.pi) % (2*np.pi) - np.pi

#         # Covariate-driven part of the linear predictor (no history)
#         x_evt   = _causal_conv(go,    k_evt)
#         x_speed = _causal_conv(speed, k_speed)
#         lin_nohist = (b0
#                       + beta_evt   * x_evt
#                       + beta_speed * x_speed
#                       + beta_sin   * np.sin(heading)
#                       + beta_cos   * np.cos(heading))

#         # ---------- AUTO-CALIBRATE PEAK so expected counts/bin don't explode ----------
#         # Target peak expected counts/bin (no-history)
#         mu_target = max(1e-6, target_peak_hz * dt)   # e.g., 30 Hz @ 10ms → 0.30
#         mu_peak_nh = float(np.exp(np.max(lin_nohist)))
#         if mu_peak_nh > mu_target:
#             shift = np.log(mu_target) - np.log(mu_peak_nh)
#             lin_nohist = lin_nohist + shift   # equivalent to lowering b0 for this trial

#         # ---------- Simulate spikes with inhibitory history ----------
#         y = np.zeros(T, dtype=int)
#         mu_clip = max(1e-6, clip_peak_hz * dt)       # final backstop, e.g., 80 Hz @ 10ms → 0.80

#         for tt in range(T):
#             # history term from past spikes only (exclude current bin)
#             Lh = min(tt, len(k_hist) - 1)
#             if Lh > 0:
#                 hist_term = hist_gain * np.dot(y[tt-Lh:tt][::-1], k_hist[1:Lh+1])
#             else:
#                 hist_term = 0.0

#             eta = lin_nohist[tt] + hist_term         # log(mean counts per bin)
#             mu  = np.exp(eta)
#             # Hard cap as a safety backstop (should rarely activate after the shift)
#             if mu > mu_clip:
#                 mu = mu_clip
#             y[tt] = rng.poisson(mu)

#         return {
#             'y': y,
#             'speed': speed,
#             'heading': heading,
#             'go': go,
#             'condition': int(cond_idx),
#             'cond_angle': float(phi)
#         }

#     # Build trials for each condition
#     for c in range(n_conditions):
#         for _ in range(train_repeats_per_cond):
#             train_trials.append(simulate_one_trial(c))
#         if c == C_star:
#             for _ in range(test_repeats_Cstar):
#                 test_trials.append(simulate_one_trial(c))

#     rng.shuffle(train_trials)
#     rng.shuffle(test_trials)
#     return train_trials, test_trials


### get trials

In [None]:
# Create synthetic data: 3 conditions; hold out 4 repeats of condition 0 for Test.
train_trials, test_trials = make_synthetic_trials(
    n_conditions=3, train_repeats_per_cond=12, test_repeats_Cstar=4,
    C_star=0, T_s=2.0, dt=0.01, seed=1
)

# Inspect shapes quickly
T = len(train_trials[0]['y'])
print(f"Train trials: {len(train_trials)} | Test trials: {len(test_trials)} | T bins per trial: {T}")



In [None]:
train_trials[0]

### fit model

In [None]:
bases_cfg = {
    'speed':   {'kind':'cont',  't_max':0.30, 'K':6, 'log':True},
    'go':      {'kind':'event', 't_max':0.50, 'K':6, 'log':True},
    'hist':    {'kind':'hist',  't_max_short':0.025, 'K_short':3,
                                't_max_long':0.250,  'K_long':5},
    'heading': {'kind':'angle'}
}
bundle = fit_glm_cv(train_trials, dt=0.01, bases_cfg=bases_cfg)
results = evaluate_on_test(train_trials, test_trials, 0.01, bases_cfg, bundle, smooth_sigma_ms=40)
print({k:v for k,v in results.items() if not hasattr(v, '__len__')})

In [None]:
print("Test trials:", len(test_trials))
print("Test spikes (sum y):", int(np.sum(np.concatenate([tr['y'] for tr in test_trials]))))
print("Bins per trial:", len(test_trials[0]['y']))


# GLM functions

## individual steps

### categorize variables

In [None]:
pgam_inst = pgam_class.PGAMclass(pn.x_var, pn.y_var_reduced, pn.bin_width, pn.processed_neural_data_folder_path)

In [None]:
pgam_inst.prepare_for_pgam(num_total_trials=10)

### temporal kernel

modified from PGAM_Tutorial.ipynb

In [None]:
pgam_inst._add_temporal_features_to_model()

In [None]:
help(gdh.smooths_handler.add_smooth)

### spatial variable

In [None]:
pgam_inst._add_spatial_features_to_model()

### run

In [None]:
pgam_inst.run_pgam(neural_cluster_number=10)

### post-processing

In [None]:
pgam_inst.post_processing()

### save results

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import BSpline

# Degree of spline
k = 3

# Example 1: Not enough start knots
knots1 = [0, 2, 4, 6, 8, 10, 10, 10, 10]
coeffs = np.ones(len(knots1) - k - 1)
spline1 = BSpline(knots1, coeffs, k)

# Example 2: Proper repeated start knots
knots2 = [0, 0, 0, 0, 2, 4, 6, 8, 10, 10, 10, 10]
coeffs2 = np.ones(len(knots2) - k - 1)
spline2 = BSpline(knots2, coeffs2, k)

x = np.linspace(-1, 11, 400)

plt.plot(x, spline1(x), label="Start knot once")
plt.plot(x, spline2(x), label="Start knot repeated")
plt.axvline(0, color='gray', linestyle='--', label="x = 0")
plt.legend()
plt.title("Effect of Repeating Start Knot")
plt.xlabel("x")
plt.ylabel("Spline Value")
plt.grid(True)
plt.show()


In [None]:
pgam_inst.save_results()

In [None]:
stop!

## iterate through all neurons

In [None]:
pgam_inst = pgam_class.PGAMclass(pn.x_var, pn.y_var, pn.bin_width, pn.processed_neural_data_folder_path)

In [None]:
for i in range(pn.x_var.shape[1]):
    print(f'neural_cluster_number: {i} out of {pn.x_var.shape[1]}')
    pgam_inst.streamline_pgam(neural_cluster_number=i, num_total_trials=10)