# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils
from decision_making_analysis.compare_GUAT_and_TAFT import find_GUAT_or_TAFT_trials

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.stop_psth import core_stops_psth, psth_postprocessing, psth_stats, compare_events, dpca_utils, prep_stop_psth_data
from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis.get_stop_events import get_stops_utils, collect_stop_data

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2

# retrieve data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"

pn, datasets, comparisons = collect_stop_data.collect_stop_data_func(
    raw_data_folder_path)

globals().update(datasets)

captures_df, valid_captures_df, filtered_no_capture_stops_df, stops_with_stats = get_stops_utils.prepare_no_capture_and_captures(
    monkey_information=pn.monkey_information,
    closest_stop_to_capture_df=pn.closest_stop_to_capture_df,
    ff_caught_T_new=pn.ff_caught_T_new,
    distance_col="distance_from_ff_to_stop",
)

# Dwell time on stops

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Combine all datasets, tagging each
df_compare = pd.concat([
    GUAT_last.assign(group='retry miss Last'),
    TAFT_last.assign(group='retry capture Last'),
    valid_captures_df.assign(group='Capture'),
    
    GUAT_middle.assign(group='retry miss Middle'),
    TAFT_middle.assign(group='retry capture Middle'),
    
    
    # GUAT_nonfinal.assign(group='retry miss Non-final'),
    # TAFT_nonfinal.assign(group='retry capture Non-final'),
    
    
    GUAT_first.assign(group='retry miss First'),
    TAFT_first.assign(group='retry capture First'),
    
    valid_captures_df.assign(group='Capture'),
    #filtered_no_capture_stops_df.assign(group='Non-capture'),
    one_stop_miss.assign(group='One-stop Miss'),
    
])

# Plot all as separate violins
plt.figure(figsize=(10, 6))
sns.violinplot(
    data=df_compare,
    x='group',
    y='stop_id_duration',
    inner='quartile',
    cut=0,
    scale='width'
)

# Customize
plt.xticks(rotation=30, ha='right')
plt.xlabel('')
plt.ylabel('Stop Duration (s)')
#plt.ylim(0, 10)
plt.title('Stop Duration Distribution Across Trial Types')
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(10, 6))
sns.boxplot(
    data=df_compare,
    x='group',
    y='stop_id_duration',
    showfliers=False,
    width=0.6
)
sns.stripplot(
    data=df_compare,
    x='group',
    y='stop_id_duration',
    color='black',
    alpha=0.4,
    jitter=0.3,
    size=2
)
plt.xticks(rotation=30, ha='right')
plt.ylabel('Stop Duration (s)')
plt.title('Stop Duration Distribution Across Trial Types')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# --- Keep only retry-related rows ---
retry_groups = [
    'retry miss First', 'retry miss Middle', 'retry miss Last',
    'retry capture First', 'retry capture Middle', 'retry capture Last'
]
df_retry = df_compare[df_compare['group'].isin(retry_groups)].copy()

# --- Extract phase and outcome type ---
df_retry['phase'] = df_retry['group'].str.extract(r'(First|Middle|Last)')
df_retry['type'] = df_retry['group'].str.contains('capture', case=False).map({True: 'Capture', False: 'Miss'})

# --- Define explicit phase order ---
phase_order = ['First', 'Middle', 'Last']

# --- Split violin plot ---
plt.figure(figsize=(8, 6))
sns.violinplot(
    data=df_retry,
    x='phase',
    y='stop_id_duration',
    hue='type',
    split=True,
    inner='quartile',
    cut=0,
    order=phase_order
)
plt.ylabel('Stop Duration (s)')
plt.xlabel('Retry Phase')
plt.title('Retry Miss vs Capture Stop Durations by Phase')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# --- Keep only retry groups explicitly ---
retry_groups = [
    'retry miss First', 'retry miss Middle', 'retry miss Last',
    'retry capture First', 'retry capture Middle', 'retry capture Last'
]
df_retry = df_compare[df_compare['group'].isin(retry_groups)].copy()

# --- Extract phase (First/Middle/Last) and outcome (Success/Fail) ---
df_retry['phase'] = df_retry['group'].str.extract(r'(First|Middle|Last)')
df_retry['outcome'] = np.where(df_retry['group'].str.contains('capture', case=False),
                               'Success', 'Fail')

# --- Add regular captures as a final category for successes ---
df_capture = valid_captures_df.copy()
df_capture['phase'] = 'Capture'
df_capture['outcome'] = 'Success'

# --- Add one-stop misses as a final category for fails ---
df_first_miss = one_stop_miss.copy()
df_first_miss['phase'] = 'First Miss'
df_first_miss['outcome'] = 'Fail'

# --- Combine everything ---
df_all = pd.concat([df_retry, df_capture, df_first_miss], ignore_index=True)

# --- Define phase orders (add 'Capture' or 'First Miss' at the end) ---
phase_order_success = ['First', 'Middle', 'Last', 'Capture']
phase_order_fail = ['First', 'Middle', 'Last', 'First Miss']

# --- Separate success/fail for plotting ---
df_success = df_all[df_all['outcome'] == 'Success']
df_fail = df_all[df_all['outcome'] == 'Fail']

# --- Retry + Success (includes Capture) ---
plt.figure(figsize=(8, 5))
sns.violinplot(
    data=df_success,
    x='phase',
    y='stop_id_duration',
    order=phase_order_success,
    inner='quartile',
    scale='width',
    cut=0,
    color='skyblue'
)
plt.title('Retry + Success (Capture) Stop Durations by Phase')
plt.xlabel('Phase')
plt.ylabel('Stop Duration (s)')
plt.tight_layout()
plt.show()

# --- Retry + Fail (includes First Miss) ---
plt.figure(figsize=(8, 5))
sns.violinplot(
    data=df_fail,
    x='phase',
    y='stop_id_duration',
    order=phase_order_fail,
    inner='quartile',
    scale='width',
    cut=0,
    color='salmon'
)
plt.title('Retry + Fail (Miss) Stop Durations by Phase')
plt.xlabel('Phase')
plt.ylabel('Stop Duration (s)')
plt.tight_layout()
plt.show()


## Compare 2

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Combine both dataframes and label them
df_compare = pd.concat([
    GUAT_last.assign(group='Give-up Last Stop'),
    GUAT_nonfinal.assign(group='Non-final Stop')
])

# Add a common x category so both appear in the same violin
df_compare['category'] = 'Stop Duration'

# Create split violin plot
plt.figure(figsize=(6, 6))
sns.violinplot(
    data=df_compare,
    x='category',
    y='stop_id_duration',
    hue='group',
    split=True,
    inner='quartile',
    width=0.8
)

# Customize
plt.xlabel('')
plt.ylabel('Stop Duration (s)')
plt.title('Stop Duration Distribution: Give-up Last vs Non-final Stops')
plt.legend(title='')
plt.tight_layout()
plt.show()


# Stats

In [None]:
from itertools import combinations
from scipy.stats import mannwhitneyu

groups = df_compare['group'].unique()
pairs = list(combinations(groups, 2))

for g1, g2 in pairs:
    x1 = df_compare.loc[df_compare['group'] == g1, 'stop_id_duration']
    x2 = df_compare.loc[df_compare['group'] == g2, 'stop_id_duration']
    stat, p = mannwhitneyu(x1, x2, alternative='two-sided')
    print(f'{g1} vs {g2}: p = {p:.4f}')


In [None]:
from scipy.stats import kruskal

# omnibus across groups
samples = [df_compare.loc[df_compare['group'] == g, 'stop_id_duration'] for g in df_compare['group'].unique()]
kw_stat, kw_p = kruskal(*samples)
print(f'Overall Kruskalâ€“Wallis: H={kw_stat:.3f}, p={kw_p:.2e}')


In [None]:
df_compare.groupby('group')['stop_id_duration'].median().sort_values()
