# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    

from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.one_ff_gam import one_ff_pgam_design, compute_tuning
from neural_data_analysis.neural_analysis_tools.pgam_tools import pgam_class
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import population_analysis_utils, one_ff_data_processing, one_ff_pipeline, one_ff_parameters
 

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats
import json

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

pd.set_option('display.max_colwidth', 200)



pgam_path = '/Users/dusiyi/Documents/Multifirefly-Project/multiff_analysis/external/pgam/src/'
import sys
if not pgam_path in sys.path: 
    sys.path.append(pgam_path)
    
import numpy as np
import sys
from PGAM.GAM_library import *
import PGAM.gam_data_handlers as gdh
import matplotlib.pylab as plt
import pandas as pd
from post_processing import postprocess_results
from scipy.io import savemat


# Use class

In [None]:
prs = one_ff_parameters.default_prs()
data_obj = one_ff_pipeline.OneFFSessionData(
    mat_path='all_monkey_data/one_ff_data/sessions_python.mat',
    prs=prs, 
    session_num=0,
)

covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor',
]

# preprocessing
data_obj.compute_covariates(covariate_names)
Y = data_obj.compute_spike_counts()
Y_smooth = data_obj.smooth_spikes()
data_obj.compute_events()




In [None]:
# PGAM design for one unit
unit_idx = 0
sm_handler = one_ff_pgam_design.build_smooth_handler(
    data_obj=data_obj,
    unit_idx=unit_idx,
    covariate_names=covariate_names,
    tuning_covariates=covariate_names,
    use_cyclic=set(),
    order=4,
)

sm_handler.smooths_var

In [None]:
import pandas as pd
import numpy as np

binned_spikes_df = data_obj.get_binned_spikes_df()

# PGAM runner
runner = pgam_class.PGAMclass( 
    x_var=binned_spikes_df,
    bin_width=data_obj.prs.dt,
    save_dir='all_monkey_data/one_ff_data/pgam_results'
)

# attach design + trial structure
runner.sm_handler = sm_handler
runner.trial_ids = data_obj.covariate_trial_ids
runner.train_trials = runner.trial_ids % 3 != 1
runner.kernel_h_length = 100

In [None]:
runner.sm_handler.smooths_var

In [None]:
# run model
# runner.run_pgam(neural_cluster_number=unit_idx)
# runner.post_processing_results()
# runner.save_results()

# Only to load PGAM results

In [None]:
# import pandas as pd
# import numpy as np


# # PGAM runner
# runner = pgam_class.PGAMclass( 
#     bin_width=prs.dt,
#     save_dir='all_monkey_data/one_ff_data/pgam_results'
# )


# load results

In [None]:
neural_cluster_number = 1

In [None]:
runner.load_pgam_results(neural_cluster_number)
runner.plot_results(plot_var_order=one_ff_parameters.plot_var_order)

# Variance explained (new)

In [None]:
all_mean_r2 = []
num_neurons = runner.x_var.shape[1]
for n in range(num_neurons):
    out = runner.run_pgam_cv(n, n_splits=5, filtwidth=2)
    all_mean_r2.append(out['mean_r2_eval'])

In [None]:
all_mean_r2

In [None]:
def plot_variance_explained_cdf(all_mean_r2,
                                alpha=0.05,
                                label='Uncoupled model',
                                figsize=(6, 6)):
    """
    Plot empirical CDF of variance explained with DKW confidence band.

    Parameters
    ----------
    all_mean_r2 : array-like
        One variance explained value per neuron (e.g., CV mean R2).
    alpha : float
        Significance level for DKW band (default=0.05 for 95% CI).
    label : str
        Label for the curve.
    figsize : tuple
        Figure size.
    """

    import numpy as np
    import matplotlib.pyplot as plt

    # Convert to array and remove NaNs
    r2 = np.asarray(all_mean_r2)
    r2 = r2[np.isfinite(r2)]

    if len(r2) == 0:
        raise ValueError("No valid variance explained values to plot.")

    # Sort values
    r2_sorted = np.sort(r2)
    n = len(r2_sorted)

    # Empirical CDF
    cdf = np.arange(1, n + 1) / n

    # DKW epsilon
    epsilon = np.sqrt(np.log(2 / alpha) / (2 * n))

    lower = np.maximum(cdf - epsilon, 0)
    upper = np.minimum(cdf + epsilon, 1)

    # Plot
    plt.figure(figsize=figsize)

    plt.plot(r2_sorted, cdf, linewidth=2, label=label)
    plt.fill_between(r2_sorted, lower, upper, alpha=0.25)

    plt.axhline(0.5, linestyle='--', color='gray')

    plt.xlabel('Variance explained (5-fold CV)')
    plt.ylabel('Cumulative fraction of neurons')

    plt.ylim([0, 1])
    plt.xlim([min(0, r2_sorted.min()), r2_sorted.max()])

    plt.tight_layout()
    plt.legend()
    plt.show()

    print('Median variance explained:', np.median(r2))

In [None]:
plot_variance_explained_cdf(all_mean_r2)

# Variance explained (old)

In [None]:
stop!

In [None]:
runner.res[neural_cluster_number]

In [None]:
self.res, self.reduced_vars, self.meta

In [None]:
runner.spk_counts = runner.x_var.iloc[:, neural_cluster_number].values
runner.cluster_name = runner.x_var.columns[neural_cluster_number]

In [None]:
r2_train = runner.compute_variance_explained(use_train=True)
r2_eval  = runner.compute_variance_explained(use_train=False)

print('MATLAB-style R2 (train):', r2_train)
print('MATLAB-style R2 (eval):', r2_eval)

In [None]:
r2_list = []
num_neurons = runner.binned_spikes_df.shape[1]

for idx in range(num_neurons):
    runner.load_pgam_results(neural_cluster_number=idx)
    
    r2 = runner.compute_variance_explained(use_train=False)  # or True
    r2_list.append(r2)

r2_array = np.array(r2_list)

In [None]:
alpha = 0.05
epsilon = np.sqrt(np.log(2 / alpha) / (2 * n))

lower = np.maximum(cdf - epsilon, 0)
upper = np.minimum(cdf + epsilon, 1)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6,6))

plt.plot(r2_sorted, cdf, linewidth=2, label='Uncoupled model')

plt.fill_between(r2_sorted, lower, upper, alpha=0.25)

plt.axhline(0.5, linestyle='--', color='gray')

plt.xlabel('Variance explained')
plt.ylabel('Cumulative fraction of neurons')

plt.xlim([0, 1])
plt.ylim([0, 1])

plt.legend()
plt.tight_layout()
plt.show()

# Fraction of neurons tuned (across populations)

In [None]:
result_list = []
idx = 0

while True:
    try:
        runner.load_pgam_results(neural_cluster_number=idx)
    except Exception as e:
        print(f'Error occurred while processing neuron {idx}: {e}')
        print('Number of neurons processed:', idx)
        break

    result_list.append(runner.res)
    idx += 1

In [None]:
tuning_stats = compute_tuning.calculate_tuning_fraction(result_list)
compute_tuning.plot_fraction_tuned(tuning_stats)

In [None]:
runner.res['variable']

In [None]:
runner.res['reduced_pval']

# Appendix

## default bin range (from Kaushik's code)

In [None]:

prs.binrange.v = [0 ; 200]; %cm/s
prs.binrange.w = [-90 ; 90]; %deg/s
prs.binrange.a = [-0.36 ; 0.36]; %cm/s
prs.binrange.alpha = [-0.36 ; 0.36]; %deg/s
prs.binrange.r_targ = [0 ; 400]; %cm
prs.binrange.theta_targ = [-60 ; 60]; %cm
prs.binrange.d = [0 ; 400]; %cm
prs.binrange.phi = [-90 ; 90]; %deg
prs.binrange.h1 = [-0.36 ; 0.36]; %s
prs.binrange.h2 = [-0.36 ; 0.36]; %s
prs.binrange.eye_ver = [-25 ; 0]; %deg
prs.binrange.eye_hor = [-40 ; 40]; %deg
prs.binrange.veye_vel = [-15 ; 5]; %deg
prs.binrange.heye_vel = [-30 ; 30]; %deg
prs.binrange.phase = [-pi ; pi]; %rad
prs.binrange.target_ON = [-0.24 ; 0.48]; %s
prs.binrange.target_OFF = [-0.36 ; 0.36]; %s
prs.binrange.move = [-0.36 ; 0.36]; %s
prs.binrange.stop = [-0.36 ; 0.36]; %s
prs.binrange.reward = [-0.36 ; 0.36]; %s

In [None]:
stop!

## check pgam var range

For an event, it seems that whether it’s treated as causal or not can actually change the units (milliseconds vs seconds)… I’m speechless.

In [None]:
for var in sm_handler.smooths_var:
    print(var)
    print('Original range:', sm_handler.smooths_dict[var]._x.min    (), sm_handler.smooths_dict[var]._x.max())
    print('Binned range:', sm_handler.smooths_dict[var].xmin, sm_handler.smooths_dict[var].xmax)
    print('')

## see one trial

In [None]:
# indices
session_num = 0
trial_num = 0

# params
sessions = data_obj.sessions
prs = one_ff_parameters.default_prs()

# session / behaviour
session = sessions[session_num]
behaviour = session.behaviour

# trials / stats
all_trials = behaviour.trials
all_stats = behaviour.stats
trial_ids = np.arange(len(all_trials))

trial = all_trials[trial_num]
stats = all_stats[trial_num]
pos_rel = stats.pos_rel

# continuous data
continuous = trial.continuous
print(continuous._fieldnames)

x = continuous.xmp
y = continuous.ymp
v = continuous.v
w = continuous.w
t = continuous.ts

# time step
prs.dt = round(np.mean(np.diff(t)), 5)

In [None]:
trial_data = data_obj.get_trial(trial_num=0)
trial_spikes = data_obj.get_trial_spike_times(trial_num=0)



In [None]:
import matplotlib.pyplot as plt

plt.plot(trial_data['x'], trial_data['y'], 'k-')
plt.xlabel('x (forward)')
plt.ylabel('y (lateral)')
plt.axis('equal')


## verify compute_all_covariates

In [None]:
covariates = one_ff_data_processing.compute_all_covariates(trial, prs.dt)


In [None]:
import numpy as np

are_close = np.allclose(pos_rel.r_targ, covariates['r_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

are_close = np.allclose(pos_rel.theta_targ, covariates['theta_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

In [None]:
covariates.keys()

In [None]:
for k, v in covariates.items():
    print(k, v.shape, np.nanmin(v), np.nanmax(v))
