# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats
import json

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.all_stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr
import statsmodels.api as sm

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2
%matplotlib inline

pd.set_option('display.max_colwidth', 200)

# load data

In [None]:
from scipy.io import loadmat

data = loadmat('all_monkey_data/one_ff_data/sessions_python.mat',
               squeeze_me=True,
               struct_as_record=False)

sessions = data['sessions_out']


## behavioral data

In [None]:
sessions = data['sessions_out']
session_num = 0
session = sessions[session_num]          # struct element
behaviour = session.behaviour            # unwrap scalar struct

In [None]:
trial_num = 0
all_trials = behaviour.trials
all_stats = behaviour.stats
pos_rel = all_stats[trial_num].pos_rel



In [None]:
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.parameters import default_prs

prs = default_prs()


In [None]:
trial = all_trials[trial_num]
continuous = trial.continuous
stats = all_stats[trial_num]
print(continuous._fieldnames)

x = continuous.xmp
y = continuous.ymp
v = continuous.v
w = continuous.w
t = continuous.ts


dt = np.mean(np.diff(trials[0].continuous.ts))
dt = round(dt, 5)
prs.dt = dt


In [None]:
continuous

In [None]:
events = sessions[session_num].behaviour.trials[trial_num].events

In [None]:
pos_rel.r_targ
pos_rel.theta_targ

In [None]:
import matplotlib.pyplot as plt

plt.plot(x, y, 'k-')
plt.xlabel('x (forward)')
plt.ylabel('y (lateral)')
plt.axis('equal')


## neural data

# functions to process data

In [None]:
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import nav_covariates
covariates = nav_covariates.compute_all_covariates(trial, dt)


In [None]:
import numpy as np

are_close = np.allclose(pos_rel.r_targ, covariates['r_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

are_close = np.allclose(pos_rel.theta_targ, covariates['theta_targ'], rtol=1e-5, atol=1e-8, equal_nan=True)
print(are_close)

In [None]:
covariates.keys()

In [None]:
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import population_analysis_utils


In [None]:
for k, v in covariates.items():
    print(k, v.shape, np.nanmin(v), np.nanmax(v))


# concat data

In [None]:
trial_ids = np.arange(len(all_trials))

In [None]:
# get unit0's trial_0 spike times
trial_neural_data = {}
for unit_id in range(len(sessions[session_num].units)):
    trial_neural_data[unit_id] = sessions[session_num].units[unit_id].trials[trial_num].tspk



In [None]:
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import nav_covariates
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import population_analysis_utils

covariate_names = [
    'v', 'w', 'd', 'phi',
    'r_targ', 'theta_targ',
    'eye_ver', 'eye_hor', 'move'
]

covariates_concat, trial_id_vec = population_analysis_utils.concatenate_covariates_with_trial_id(
    trials=all_trials,
    trial_indices=trial_ids,
    covariate_fn=lambda tr: nav_covariates.compute_all_covariates(tr, dt),
    time_window_fn=population_analysis_utils.full_time_window,
    covariate_names=covariate_names
)


In [None]:
covariates_concat['v'].shape

In [None]:
units = sessions[session_num].units
n_units = len(units)

In [None]:
dt

In [None]:
from methods.neural_data_analysis.topic_based_neural_analysis.replicate_one_ff.parameters import default_prs

prs = default_prs()

print(prs.dt)
print(prs.neural_filtwidth)
print(prs.GAM_varname)


In [None]:
Y = np.zeros((len(trial_id_vec), n_units))

for k in range(n_units):
    spk_counts, trial_id_vec_spk = population_analysis_utils.concatenate_trials_with_trial_id(
        all_trials,
        trial_ids,
        lambda tr, tid: population_analysis_utils.bin_spikes(
            trial_neural_data[k],
            tr.continuous.ts
        ),
        population_analysis_utils.full_time_window
    )
    Y[:, k] = spk_counts

Y_smooth = population_analysis_utils.smooth_signal(Y, prs.neural_filtwidth) / dt

In [None]:
 Y.shape

In [None]:
trial_id_vec_spk.shape

In [None]:
all_events = {}
for event in ['t_targ', 't_move', 't_rew']:
    events_concat, trial_id_vec_evt = population_analysis_utils.concatenate_trials_with_trial_id(
        all_trials,
        trial_ids,
        lambda tr, tid: population_analysis_utils.event_impulse(tr, tid, event),
        population_analysis_utils.full_time_window
    )
    all_events[event] = events_concat

In [None]:
events_concat.shape

# PGAM

In [None]:
# Import libraries

# import sys
# ## if working outside the docker container, uncomment the line below and add the path to [YOUR PATH TO PGAM FOLDER]/src/
# ## sys.path.append('[YOUR PATH TO PGAM FOLDER]/src/')
# sys.path.append('src/')

pgam_path = '/Users/dusiyi/Documents/Multifirefly-Project/multiff_analysis/external/pgam/src/'
import sys
if not pgam_path in sys.path: 
    sys.path.append(pgam_path)
    
import numpy as np
import sys
from PGAM.GAM_library import *
import PGAM.gam_data_handlers as gdh
import matplotlib.pylab as plt
import pandas as pd
from post_processing import postprocess_results
from scipy.io import savemat

## temporal covariance

Temporal filters g were parameterized using a basis of ten raised cosine filters spanning a range of 600 milliseconds. The filter associated with target-onset was causal ([0, 600] ms), while the remaining filters were non-causal ([-300, 300] ms). Both spike-history filter h and coupling filter p were expressed using a basis of ten causal raised cosine filters in logarithmic time scale. Spike-history filters spanned 350 ms, while coupling filters spanned 1.375 seconds.

In [None]:
from neural_data_analysis.topic_based_neural_analysis.replicate_one_ff import prepare_pgam_design


In [None]:
sm_handler = prepare_pgam_design.build_smooth_handler_for_unit(
    unit_idx=k,                         # <-- choose the unit you want
    covariates_concat=covariates_concat,
    covariate_names=covariate_names,
    trial_id_vec=trial_id_vec,
    Y_binned=Y,
    all_events=all_events,
    dt=dt,
    tuning_covariates=covariate_names,  # or a subset
    use_cyclic=set(),                  # e.g., {'heading_angle'}
    order=4,
)


In [None]:
sm_handler

In [None]:
order = 4
knots_num = 10 - order
dt_ms = 1.0          # e.g. 1 ms bins (1 kHz sampling)
kernel_ms = 600
kernel_h_length = int(kernel_ms / dt_ms)


In [None]:
# tot_tp = 10**3

# # # trial ids
# # trial_ids = np.zeros(tot_tp)
# # trial_ids[400:] = 1

# # # event markers
# # event = np.zeros(tot_tp)
# # event[[100, 200, 600, 900]] = 1

# # kernel parameters
# dt_ms = 1.0                 # ms per time bin
# kernel_ms = 600             # total temporal span
# kernel_h_length = int(kernel_ms / dt_ms)

# order = 4                   # cubic B-splines
# num_filters = 10
# num_int_knots = num_filters - order

# dict_kernel = {
#     0: 'Acausal',
#     1: 'Direction %d' % 1,
#     -1: 'Direction %d' % (-1)
# }

# for kernel_direction in [0, 1, -1]:
#     sm_handler = gdh.smooths_handler()

#     sm_handler.add_smooth(
#         'this_event',
#         [events_concat],
#         is_temporal_kernel=True,
#         ord=order,
#         knots_num=num_int_knots,
#         trial_idx=trial_id_vec_evt,
#         kernel_length=kernel_h_length,
#         kernel_direction=kernel_direction
#     )


In [None]:
events_concat, trial_id_vec_evt