# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils
from decision_making_analysis.event_detection import detect_rsw_and_rcap

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, psth_postprocessing, psth_stats, compare_events, dpca_utils
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import get_stops_utils, collect_stop_data

from neural_data_analysis.neural_analysis_tools.glm_tools.glm_fit import general_glm_fit, cv_stop_glm, glm_fit_utils, variance_explained, glm_runner
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_plotting import plot_spikes, plot_glm_fit, plot_tuning_func, compare_glm_fit
from neural_data_analysis.design_kits.design_around_event import event_binning, stop_design, cluster_design, design_checks
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_glm.glm_hyperparams import compare_glm_configs, glm_hyperparams_class
from neural_data_analysis.topic_based_neural_analysis.ff_visibility import ff_vis_epochs, vis_design

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import prepare_stop_design

# import decoding
from neural_data_analysis.neural_analysis_tools.decoding_tools.event_decoding import decoding_utils, decoding_analysis, plot_decoding, cmp_decode, load_results
from neural_data_analysis.design_kits.design_by_segment import spike_history
from neural_data_analysis.neural_analysis_tools.decoding_tools.general_decoding import cv_decoding
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural.pn_decoding import pn_decoding_model_specs
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import assemble_stop_design
from neural_data_analysis.topic_based_neural_analysis.ff_visibility import decode_vis

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats
import json

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

pd.set_option('display.max_colwidth', 200)

# load data

In [None]:
from scipy.io import loadmat

data = loadmat('all_monkey_data/one_ff_data/sessions_python.mat',
               squeeze_me=True,
               struct_as_record=False)

sessions = data['sessions_out']


## behavioral data

In [None]:
trial_num = 30
trial_data = sessions[0].behaviour_trials[trial_num].continuous
print(trial_data._fieldnames)

x = trial_data.xmp
y = trial_data.ymp
v = trial_data.v
w = trial_data.w
t = trial_data.ts


In [None]:
import matplotlib.pyplot as plt

plt.plot(x, y, 'k-')
plt.xlabel('x (forward)')
plt.ylabel('y (lateral)')
plt.axis('equal')


## neural data

In [None]:
# get unit0's trial_0 spike times
unit0 = sessions[0].units[0]
trial0 = unit0.trials[0]
trial0._fieldnames
trial0.tspk


# Replicate

In [None]:

# =========================
# Load data
# =========================

data = loadmat(
    'all_monkey_data/one_ff_data/sessions_python.mat',
    squeeze_me=True,
    struct_as_record=False
)

sessions = data['sessions_out']
session = sessions[0]

trials = session.behaviour_trials
units = session.units

n_trials = len(trials)
n_units = len(units)


# =========================
# Build trial index (all trials)
# =========================

trial_ids = np.arange(n_trials)



In [None]:
"""
Population analysis pipeline
Python replication of AnalysePopulation.m (core functionality)

Author: you + ChatGPT
"""

import numpy as np
from scipy.io import loadmat
from scipy.signal import hilbert
from scipy.signal.windows import gaussian
from scipy.linalg import lstsq
from sklearn.cross_decomposition import CCA
from dataclasses import dataclass
from typing import List, Dict


# =========================
# Parameters (prs struct)
# =========================

@dataclass
class Params:
    dt: float = 0.01
    neural_filtwidth: int = 5
    pretrial: float = 0.5
    posttrial: float = 0.5
    cca_vars: List[str] = None
    decode_vars: List[str] = None


prs = Params(
    dt=0.01,
    neural_filtwidth=5,
    cca_vars=['v', 'w'],
    decode_vars=['v', 'w']
)


# =========================
# Utilities
# =========================

def gaussian_kernel(width):
    t = np.arange(-2 * width, 2 * width + 1)
    h = np.exp(-t ** 2 / (2 * width ** 2))
    return h / h.sum()


def smooth_signal(x, width):
    if width <= 0:
        return x
    h = gaussian_kernel(width)
    return np.apply_along_axis(lambda m: np.convolve(m, h, mode='same'), 0, x)


def bin_spikes(spike_times, ts):
    counts = np.zeros(len(ts))
    idx = np.searchsorted(ts, spike_times)
    idx = idx[(idx >= 0) & (idx < len(ts))]
    np.add.at(counts, idx, 1)
    return counts


def concatenate_trials(trials, trial_ids, signal_fn, time_window_fn):
    X = []
    trial_lengths = []

    for tid in trial_ids:
        tr = trials[tid]
        mask = time_window_fn(tr)
        sig = signal_fn(tr, tid)[mask]   # <-- pass tid
        X.append(sig)
        trial_lengths.append(len(sig))

    return np.concatenate(X), trial_lengths


def deconcatenate(x, trial_lengths):
    out = []
    idx = 0
    for L in trial_lengths:
        out.append(x[idx:idx+L])
        idx += L
    return out


def compute_d(v, ts, dt):
    d = np.zeros_like(v)
    valid = ts > 0
    d[valid] = np.cumsum(v[valid]) * dt
    return d


def compute_phi(w, ts, dt):
    phi = np.zeros_like(w)
    valid = ts > 0
    phi[valid] = np.cumsum(w[valid]) * dt
    return phi


# =========================
# Time window helper
# =========================

def full_time_window(tr):
    t0 = min(tr.events.t_move, tr.events.t_targ) - prs.pretrial
    t1 = tr.events.t_end + prs.posttrial
    return (tr.continuous.ts >= t0) & (tr.continuous.ts <= t1)


# =========================
# Build stimulus matrix X
# =========================

def get_var(tr, name):
    if name == 'v':
        return tr.continuous.v
    if name == 'w':
        return tr.continuous.w
    if name == 'd':
        return compute_d(tr.continuous.v, tr.continuous.ts, prs.dt)
    if name == 'phi':
        return compute_phi(tr.continuous.w, tr.continuous.ts, prs.dt)
    raise ValueError(name)


def gen_traj(w, v, ts):
    x = np.zeros(len(ts))
    y = np.zeros(len(ts))
    for i in range(1, len(ts)):
        x[i] = x[i-1] + v[i] * np.cos(w[i]) * prs.dt
        y[i] = y[i-1] + v[i] * np.sin(w[i]) * prs.dt
    return x, y




# run

In [None]:
X_list = []
trial_lengths = None
total_len_ref = None

for var in prs.cca_vars:
    x, tl = concatenate_trials(
        trials,
        trial_ids,
        lambda tr, tid, v=var: get_var(tr, v),
        full_time_window
    )

    if trial_lengths is None:
        trial_lengths = tl
        total_len_ref = x.shape[0]
    else:
        if x.shape[0] != total_len_ref:
            raise RuntimeError(f'Concatenated length mismatch for var {var}: {x.shape[0]} vs {total_len_ref}')

    X_list.append(x)

X = np.column_stack(X_list)
X[np.isnan(X)] = 0


In [None]:


# =========================
# Build population activity Y
# =========================

Y = np.zeros((X.shape[0], n_units))

for k in range(n_units):
    yk, _ = concatenate_trials(
        trials,
        trial_ids,
        lambda tr, tid, k=k: bin_spikes(
            units[k].trials[tid].tspk,
            tr.continuous.ts
        ),
        full_time_window
    )
    Y[:, k] = yk

Y_smooth = smooth_signal(Y, prs.neural_filtwidth) / prs.dt


# =========================
# Canonical Correlation Analysis
# =========================

cca = CCA(n_components=min(X.shape[1], Y_smooth.shape[1]))
Xc, Yc = cca.fit_transform(X, Y_smooth)

cca_corrs = [np.corrcoef(Xc[:, i], Yc[:, i])[0, 1] for i in range(Xc.shape[1])]

print('CCA correlations:', cca_corrs)


# =========================
# Linear population decoding
# =========================

decode_results = {}

for var in prs.decode_vars:
    xt, _ = concatenate_trials(
        trials,
        trial_ids,
        lambda tr, tid, v=var: get_var(tr, v),
        full_time_window
    )
    xt[np.isnan(xt)] = 0

    Yd = smooth_signal(Y, prs.neural_filtwidth)

    wts, _, _, _ = lstsq(Yd, xt)
    pred = Yd @ wts

    corr = np.corrcoef(xt, pred)[0, 1]

    decode_results[var] = dict(
        weights=wts,
        corr=corr,
        true=deconcatenate(xt, trial_lengths),
        pred=deconcatenate(pred, trial_lengths)
    )

    print(f'Decode {var}: r = {corr:.3f}')


# =========================
# Trajectory reconstruction
# =========================

traj_pred = []

for tr_v, tr_w in zip(decode_results['v']['pred'], decode_results['w']['pred']):
    x, y = gen_traj(tr_w, tr_v, np.arange(len(tr_v)) * prs.dt)
    traj_pred.append((x, y))


print('Pipeline complete.')
