# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import general_utils, specific_utils, process_monkey_information
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, ml_decoder_class, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils
from neural_data_analysis.topic_based_neural_analysis.target_decoder import behav_features_to_keep, target_decoder_class, prep_target_decoder, eval_target_decoder, td_seg_aligned_class
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_tuning, gpfa_helper_class
from machine_learning.ml_methods import regression_utils, classification_utils, prep_ml_data_utils
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2

# check

In [None]:
dec.monkey_information

In [None]:
dec.behav_data_by_point

# Get data

In [None]:
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
#raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0321"
bin_width = 0.1
dec = td_seg_aligned_class.TargetDecoderSegmentAlignedClass(raw_data_folder_path=raw_data_folder_path,
                                                               bin_width=bin_width)

In [None]:
behav_data_exists_ok = True
x_and_y_var_exists_ok = False
dec.streamline_making_behav_and_neural_data(exists_ok=behav_data_exists_ok)
dec.get_x_and_y_var(exists_ok=x_and_y_var_exists_ok)
dec.get_x_and_y_data_for_modeling(exists_ok=x_and_y_var_exists_ok)
# dec._free_up_memory()
print('x_var.shape:', dec.x_var.shape)
print('y_var_reduced.shape:', dec.y_var_reduced.shape)

print('x_var_lags.shape:', dec.x_var_lags.shape)
print('y_var_lags_reduced.shape:', dec.y_var_lags_reduced.shape)

# GPFA

## get data and fit gpfa

In [None]:
dec.prepare_seg_aligned_data(align_at_beginning=False)

In [None]:
dec.get_gpfa_traj(latent_dimensionality=7, exists_ok=False)

In [None]:
print(dec.trajectories.shape) # number of segments
print(dec.trajectories[2].shape) # num_latent_dimensions x num_bins

In [None]:
use_raw_spike_data_instead = False
dec.get_concat_data_for_regression(use_raw_spike_data_instead=False,
                                    use_lagged_raw_spike_data=False,
                                    apply_pca_on_raw_spike_data=False,
                                    num_pca_components=7)

dec.print_data_dimensions()

## point-wise segment regression

In [None]:
dec.time_resolved_regression_cv()

In [None]:
dec.plot_time_resolved_regression()

In [None]:
dec.plot_trial_counts_by_timepoint()  # 

## concat data regression

In [None]:
# Multivariate linear regression
dec.y_var_lr_df = neural_data_modeling.get_y_var_lr_df(
                dec.concat_neural_trials.drop(columns=['new_segment', 'new_bin'], errors='ignore'),
                dec.concat_behav_trials)

In [None]:
# use_raw_spike_data_instead=True,
# use_lagged_raw_spike_data=True,
# apply_pca_on_raw_spike_data=True,
dec.y_var_lr_df.head(15)

In [None]:
# use_raw_spike_data_instead = True
dec.y_var_lr_df.head(7)

In [None]:
# use_raw_spike_data_instead=True,
# use_lagged_raw_spike_data=True,

dec.y_var_lr_df.head(7)

In [None]:
# use_raw_spike_data_instead = False
dec.y_var_lr_df.head(10)

## train-test split by segments

In [None]:
    # x_var, y_var = pn.get_concat_x_and_y_var_for_lr(test_or_control=test_or_control)
    
    # results_summary = ml_methods_utils.run_segment_split_regression_cv(
    #     x_var, 
    #     y_var, 
    #     columns_of_interest, 
    #     num_folds=5, 
    # )

In [None]:
from machine_learning.ml_methods import ml_methods_utils

In [None]:
columns_of_interest = ['target_distance']

x_var = dec.concat_neural_trials
y_var = dec.concat_behav_trials

results_summary = ml_methods_utils.run_segment_split_regression_cv(
    x_var, 
    y_var, 
    columns_of_interest, 
    num_folds=5, 
    segment_column='new_segment',
)

results_summary

In [None]:
wide_df = ml_methods_utils.convert_results_to_wide_df(results_summary, index_columns=['Target', 'Model'])
wide_df

## plot latent dimensions

In [None]:
plot_gpfa_utils.plot_gpfa_traj_3d_uniform_color(dec.trajectories)


In [None]:
reload(plot_gpfa_utils)

In [None]:
# First, enable interactive mode in your notebook
%matplotlib inline

# Import required modules
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Create the interactive plot
fig, ax = plot_gpfa_utils.plot_gpfa_traj_3d(
    num_traj_to_plot=30,
    trajectories=dec.trajectories,
    figsize=(15, 5),
    linewidth_single_trial=0.75,
    alpha_single_trial=0.3,
    linewidth_trial_average=2,
    title='Latent dynamics extracted by GPFA',
    view_azim=-5,
    view_elev=60
)

plt.show()

In [None]:
# fig = plot_gpfa_utils.plot_gpfa_traj_3d_plotly(trajectories)

In [None]:
# Find variance explained by each latent dimension
traj_stack = np.stack(dec.trajectories, axis=0)  # shape: (n_trials, 3, T)
var_by_dim = np.var(traj_stack, axis=(0, 2))    # variance across trials and time
var_by_dim /= var_by_dim.sum()               # normalize to get explained variance ratio
print("Variance explained by each latent dimension:", var_by_dim)


In [None]:

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(figsize=(15, 5))

ax.set_title('Latent dynamics extracted by GPFA')
ax.set_xlabel('Time [s]')

average_trajectory = np.mean(dec.trajectories, axis=0)
time = np.arange(len(average_trajectory[0])) * dec.bin_width  # assuming all trajectories have the same length

for i, x in enumerate(average_trajectory):
    ax.plot(time, x, label=f'Dim {i+1}')

ax.legend()

plt.tight_layout()
plt.show()


## why poor performance?

In [None]:
import neural_data_analysis.neural_analysis_tools.gpfa_methods.time_resolved_regression as time_resolved_regression

# 1. Print number of trials per timepoint
time_resolved_regression.print_trials_per_timepoint(dec.gpfa_neural_trials)

# 2. Check for NaNs
time_resolved_regression.check_for_nans_in_trials(dec.gpfa_neural_trials, name='latent')
time_resolved_regression.check_for_nans_in_trials(dec.behav_trials, name='behavioral')

# 3. Standardize trials
latent_trials_std = time_resolved_regression.standardize_trials(dec.gpfa_neural_trials)
behav_trials_std = time_resolved_regression.standardize_trials(dec.behav_trials)

# 4. Print number of points per trial
plot_time_resolved_regression.plot_trial_point_distribution(dec.pursuit_data)

# 5. Plot latent and behavioral variables for a few trials
time_resolved_regression.plot_latents_and_behav_trials(latent_trials_std, behav_trials_std, bin_width=dec.bin_width, n_trials=5)


## hyperparams (still need to debug)

In [None]:
stop! # this section is not finished yet

# grid search

import itertools
from joblib import Parallel, delayed, cpu_count
print(f"Detected CPU cores: {cpu_count()}")

# # can add for smoothing:
# # other forms of smoothing like (currently it's only uniform_filter1d)
# from scipy.ndimage import gaussian_filter1d
# # gpfa_neural_trials: list of trials, each trial shape (time_bins, n_neurons)
# smoothed_trials = [
#     gaussian_filter1d(trial, sigma=smooth_sigma, axis=0)
#     for trial in gpfa_neural_trials
# ]


# Define your grid
smoothing_windows = [1, 3]
use_sqrt = [True, False]
gpfa_dims = [3, 5]
bin_widths = [0.02]
ridge_alphas = [0.1, 1]
regression_types = ['ridge']
align_at_beginning_opts = [True]
pca_components = [5, 10]

param_grid_gpfa = list(itertools.product(
    smoothing_windows, use_sqrt, gpfa_dims, bin_widths, ridge_alphas, regression_types, align_at_beginning_opts
))

# Baseline configs
param_grid_raw = list(itertools.product(
    smoothing_windows, use_sqrt, bin_widths, ridge_alphas, regression_types, align_at_beginning_opts
))
param_grid_pca = list(itertools.product(
    smoothing_windows, use_sqrt, bin_widths, ridge_alphas, regression_types, align_at_beginning_opts, pca_components
))

# Run GPFA grid
results_gpfa = Parallel(n_jobs=-1, verbose=10)(
    delayed(gpfa_tuning.run_gpfa_experiment_time_resolved)(
        dec, smoothing, sqrt, gpfa_dim, bin_width, ridge_alpha, regression_type, align_at_beginning, baseline=None
    )
    for (smoothing, sqrt, gpfa_dim, bin_width, ridge_alpha, regression_type, align_at_beginning) in param_grid_gpfa
)

# Run raw baseline grid
results_raw = Parallel(n_jobs=-1, verbose=10)(
    delayed(gpfa_tuning.run_gpfa_experiment_time_resolved)(
        dec, smoothing, sqrt, None, bin_width, ridge_alpha, regression_type, align_at_beginning, baseline='raw'
    )
    for (smoothing, sqrt, bin_width, ridge_alpha, regression_type, align_at_beginning) in param_grid_raw
)

# Run PCA baseline grid
results_pca = Parallel(n_jobs=-1, verbose=10)(
    delayed(gpfa_tuning.run_gpfa_experiment_time_resolved)(
        dec, smoothing, sqrt, None, bin_width, ridge_alpha, regression_type, align_at_beginning, baseline='pca', pca_components=pca_comp
    )
    for (smoothing, sqrt, bin_width, ridge_alpha, regression_type, align_at_beginning, pca_comp) in param_grid_pca
)

# Combine all results
all_results = results_gpfa + results_raw + results_pca
df = pd.DataFrame(all_results)
print(df.sort_values('mean_r2', ascending=False).head(10))

In [None]:
import matplotlib.pyplot as plt
best = df.iloc[df['mean_r2'].idxmax()]
plt.plot(best['times'], np.nanmean(np.array(best['r2_by_time']), axis=1))
plt.xlabel('Time (s)')
plt.ylabel('Mean R²')
plt.title(f"Best config: {best['model']} R² by time")
plt.show()

# Compare models
import seaborn as sns
sns.catplot(data=df, x='model', y='mean_r2', kind='bar')

# ML to decode single vars

## decode

In [None]:
# neural_data = dec.x_var_lags
# behavioral_data = dec.y_var_reduced

neural_data = dec.concat_neural_trials
behavioral_data = dec.concat_behav_trials

In [None]:
# General usage for any behavioral variable
decoder = ml_decoder_class.MLBehavioralDecoder()
models_to_use=['rf', 'nn', 'lr']
successful_decodings = {}

for var in ['target_rel_y', 'target_rel_x']:
    result = decoder.decode_variable(neural_data, behavioral_data, var, models_to_use=models_to_use)
    if result is not None:
        successful_decodings[var] = result

best_model, best_results = decoder.get_best_model('target_rel_y', 'test_r2')

# Plot rf results for any variable
decoder.plot_ml_results('target_rel_y', 'rf')

successful_decodings

## compare different Models

Let's compare the performance of different machine learning models.


In [None]:
comparison_df = eval_target_decoder.compare_models(successful_decodings)

## plot feature importance for RF

In [None]:
# Analyze feature importance for Random Forest models
for target_var in successful_decodings.keys():
    if 'rf' in successful_decodings[target_var]:
        print(f"\n{'='*50}")
        print(f"FEATURE IMPORTANCE: {target_var}")
        print('='*50)
        
        rf_model = successful_decodings[target_var]['rf']['model']
        
        if hasattr(rf_model, 'feature_importances_'):
            # Get feature importance
            importance_df = regression_utils._get_rf_feature_importances(rf_model, dec.neural_data.columns)
            # Show top 10 most important features
            print(f"Top 10 most important neurons for {target_var}:")
            print(importance_df.head(10))
            
            # Plot feature importance
            regression_utils.plot_feature_importance(importance_df, target_var)


# Save Results (have yet to try)

Finally, let's save our results for future analysis.


In [None]:
import pickle
import json
import pandas as pd
from typing import Dict, Any

def create_experiment_info(decoder, monkey: str, session: str) -> Dict[str, Any]:
    """Create experiment information dictionary."""
    return {
        'monkey': monkey,
        'session': session,
        'bin_width': decoder.bin_width,
        'neural_data_shape': decoder.neural_data.shape,
        'target_data_shape': decoder.target_data.shape
    }

def create_cca_results(decoder) -> Dict[str, Any]:
    """Create CCA results summary."""
    return {
        'top_3_correlations': (
            decoder.results['cca']['canonical_correlations'][:3].tolist() 
            if 'cca' in decoder.results else None
        )
    }

def find_best_performances(successful_decodings: Dict) -> Dict[str, Dict[str, Any]]:
    """Find best performing model for each target variable."""
    best_performances = {}
    for target_var, models in successful_decodings.items():
        best_model = None
        best_score = -1
        
        for model_name, results in models.items():
            score = results.get('test_r2', results.get('test_accuracy', results.get('cv_mean', 0)))
            if score > best_score:
                best_score = score
                best_model = model_name
        
        best_performances[target_var] = {
            'best_model': best_model,
            'best_score': best_score
        }
    return best_performances

def create_summary_report(decoder, successful_decodings: Dict, monkey: str, session: str) -> Dict[str, Any]:
    """Create complete summary report."""
    return {
        'experiment_info': create_experiment_info(decoder, monkey, session),
        'cca_results': create_cca_results(decoder),
        'ml_results_summary': {
            'successful_targets': list(successful_decodings.keys()),
            'best_performances': find_best_performances(successful_decodings)
        }
    }

def print_summary_report(summary_report: Dict[str, Any]):
    """Print formatted summary report."""
    print("\nEXPERIMENT SUMMARY")
    print("="*50)
    print(f"Neural data shape: {summary_report['experiment_info']['neural_data_shape']}")
    print(f"Target data shape: {summary_report['experiment_info']['target_data_shape']}")
    
    if summary_report['cca_results']['top_3_correlations']:
        print(f"Top 3 CCA correlations: {summary_report['cca_results']['top_3_correlations']}")
    
    print(f"Successfully decoded targets: {summary_report['ml_results_summary']['successful_targets']}")
    
    print("\nBest model performance for each target:")
    for target, perf in summary_report['ml_results_summary']['best_performances'].items():
        print(f"  {target}: {perf['best_model']} (score: {perf['best_score']:.4f})")

def save_experiment_results(decoder, successful_decodings: Dict, monkey: str, session: str, 
                          base_filename: str = None):
    """Save both detailed results and summary report."""
    if base_filename is None:
        base_filename = f"target_decoding_results_{monkey}_{session}"
    
    pkl_filename = f"{base_filename}.pkl"
    json_filename = f"{base_filename}_summary.json"
    
    # Save detailed results
    print("Saving results...")
    decoder.save_results(pkl_filename)
    
    # Create and save summary report
    summary_report = create_summary_report(decoder, successful_decodings, monkey, session)
    print_summary_report(summary_report)
    
    with open(json_filename, 'w') as f:
        json.dump(summary_report, f, indent=2)
    
    print(f"\nResults saved to: {pkl_filename}")
    print(f"Summary saved to: {json_filename}")
    
    return pkl_filename, json_filename

def load_experiment_results(base_filename: str = None, monkey: str = None, session: str = None):
    """Load both detailed results and summary report."""
    if base_filename is None:
        if monkey and session:
            base_filename = f"target_decoding_results_{monkey}_{session}"
        else:
            raise ValueError("Must provide either base_filename or both monkey and session")
    
    pkl_filename = f"{base_filename}.pkl"
    json_filename = f"{base_filename}_summary.json"
    
    try:
        # Load detailed results
        with open(pkl_filename, 'rb') as f:
            decoder_results = pickle.load(f)
        
        # Load summary report
        with open(json_filename, 'r') as f:
            summary_report = json.load(f)
        
        print(f"Loaded results from: {pkl_filename}")
        print(f"Loaded summary from: {json_filename}")
        
        return decoder_results, summary_report
        
    except FileNotFoundError as e:
        print(f"File not found: {e}")
        return None, None
    except Exception as e:
        print(f"Error loading results: {e}")
        return None, None

# --- Usage Examples ---

# Saving (replaces your original code):
# save_experiment_results(decoder, successful_decodings, 'Bruno', 'data_0328')

# Loading:
# decoder_results, summary_report = load_experiment_results(monkey='Bruno', session='data_0328')
# OR
# decoder_results, summary_report = load_experiment_results(base_filename="target_decoding_results_bruno_0328")

# If you want to print the loaded summary:
# if summary_report:
#     print_summary_report(summary_report)

## save

In [None]:
# Save everything with one function call
save_experiment_results(decoder, successful_decodings, 'Bruno', 'data_0328')

## retrieve

In [None]:

# Load everything with one function call
decoder_results, summary_report = load_experiment_results(monkey='Bruno', session='data_0328')
## OR
# decoder_results, summary_report = load_experiment_results(base_filename="target_decoding_results_bruno_0328")

# If you want to print the loaded summary:
if summary_report:
    print_summary_report(summary_report)

# Access successful_decodings
if decoder_results and 'successful_decodings' in decoder_results:
    successful_decodings = decoder_results['successful_decodings']
    # Use with your model comparison functions

# Other thoughts

## more columns (possibly get in the future)

get also get: (but to be honest, it doesn't make that much sense to get them....so let's skip for now.)
'distance traversed since target last visible',
'd angle since target last visible', 'target_at_right',
'time_till_capture', 'time from last visible to capture

Note that there might be multicollinearity. For example, duration from last visible to capture = time since target last visible + time till capture

Similarly, target angle = target angle last seen frozen - d angle since target last visible

(For distance it's not exactly the same because of the difference between distance and distance traversed, but it's still similar)

The multicollinearity is fine in linear regression (when each feature here is a y var), but need to be dealt with in cca.

## possible things to try

should i actually align each section, as if they are trials???
maybe i can try both that and continuous time... both can shed light on different behavioral variables
but for aligning trials, it may require alignment or warping since trial durations vary.

btw, what does it mean to stitch data?

also, what does it look like to use RNN to model it?
I thought about the paper that Noah presented on RNN


btw.......IME

# Misc

## plot trial segments

In [None]:
## plot trial segments in pursuit_data
from visualization.matplotlib_tools import plot_trials,
dec.make_PlotTrials_args()
plt.rcParams['figure.figsize'] = [10, 10]                     

max_plot_to_make = 2
plot_counter = 0

for index, row in dec.single_vis_target_df.iloc[2:].iterrows():

    duration = [row['last_vis_time'], row['ff_caught_time']]

    returned_info = plot_trials.PlotTrials(
                duration, 
                *dec.PlotTrials_args,  
                adjust_xy_limits=True,       
                minimal_margin=50,
                show_reward_boundary=True,
                show_alive_fireflies=False,
                show_visible_fireflies=True,
                show_in_memory_fireflies=True,
                show_believed_target_positions=True,
                )
    plt.show()
    

    plot_counter += 1
    if plot_counter >= max_plot_to_make:
        break

# exp

In [None]:
## what the 

In [None]:
dec.behav_data_by_bin.head(3)

In [None]:
import numpy as np
import pandas as pd

# Custom function for mode (returns first mode if multiple)
def get_mode(x):
    return x.mode().iloc[0] if not x.mode().empty else pd.NA

# Define strict median function
def strict_median(series, method='lower'):
    sorted_vals = np.sort(series.dropna().values)
    n = len(sorted_vals)
    if n == 0:
        return np.nan
    elif n % 2 == 1:
        return sorted_vals[n // 2]
    else:
        if method == 'lower':
            return sorted_vals[n // 2 - 1]
        elif method == 'upper':
            return sorted_vals[n // 2]
        else:
            raise ValueError("method must be 'lower' or 'upper'")

# Define column groups
col_max = ['target_visible_dummy', 'target_cluster_visible_dummy', 'capture_ff',
           'num_visible_ff', 'any_ff_visible', 'catching_ff']
col_strict_median = ['point_index', 'valid_view_point']

# Combine aggregation functions
agg_funcs = {col: 'max' for col in col_max}
agg_funcs.update({col: strict_median for col in col_strict_median})

# Perform groupby
result = df.groupby('bin').agg(agg_funcs).reset_index()

# Drop unwanted columns (corrected)
result = result.drop(columns=[
    'target_index', 
    'target_has_disappeared_for_last_time_dummy', 
    'target_cluster_has_disappeared_for_last_time_dummy'
])

# Merge back relevant columns (corrected)
result = result.merge(
    df[['point_index', 'target_index', 'target_has_disappeared_for_last_time_dummy', 
        'target_cluster_has_disappeared_for_last_time_dummy']],
    on='point_index',
    how='left'
)

In [None]:
# through merge
'target_index', 'target_has_disappeared_for_last_time_dummy', 'target_cluster_has_disappeared_for_last_time_dummy'

In [None]:
# max
col_max = ['target_visible_dummy', 'target_cluster_visible_dummy', 'capture_ff',
           'num_visible_ff', 'any_ff_visible', 'catching_ff']

# strict median
col_strict_median = ['point_index', 'valid_view_point']


agg_funcs = {
 col: 'max' for col in col_max,
 col: 'strict_median' for col in col_strict_median,
}


In [None]:
# Custom function for mode (returns first mode if multiple)
def get_mode(x):
    return x.mode().iloc[0] if not x.mode().empty else pd.NA

def strict_median(series, method='lower'):
    sorted_vals = np.sort(series.dropna().values)
    n = len(sorted_vals)
    if n == 0:
        return np.nan
    elif n % 2 == 1:
        return sorted_vals[n // 2]
    else:
        if method == 'lower':
            return sorted_vals[n // 2 - 1]
        elif method == 'upper':
            return sorted_vals[n // 2]
        else:
            raise ValueError("method must be 'lower' or 'upper'")
        
# Specify aggregations
agg_funcs = {
    
}

# Get list of remaining columns to apply median
remaining_cols = [col for col in df.columns if col not in ['group'] + list(agg_funcs)]
for col in remaining_cols:
    agg_funcs[col] = 'median'

# Perform groupby
result = df.groupby('bin').agg(agg_funcs).reset_index()


In [None]:
dec.behav_data_by_bin.groupby('target_index').count()

In [None]:
dec.ff_caught_T_new.shape