# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break

from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import population_analysis_utils, one_ff_data_processing, parameters, one_ff_pipeline, one_ff_glm_design
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam import plot_gam_fit
from neural_data_analysis.neural_analysis_tools.glm_tools.tpg import glm_bases
from neural_data_analysis.design_kits.design_by_segment import temporal_feats, spatial_feats
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam import one_ff_gam_fit, assemble_one_ff_gam_design, penalty_tuning, backward_elimination
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam.one_ff_gam_fit import GroupSpec, fit_poisson_gam_map


import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats
import json

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

pd.set_option('display.max_colwidth', 200)




# Build design

In [None]:
unit_idx = 2

In [None]:
reload(assemble_one_ff_gam_design)

In [None]:
covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor',
]

prs = parameters.default_prs()
data_obj = one_ff_pipeline.OneFFSessionData(
    mat_path='all_monkey_data/one_ff_data/sessions_python.mat',
    prs=prs, 
    session_num=0,
)

covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor'
]

# preprocessing
data_obj.compute_covariates(covariate_names)
data_obj.compute_spike_counts()
data_obj.smooth_spikes()
data_obj.compute_events()



linear_vars = [
    'v', 'w', 'd', 'r_targ',
    'eye_ver', 'eye_hor',
]

angular_vars = [
    'phi', 'theta_targ',
]

# Build design (in class)
# build once
temporal_df, temporal_meta, specs_meta = assemble_one_ff_gam_design.build_temporal_design_base(data_obj)
X_tuning, tuning_meta = assemble_one_ff_gam_design.build_tuning_design(data_obj.data_df, linear_vars, angular_vars, 
                                                                        binrange_dict=data_obj.prs.binrange)


# per-unit
design_df, y, groups, all_meta = assemble_one_ff_gam_design.process_unit_design_and_groups(
    unit_idx=unit_idx,
    data_obj=data_obj,
    temporal_df=temporal_df,
    temporal_meta=temporal_meta,
    X_tuning=X_tuning,
    tuning_meta=tuning_meta,
    specs_meta=specs_meta,
    #coupling_units=[1, 3, 7],  # optional
)


# Run model

In [None]:
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)
(outdir / 'fit_results').mkdir(parents=True, exist_ok=True)

lam_suffix = one_ff_gam_fit.generate_lambda_suffix(groups)
save_path = outdir / 'fit_results' / f'{lam_suffix}.pkl'



# ----------------------------
# 4) Fit the MAP Poisson GAM
# ----------------------------
fit_res = fit_poisson_gam_map(
    design_df=design_df,
    y=y,
    groups=groups,
    l1_groups=[],
    max_iter=200,
    tol=1e-6,
    verbose=True,
    save_path=save_path,
)

print('success:', fit_res.success)
print('message:', fit_res.message)
print('n_iter:', fit_res.n_iter)
print('final objective:', fit_res.fun)
print('grad_norm:', fit_res.grad_norm)

beta = fit_res.coef  # pd.Series indexed by design_df columns

# ----------------------------
# 5) (Optional) Recover tuning curves / kernels in the same spirit as their MATLAB
# ----------------------------
# Helper to get weights for a group
def get_group_beta(beta: pd.Series, cols: list) -> np.ndarray:
    cols_present = [c for c in cols if c in beta.index]
    return beta.loc[cols_present].to_numpy()

temporal_groups = temporal_meta['groups']  # {'t_targ': [...], 't_move': [...], 't_rew': [...], 'spike_hist': [...]}
# Example: event kernels (weights live in basis space)
beta_t_targ = get_group_beta(beta, temporal_groups['t_targ'])
beta_t_move = get_group_beta(beta, temporal_groups['t_move'])
beta_t_rew = get_group_beta(beta, temporal_groups['t_rew'])
beta_hist = get_group_beta(beta, all_meta['hist']['groups']['spike_hist'])

# Example: tuning weights (boxcar/Fourier weights)
beta_v = get_group_beta(beta, tuning_meta['groups']['v'])
beta_phi = get_group_beta(beta, tuning_meta['groups']['phi'])

In [None]:
fit_res.message

In [None]:
fit_res.n_iter 

In [None]:
fit_res.grad_norm  

## debug

In [None]:
# Convert to numpy
X = design_df.to_numpy(dtype=float)
y_array = np.asarray(y, dtype=float).ravel()

# Initialize beta0 the same way the function does
rng = np.random.default_rng(0)
beta0 = 1e-3 * rng.standard_normal(X.shape[1])
if 'const' in design_df.columns:
    beta0[design_df.columns.get_loc('const')] = np.log(max(y_array.mean(), 1e-8))

# Run diagnostics
u0 = X @ beta0
print("=" * 80)
print("PRE-FIT DIAGNOSTICS")
print("=" * 80)
print(f"Design matrix shape: {X.shape}")
print(f"X range: [{X.min():.2e}, {X.max():.2e}]")
print(f"y range: [{y_array.min():.2e}, {y_array.max():.2e}]")
print(f"y mean: {y_array.mean():.2e}, y sum: {y_array.sum():.2e}")
print(f"Initial beta range: [{beta0.min():.2e}, {beta0.max():.2e}]")
print(f"Initial u = X @ beta0 range: [{u0.min():.2e}, {u0.max():.2e}]")
print(f"Initial rate = exp(u) range: [{np.exp(u0.min()):.2e}, {np.exp(u0.max()):.2e}]")
print("=" * 80)


## plot

In [None]:
dt = prs.dt

In [None]:
import numpy as np
import matplotlib.pyplot as plt

beta0 = beta['const']
baseline_rate = np.exp(beta0) / dt

print(f'Baseline firing rate: {baseline_rate:.2f} Hz')

In [None]:
for var in tuning_meta['linear_vars']:
    plot_gam_fit.plot_linear_tuning(var, beta, tuning_meta)

In [None]:
for var in tuning_meta['angular_vars']:
    plot_gam_fit.plot_angular_tuning(var, beta, tuning_meta)

In [None]:

# Plot anything!
plot_gam_fit.plot_variable('v', beta, all_meta)           # tuning
plot_gam_fit.plot_variable('t_move', beta, all_meta)      # event
plot_gam_fit.plot_variable('spike_hist', beta, all_meta)  # history

# Or plot everything
plot_gam_fit.plot_all_tuning_curves(beta, all_meta)
plot_gam_fit.plot_all_temporal_filters(beta, all_meta)

# backward elimination

In [None]:
# import pickle
# save_path = 'all_monkey_data/one_ff_data/my_gam_results/neuron_2/kept_groups.pkl'
# with open(save_path, 'rb') as f:
#     saved_data = pickle.load(f)


In [None]:
# history_path = 'all_monkey_data/one_ff_data/my_gam_results/neuron_2/history.csv'
# history = pd.read_csv(history_path)

In [None]:
# Setup output directory and paths
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)

# Generate descriptive filename with lambda configuration
lam_suffix = one_ff_gam_fit.generate_lambda_suffix(groups)
save_path = outdir / 'backward_elimination' / f'{lam_suffix}.pkl'
#save_path = outdir / 'kept_groups.pkl'

kept, history = backward_elimination.backward_elimination_gam(
    design_df=design_df,
    y=y,
    groups=groups,
    alpha=0.05,
    n_folds=10,
    verbose=True,
    save_path=str(save_path),
)

print('\nFinal retained variables:')
for g in kept:
    print(' ', g.name)

# penalty tuning

In [None]:
outdir = Path(f'all_monkey_data/one_ff_data/my_gam_results/neuron_{unit_idx}')
outdir.mkdir(parents=True, exist_ok=True)

l1_groups = []  # coupling Laplace prior can go here later


lam_grid = {
    'lam_f': [10, 50, 100, 300],
    'lam_g': [1, 5, 10, 30],
    'lam_h': [1, 5, 10, 30],
}

group_name_map = {
    'lam_f': list(tuning_meta['groups'].keys()),
    'lam_g': ['t_targ', 't_move', 't_rew'],
    'lam_h': ['spike_hist'],
}

best_lams, cv_results = penalty_tuning.tune_penalties(
    design_df=design_df,
    y=y,
    base_groups=groups,
    l1_groups=l1_groups,
    lam_grid=lam_grid,
    group_name_map=group_name_map,
    n_folds=5,
    save_path=outdir / 'penalty_tuning.pkl',
    retrieve_only=True,
)

print('Best lambdas:', best_lams)



In [None]:
groups

In [None]:
cv_results

## Refit final model with best penalties

In [None]:


final_groups = penalty_tuning.clone_groups_with_lams(groups, {
    gname: best_lams['lam_f'] for gname in group_name_map['lam_f']
} | {
    gname: best_lams['lam_g'] for gname in group_name_map['lam_g']
} | {
    gname: best_lams['lam_h'] for gname in group_name_map['lam_h']
})

fit_res = fit_poisson_gam_map(
    design_df=design_df,
    y=y,
    groups=final_groups,
    l1_groups=l1_groups,
    max_iter=200,
    tol=1e-6,
    verbose=True,
)

# see one trial

In [None]:
prs = parameters.default_prs()
data_obj = one_ff_pipeline.OneFFSessionData(
    mat_path='all_monkey_data/one_ff_data/sessions_python.mat',
    prs=prs, 
    session_num=0,
)

data_obj._load_data()

In [None]:
# indices
session_num = 0
trial_num = 0

# params
sessions = data_obj.sessions
prs = parameters.default_prs()

# session / behaviour
session = sessions[session_num]
behaviour = session.behaviour

# trials / stats
all_trials = behaviour.trials
all_stats = behaviour.stats
trial_ids = np.arange(len(all_trials))

trial = all_trials[trial_num]
stats = all_stats[trial_num]
pos_rel = stats.pos_rel

# continuous data
continuous = trial.continuous
print(continuous._fieldnames)

x = continuous.xmp
y = continuous.ymp
v = continuous.v
w = continuous.w
t = continuous.ts

# time step
prs.dt = round(np.mean(np.diff(t)), 5)

## verify compute_all_covariates

In [None]:
covariates = one_ff_data_processing.compute_all_covariates(trial, prs.dt)


In [None]:
import numpy as np

are_close = np.allclose(pos_rel.r_targ, covariates['r_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

are_close = np.allclose(pos_rel.theta_targ, covariates['theta_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

In [None]:
covariates.keys()

In [None]:
for k, v in covariates.items():
    print(k, v.shape, np.nanmin(v), np.nanmax(v))


# try script

In [None]:
!python multiff_analysis/jobs/one_ff/scripts/one_ff_back_elim_script.py --unit_idx 1

In [None]:
!python multiff_analysis/jobs/one_ff/scripts/one_ff_pen_tune_script.py --unit_idx 1