# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import population_analysis_utils, one_ff_data_processing, parameters, one_ff_pipeline, one_ff_glm_design
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam import plot_gam_fit
from neural_data_analysis.neural_analysis_tools.glm_tools.tpg import glm_bases
from neural_data_analysis.design_kits.design_by_segment import temporal_feats, spatial_feats
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam import one_ff_gam_fit, assemble_one_ff_gam_design, penalty_tuning, backward_elimination

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats
import json

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

pd.set_option('display.max_colwidth', 200)




# Build design

In [None]:
unit_idx = 4

In [None]:
reload(assemble_one_ff_gam_design)

In [None]:
covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor',
]

prs = parameters.default_prs()
data_obj = one_ff_pipeline.OneFFSessionData(
    mat_path='all_monkey_data/one_ff_data/sessions_python.mat',
    prs=prs, 
    session_num=0,
)

covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor'
]

# preprocessing
data_obj.compute_covariates(covariate_names)
data_obj.compute_spike_counts()
data_obj.smooth_spikes()
data_obj.compute_events()



linear_vars = [
    'v', 'w', 'd', 'r_targ',
    'eye_ver', 'eye_hor',
]

angular_vars = [
    'phi', 'theta_targ',
]

# Build design (in class)
# build once
temporal_df, temporal_meta, specs_meta = assemble_one_ff_gam_design.build_temporal_design_base(data_obj)
X_tuning, tuning_meta = assemble_one_ff_gam_design.build_tuning_design(data_obj.data_df, linear_vars, angular_vars, 
                                                                        binrange_dict=data_obj.prs.binrange)


# per-unit
design_df, groups, all_meta = assemble_one_ff_gam_design.process_unit_design_and_groups(
    unit_idx=unit_idx,
    data_obj=data_obj,
    temporal_df=temporal_df,
    temporal_meta=temporal_meta,
    X_tuning=X_tuning,
    tuning_meta=tuning_meta,
    specs_meta=specs_meta,
    #coupling_units=[1, 3, 7],  # optional
)

y = assemble_one_ff_gam_design.extract_response(
        unit_idx=unit_idx,
        data_obj=data_obj,
        design_df=design_df,
        temporal_meta=temporal_meta,
    )

In [None]:
# pd.set_option('display.max_rows', None)
# design_df.describe().T

# Run model

In [None]:
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)
(outdir / 'fit_results').mkdir(parents=True, exist_ok=True)

lam_suffix = one_ff_gam_fit.generate_lambda_suffix(groups)
save_path = outdir / 'fit_results' / f'{lam_suffix}.pkl'

# ----------------------------
# 4) Fit the MAP Poisson GAM
# ----------------------------
fit_res = one_ff_gam_fit.fit_poisson_gam_map(
    design_df=design_df,
    y=y,
    groups=groups,
    l1_groups=[],
    max_iter=200,
    tol=1e-6,
    verbose=True,
    save_path=save_path,
)

print('success:', fit_res.success)
print('message:', fit_res.message)
print('n_iter:', fit_res.n_iter)
print('final objective:', fit_res.fun)
print('grad_norm:', fit_res.grad_norm)

  # pd.Series indexed by design_df columns


## plot

In [None]:
plot_gam_fit.plot_variables(fit_res.coef, all_meta, plot_gam_fit.plot_var_order)  

In [None]:
stop!

# penalty tuning

In [None]:
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)

l1_groups = []  # coupling Laplace prior can go here later


lam_grid = {
    'lam_f': [10, 50, 100, 300],
    'lam_g': [1, 5, 10, 30],
    'lam_h': [1, 5, 10, 30],
}

group_name_map = {
    'lam_f': list(all_meta['tuning']['groups'].keys()),
    'lam_g': list(all_meta['temporal']['groups'].keys()),
    'lam_h': list(all_meta['hist']['groups'].keys()),
}

best_lams, cv_results = penalty_tuning.tune_penalties(
    design_df=design_df,
    y=y,
    base_groups=groups,
    l1_groups=l1_groups,
    lam_grid=lam_grid,
    group_name_map=group_name_map,
    n_folds=5,
    save_path=outdir / 'penalty_tuning.pkl',
    retrieve_only=True,
)

print('Best lambdas:', best_lams)



## Refit final model with best penalties

In [None]:
final_groups = penalty_tuning.clone_groups_with_lams(groups, {
    gname: best_lams['lam_f'] for gname in group_name_map['lam_f']
} | {
    gname: best_lams['lam_g'] for gname in group_name_map['lam_g']
} | {
    gname: best_lams['lam_h'] for gname in group_name_map['lam_h']
})

lam_suffix = one_ff_gam_fit.generate_lambda_suffix(final_groups)
save_path = outdir / 'fit_results' / f'{lam_suffix}.pkl'


In [None]:
refit_res = one_ff_gam_fit.fit_poisson_gam_map(
    design_df=design_df,
    y=y,
    groups=final_groups,
    l1_groups=l1_groups,
    max_iter=200,
    tol=1e-6,
    verbose=True,
    save_path=save_path,
)

print('success:', fit_res.success)
print('message:', fit_res.message)
print('n_iter:', fit_res.n_iter)
print('final objective:', fit_res.fun)
print('grad_norm:', fit_res.grad_norm)


In [None]:
plot_gam_fit.plot_variables(refit_res.coef, all_meta, plot_gam_fit.plot_var_order)    

# backward elimination

In [None]:
# import pickle
# save_path = 'all_monkey_data/one_ff_data/my_gam_results/neuron_2/kept_groups.pkl'
# with open(save_path, 'rb') as f:
#     saved_data = pickle.load(f)


In [None]:
# history_path = 'all_monkey_data/one_ff_data/my_gam_results/neuron_2/history.csv'
# history = pd.read_csv(history_path)

In [None]:
# Setup output directory and paths
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)

# Generate descriptive filename with lambda configuration
lam_suffix = one_ff_gam_fit.generate_lambda_suffix(groups)
save_path = outdir / 'backward_elimination' / f'{lam_suffix}.pkl'
#save_path = outdir / 'kept_groups.pkl'

kept, history = backward_elimination.backward_elimination_gam(
    design_df=design_df,
    y=y,
    groups=groups,
    alpha=0.05,
    n_folds=10,
    verbose=True,
    save_path=str(save_path),
)

print('\nFinal retained variables:')
for g in kept:
    print(' ', g.name)

# see one trial

In [None]:
prs = parameters.default_prs()
data_obj = one_ff_pipeline.OneFFSessionData(
    mat_path='all_monkey_data/one_ff_data/sessions_python.mat',
    prs=prs, 
    session_num=0,
)

data_obj._load_data()

In [None]:
# indices
session_num = 0
trial_num = 0

# params
sessions = data_obj.sessions
prs = parameters.default_prs()

# session / behaviour
session = sessions[session_num]
behaviour = session.behaviour

# trials / stats
all_trials = behaviour.trials
all_stats = behaviour.stats
trial_ids = np.arange(len(all_trials))

trial = all_trials[trial_num]
stats = all_stats[trial_num]
pos_rel = stats.pos_rel

# continuous data
continuous = trial.continuous
print(continuous._fieldnames)

x = continuous.xmp
y = continuous.ymp
v = continuous.v
w = continuous.w
t = continuous.ts

# time step
prs.dt = round(np.mean(np.diff(t)), 5)

## verify compute_all_covariates

In [None]:
covariates = one_ff_data_processing.compute_all_covariates(trial, prs.dt)


In [None]:
import numpy as np

are_close = np.allclose(pos_rel.r_targ, covariates['r_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

are_close = np.allclose(pos_rel.theta_targ, covariates['theta_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

In [None]:
covariates.keys()

In [None]:
for k, v in covariates.items():
    print(k, v.shape, np.nanmin(v), np.nanmax(v))


# try script

In [None]:
# !python multiff_analysis/jobs/one_ff/scripts/one_ff_back_elim_script.py --unit_idx 1

In [None]:
# !python multiff_analysis/jobs/one_ff/scripts/one_ff_pen_tune_script.py --unit_idx 1

# compare with matlab design_df

## load matlab design

In [None]:
import scipy.io as sio
mat = sio.loadmat('all_monkey_data/one_ff_data/design_matrix.mat', struct_as_record=False, squeeze_me=True)
mat_design = mat['design_matrix_data']
mat_x = mat_design.x
mat_xname = mat_design.xname
mat_yt = mat_design.yt
mat_basis = mat_design.basis
mat_xtype = mat_design.xtype
mat_nprs = mat_design.nprs

In [None]:
import numpy as np
np.set_printoptions(
    threshold=200,   # no summarization
    linewidth=200,      # prevent line wrapping
)

## decide a variable

In [None]:
python_var_name = 't_targ'
matlab_var_name = 'target_OFF'

In [None]:
python_var_name = 't_rew'
matlab_var_name = 'reward'

In [None]:
python_var_name = 't_stop'
matlab_var_name = 'stop'

In [None]:
python_var_name = 'spike_hist'
matlab_var_name = 'spikehist'

In [None]:
python_var_name = 'v'
matlab_var_name = 'v'

In [None]:
python_var_name = 'theta_targ'
matlab_var_name = 'theta_targ'

In [None]:
python_var_name = 'phi'
matlab_var_name = 'phi'

In [None]:
mat_design.xname # matlabe variable names

In [None]:
tuning_meta['linear_vars']

In [None]:
matlab_var_idx = int(np.where(mat_design.xname == matlab_var_name)[0][0])
if 't_' in python_var_name:
    group_name = 'temporal'
elif 'hist' in python_var_name:
    group_name = 'hist'
else:
    group_name = 'tuning'
python_col_names = all_meta[group_name]['groups'][python_var_name]
python_design_df = design_df[python_col_names]
matlab_design_df = pd.DataFrame(mat_design.x[matlab_var_idx])

## raw data (events only)

In [None]:
if group_name == 'temporal':
    raw_df = {'python': data_obj.events[python_var_name], 'matlab': mat_design.xt[matlab_var_idx]}
    raw_df = pd.DataFrame(raw_df).astype(int)
    raw_df['diff'] = raw_df['python'] - raw_df['matlab']

    diff_idx = np.where(
        mat_design.xt[matlab_var_idx].astype(np.float32)
        != data_obj.events[python_var_name]
    )[0]

    print('len(diff_idx)', len(diff_idx))
    diff_idx[:10]

## compare bases

In [None]:
start_index = 0
num_rows = 20

In [None]:
if group_name != 'tuning':
    matlab_basis_df = pd.DataFrame(mat_design.basis[matlab_var_idx].y)
    python_basis_df = pd.DataFrame(all_meta[group_name]['groups'][python_var_name][0])
    print(matlab_basis_df.iloc[start_index:start_index+num_rows])
    print(python_basis_df.iloc[start_index:start_index+num_rows])
else:
    matlab_basis = np.array(mat_design.basis[matlab_var_idx].x)
    python_basis = np.array(all_meta[group_name]['bin_edges'][python_var_name])
    print('matlab_basis', matlab_basis)
    print('python_basis', python_basis)

## compare values

In [None]:
# start_index = 3080
start_index = 0
num_rows = 90
python_design_df.iloc[start_index:start_index+num_rows]

In [None]:
matlab_design_df.iloc[start_index:start_index+num_rows]

In [None]:
python_arr = np.array(python_design_df).astype(np.float32)
matlab_arr = np.array(matlab_design_df).astype(np.float32)

np.equal(python_arr, matlab_arr).all()
diff_mask = python_arr != matlab_arr
n_diff = np.sum(diff_mask)

print('Number of mismatched entries:', n_diff)

diff_indices = np.argwhere(diff_mask)
print(diff_indices[:30])  # show first 10 differences


## debug

In [None]:
start_index = 0
num_rows = 5

In [None]:
data_obj.covariates[python_var_name][start_index:start_index+num_rows]

In [None]:
python_design_df.loc[start_index:start_index+num_rows]

In [None]:
matlab_design_df.loc[start_index:start_index+num_rows]

## checking all tuning vars

In [None]:
stop!

In [None]:
def compare_python_matlab_design(
    design_df,
    mat_design,
    all_meta,
    var_list=None,
    start_index=0,
    num_rows=90,
    verbose=True
):
    """
    Compare Python and MATLAB design matrices variable-by-variable.

    Parameters
    ----------
    design_df : pandas.DataFrame
        Python design dataframe.
    mat_design : object
        MATLAB design structure with attributes:
            - xname (array of variable names)
            - x (array of design matrices)
    all_meta : dict
        Metadata dictionary containing grouping structure.
    var_list : list, optional
        Variables to compare. Defaults to standard list.
    start_index : int
        Row index to begin preview slice.
    num_rows : int
        Number of rows to preview.
    verbose : bool
        If True, print detailed comparison results.

    Returns
    -------
    results_dict : dict
        Dictionary containing mismatch statistics for each variable.
    """

    if var_list is None:
        var_list = [
            'v', 'w', 'd', 'phi',
            'r_targ', 'theta_targ',
            'eye_ver', 'eye_hor'
        ]

    results_dict = {}

    for var in var_list:

        if verbose:
            print('=' * 100)
            print(var)

        python_var_name = var
        matlab_var_name = var

        # Find MATLAB variable index
        matlab_var_idx = int(
            np.where(mat_design.xname == matlab_var_name)[0][0]
        )

        # Determine group name
        if 't_' in python_var_name:
            group_name = 'temporal'
        elif 'hist' in python_var_name:
            group_name = 'hist'
        else:
            group_name = 'tuning'

        # Extract columns
        python_col_names = all_meta[group_name]['groups'][python_var_name]
        python_design_df = design_df[python_col_names]
        matlab_design_df = pd.DataFrame(mat_design.x[matlab_var_idx])

        # Optional preview slice (not used in comparison)
        _ = python_design_df.iloc[start_index:start_index + num_rows]

        # Convert to arrays
        python_arr = np.array(python_design_df).astype(np.float32)
        matlab_arr = np.array(matlab_design_df).astype(np.float32)

        # Compare
        diff_mask = python_arr != matlab_arr
        n_diff = int(np.sum(diff_mask))
        diff_indices = np.argwhere(diff_mask)

        if verbose:
            print('Number of mismatched entries:', n_diff)
            if n_diff > 0:
                print('First 10 mismatches (row, col):')
                print(diff_indices[:10])

        results_dict[var] = {
            'n_mismatched_entries': n_diff,
            'diff_indices_first10': diff_indices[:10],
            'arrays_equal': n_diff == 0
        }

    return results_dict


results = compare_python_matlab_design(
    design_df=design_df,
    mat_design=mat_design,
    all_meta=all_meta
)

# Appendix

## debug non-convergence

In [None]:
# Convert to numpy
X = design_df.to_numpy(dtype=float)
y_array = np.asarray(y, dtype=float).ravel()

# Initialize beta0 the same way the function does
rng = np.random.default_rng(0)
beta0 = 1e-3 * rng.standard_normal(X.shape[1])
if 'const' in design_df.columns:
    beta0[design_df.columns.get_loc('const')] = np.log(max(y_array.mean(), 1e-8))

# Run diagnostics
u0 = X @ beta0
print("=" * 80)
print("PRE-FIT DIAGNOSTICS")
print("=" * 80)
print(f"Design matrix shape: {X.shape}")
print(f"X range: [{X.min():.2e}, {X.max():.2e}]")
print(f"y range: [{y_array.min():.2e}, {y_array.max():.2e}]")
print(f"y mean: {y_array.mean():.2e}, y sum: {y_array.sum():.2e}")
print(f"Initial fit_res.coef range: [{beta0.min():.2e}, {beta0.max():.2e}]")
print(f"Initial u = X @ beta0 range: [{u0.min():.2e}, {u0.max():.2e}]")
print(f"Initial rate = exp(u) range: [{np.exp(u0.min()):.2e}, {np.exp(u0.max()):.2e}]")
print("=" * 80)


## get weights of a group

In [None]:
def get_group_beta(fit_res.coef: pd.Series, cols: list) -> np.ndarray:
    cols_present = [c for c in cols if c in fit_res.coef.index]
    return fit_res.coef.loc[cols_present].to_numpy()

temporal_groups = temporal_meta['groups']  # {'t_targ': [...], 't_move': [...], 't_rew': [...], 'spike_hist': [...]}
# Example: event kernels (weights live in basis space)
beta_t_targ = get_group_beta(fit_res.coef, temporal_groups['t_targ'])
beta_t_move = get_group_beta(fit_res.coef, temporal_groups['t_move'])
beta_t_rew = get_group_beta(fit_res.coef, temporal_groups['t_rew'])
beta_hist = get_group_beta(fit_res.coef, all_meta['hist']['groups']['spike_hist'])

# Example: tuning weights (boxcar/Fourier weights)
beta_v = get_group_beta(fit_res.coef, tuning_meta['groups']['v'])
beta_phi = get_group_beta(fit_res.coef, tuning_meta['groups']['phi'])