# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils
from decision_making_analysis.compare_add_features_GUAT_and_TAFT import find_GUAT_or_TAFT_trials

from neural_data_analysis.topic_based_neural_analysis.stop_event_analysis import stops_psth, extract_stops_utils, psth_postprocessing

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2

# retrieve data

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0321"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0329"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0403"

In [None]:
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0312"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0316"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
# pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
#     pn.planning_data_by_point)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

if not hasattr(pn, 'spikes_df'):
    pn.retrieve_or_make_monkey_data()
    pn.spikes_df = neural_data_processing.make_spikes_df(pn.raw_data_folder_path, pn.ff_caught_T_sorted,
                                                            sampling_rate=pn.sampling_rate)

# Get captures

In [None]:
# Example wiring (mirrors your original usage)
valid_captures_df, filtered_no_capture_stops_df, stops_with_stats = extract_stops_utils.prepare_no_capture_and_captures(
    monkey_information=pn.monkey_information,
    closest_stop_to_capture_df=pn.closest_stop_to_capture_df,
    ff_caught_T_new=pn.ff_caught_T_new,
    min_stop_duration=0.02,
    max_stop_duration=1.0,
    capture_match_window=0.3,
    distance_thresh=25.0,
    distance_col="distance_from_ff_to_stop",
)


In [None]:
valid_captures_df

In [None]:
filtered_no_capture_stops_df

# Get misses

##  one

In [None]:
pn.make_one_stop_w_ff_df()
one_stop_miss_df = pn.one_stop_w_ff_df[['first_stop_point_index', 'first_stop_time', 'latest_visible_ff', 'ff_distance', 'min_distance_from_adjacent_stops']].copy()
one_stop_miss_df.rename(columns={'first_stop_point_index': 'stop_point_index', 'first_stop_time': 'stop_time'}, inplace=True)

## more

In [None]:
pn.get_try_a_few_times_info()
pn.get_give_up_after_trying_info()


In [None]:
import pandas as pd


# Expand so each stop_index gets its own row
GUAT_expanded = pn.GUAT_trials_df.explode("stop_indices").reset_index(drop=True)

# Optionally rename column
GUAT_expanded = GUAT_expanded.rename(columns={"stop_indices": "stop_point_index"})
GUAT_expanded['stop_time'] = pn.monkey_information['time'].loc[GUAT_expanded['stop_point_index']].values

TAFT_expanded = pn.TAFT_trials_df.explode("stop_indices").reset_index(drop=True)

# Optionally rename column
TAFT_expanded = TAFT_expanded.rename(columns={"stop_indices": "stop_point_index"})
TAFT_expanded['stop_time'] = pn.monkey_information['time'].loc[TAFT_expanded['stop_point_index']].values


# group TAFT_expanded by stop_cluster_id and drop the last row of each group
TAFT_expanded.sort_values('stop_point_index', inplace=True)
TAFT_expanded2 = TAFT_expanded.groupby('stop_cluster_id').apply(lambda x: x.iloc[:-1]).reset_index(drop=True)


In [None]:
# only preserve the first stop of each stop cluster
GUAT_expanded3 = GUAT_expanded.groupby('stop_cluster_id').first().reset_index()
TAFT_expanded3 = TAFT_expanded.groupby('stop_cluster_id').first().reset_index()



In [None]:
# for GUAT, separate the last stop of each stop cluster
GUAT_expanded4_first_several = GUAT_expanded.groupby('stop_cluster_id').apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
GUAT_expanded4_last = GUAT_expanded.groupby('stop_cluster_id').last().reset_index()

TAFT_expanded4_first_several = TAFT_expanded.groupby('stop_cluster_id').apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
both_first_several = pd.concat([GUAT_expanded4_first_several[['stop_point_index', 'stop_time']], TAFT_expanded4_first_several[['stop_point_index', 'stop_time']]])
both_first_several.sort_values('stop_point_index', inplace=True)
both_first_several.reset_index(drop=True, inplace=True)
both_first_several


# run class

In [None]:
reload(stops_psth)

In [None]:
# cfg = stops_psth.PSTHConfig(
#     pre_window=1.0,
#     post_window=1.0,
#     bin_width=0.02,
#     smoothing_sigma=0.05,
#     min_trials=5,
#     normalize="zscore",            # try: None, "sub", or "div"
# )

cfg = stops_psth.PSTHConfig(
    pre_window=0.5,
    post_window=0.5,
    bin_width=0.05,
    smoothing_sigma=0.1,
    min_trials=5,
    normalize="zscore",            # try: None, "sub", or "div"
)

# an = stops_psth.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
#                                                  captures_df=valid_captures_df,
#                                                  #no_capture_stops_df=filtered_no_capture_stops_df,
#                                                  no_capture_stops_df=one_stop_miss_df
#                                                  )

# an = stops_psth.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
#                                                  captures_df=TAFT_expanded2,
#                                                  no_capture_stops_df=GUAT_expanded
#                                                  )

# an = stops_psth.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
#                                                  captures_df=TAFT_expanded2,
#                                                  no_capture_stops_df=one_stop_miss_df
#                                                  )

an = stops_psth.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
                                                 captures_df=both_first_several,
                                                 no_capture_stops_df=GUAT_expanded4_last
                                                 )

# Per-cluster plots with bands
fig1 = an.plot_psth(cluster_idx=None, show_individual=False)

# Overlay comparison
fig2 = an.plot_comparison(cluster_idx=0)

plt.show()

# Stats in early post-stop window
stats_ = an.statistical_comparison(time_window=(0.0, 0.5))


df = psth_postprocessing.export_psth_to_df(an)              # all clusters
df_c0 = psth_postprocessing.export_psth_to_df(an, [0])      # just the first cluster


windows = {
    "pre_bump(-0.3–0.0)": (-0.3, 0.0),
    "early_dip(0.0–0.3)": (0.0, 0.3),
    "late_rebound(0.3–0.8)": (0.3, 0.8),
}
summary = psth_postprocessing.compare_windows(an, windows, alpha=0.05)
summary.sort_values(["window","p"]).head(12)




In [None]:
# keep only rows where sig_FDR is True
sig_rows = summary[summary["sig_FDR"]]

# plot effect sizes by epoch
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(8,5))
sns.barplot(data=sig_rows, x="window", y="cohens_d", hue="cluster", dodge=True)
plt.axhline(0, color="k", lw=1)
plt.ylabel("Cohen's d (capture − miss)")
plt.title("Significant neurons across epochs")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


# More plots

## Quickly plot PSTHs for the top significant neurons

In [None]:
import numpy as np

def plot_top_psths(analyzer, summary: pd.DataFrame, epoch: str, top_k=6):
    # pick significant clusters in the epoch, ranked by |d|
    g = summary[(summary["window"] == epoch) & (summary["sig_FDR"])].copy()
    if g.empty:
        print(f"No significant clusters for {epoch}."); return
    g = g.sort_values("cohens_d", key=lambda s: s.abs(), ascending=False).head(top_k)

    # map string cluster ids back to analyzer cluster indices
    plotted = 0
    for cl_str in g["cluster"]:
        # analyzer.clusters holds original IDs (numeric or str)
        # coerce both sides to string for robust matching
        matches = np.where(np.array(list(map(str, analyzer.clusters))) == str(cl_str))[0]
        if len(matches) == 0: 
            continue
        ci = int(matches[0])
        analyzer.plot_comparison(cluster_idx=ci)  # your existing method
        plotted += 1
    if plotted == 0:
        print("Nothing plotted (no matches).")

# usage
plot_top_psths(an, summary, "early_dip(0.0–0.3)", top_k=7)


## Heatmap of effect sizes

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_sig_heatmap(summary: pd.DataFrame, title="Significant effects (Cohen's d)"):
    # keep only FDR-significant rows
    sig = summary[summary["sig_FDR"]].copy()
    if sig.empty:
        print("No significant results to plot.")
        return

    # pivot to clusters × windows, values = d
    pivot = sig.pivot_table(index="cluster", columns="window", values="cohens_d", aggfunc="mean")

    # optional: sort clusters by strongest absolute effect
    order = np.argsort(-pivot.abs().max(axis=1).values)
    pivot = pivot.iloc[order]

    # plot
    fig, ax = plt.subplots(figsize=(8, max(3, 0.35 * len(pivot))))
    im = ax.imshow(pivot.values, aspect="auto", cmap="coolwarm", vmin=-np.nanmax(abs(pivot.values)), vmax=np.nanmax(abs(pivot.values)))
    ax.set_xticks(range(pivot.shape[1])); ax.set_xticklabels(pivot.columns, rotation=30, ha="right")
    ax.set_yticks(range(pivot.shape[0])); ax.set_yticklabels(pivot.index)
    ax.set_title(title)
    cbar = plt.colorbar(im, ax=ax); cbar.set_label("Cohen's d (capture − miss)")
    plt.tight_layout()
    plt.show()

# usage
plot_sig_heatmap(summary)


## Bar chart of significant effects per epoch (one bar per cluster)

In [None]:
def plot_sig_bars(summary: pd.DataFrame, epoch: str):
    g = summary[(summary["window"] == epoch) & (summary["sig_FDR"])].copy()
    if g.empty:
        print(f"No significant clusters for {epoch}."); return
    g = g.sort_values("cohens_d", key=lambda s: s.abs(), ascending=False)

    fig, ax = plt.subplots(figsize=(10, max(3, 0.35 * len(g))))
    ax.barh(g["cluster"], g["cohens_d"])
    ax.axvline(0, color="k", lw=1, alpha=0.5)
    ax.set_xlabel("Cohen's d (capture − miss)")
    ax.set_ylabel("Cluster")
    ax.set_title(f"Significant clusters in {epoch}")
    plt.tight_layout(); plt.show()

# usage
plot_sig_bars(summary, "pre_bump(-0.3–0.0)")
plot_sig_bars(summary, "early_dip(0.0–0.3)")
plot_sig_bars(summary, "late_rebound(0.3–0.8)")


# Appendix

## Validate near-miss single stops

In [None]:
# Check whether "near-miss" stops (one_stop_w_ff_df) are truly not part of a stop cluster.  
# A stop cluster is defined as ≥ 2 stops where each consecutive stop is within 50 cm (cumulative distance).

# --- Step 1: Create one-stop dataframe and assign cluster IDs
pn.make_one_stop_w_ff_df()
pn.monkey_information = find_GUAT_or_TAFT_trials.add_stop_cluster_id(pn.monkey_information)

# --- Step 2: Build stop-cluster summary
stop_cluster_df = (
    pn.monkey_information.loc[pn.monkey_information['whether_new_distinct_stop'], ['point_index', 'stop_cluster_id']]
    .copy()
)
stop_cluster_df['num_stops_in_cluster'] = (
    stop_cluster_df.groupby('stop_cluster_id')['point_index'].transform('count')
)

# --- Step 3: Merge cluster info into one-stop dataframe (if not already present)
if 'stop_cluster_id' not in pn.one_stop_w_ff_df.columns:
    pn.one_stop_w_ff_df = pn.one_stop_w_ff_df.merge(
        stop_cluster_df.rename(columns={'point_index': 'first_stop_point_index'}),
        on='first_stop_point_index',
        how='left'
    )

# --- Step 4: Inspect any one-stop rows that actually fall in a multi-stop cluster
pn.one_stop_w_ff_df[pn.one_stop_w_ff_df['num_stops_in_cluster'] > 1]


## Check dt between stops in clusters

In [None]:
import pandas as pd


# Expand so each stop_index gets its own row
GUAT_expanded = pn.GUAT_trials_df.explode("stop_indices").reset_index(drop=True)

# Optionally rename column
GUAT_expanded = GUAT_expanded.rename(columns={"stop_indices": "stop_point_index"})
GUAT_expanded['stop_time'] = pn.monkey_information['time'].loc[GUAT_expanded['stop_point_index']].values

TAFT_expanded = pn.TAFT_trials_df.explode("stop_indices").reset_index(drop=True)

# Optionally rename column
TAFT_expanded = TAFT_expanded.rename(columns={"stop_indices": "stop_point_index"})
TAFT_expanded['stop_time'] = pn.monkey_information['time'].loc[TAFT_expanded['stop_point_index']].values


# group TAFT_expanded by stop_cluster_id and drop the last row of each group
TAFT_expanded.sort_values('stop_point_index', inplace=True)
TAFT_expanded2 = TAFT_expanded.groupby('stop_cluster_id').apply(lambda x: x.iloc[:-1]).reset_index(drop=True)


In [None]:
GUAT_expanded['dt'] = GUAT_expanded['stop_time'].diff()
TAFT_expanded['dt'] = TAFT_expanded['stop_time'].diff()
GUAT_sub = GUAT_expanded[GUAT_expanded['dt'] < 0.5]
TAFT_sub = TAFT_expanded[TAFT_expanded['dt'] < 0.5]
GUAT_sub


In [None]:
sns.histplot(GUAT_sub[['dt']])
plt.show()

In [None]:
sns.histplot(TAFT_sub[['dt']])
plt.show()

In [None]:
GUAT_sub[['dt']].describe()

In [None]:
TAFT_sub[['dt']].describe()

## check inter-stop intervals

In [None]:
unique_stops_df = extract_stops_utils.extract_unique_stops(pn.monkey_information)
onsets = unique_stops_df['time'].to_numpy()

In [None]:
returned = extract_stops_utils.plot_inter_stop_intervals(onsets)
ax1 = returned['ax1']

# add an additional vertical line to the linear plot
additional_vline = 0.2
ax1.axvline(additional_vline, linestyle="--", color="b", alpha=0.8, label=f"x = {additional_vline}s")
ax1.legend()
plt.show()

## make stop clusters

In [None]:
import numpy as np
import pandas as pd

def make_stop_clusters(
    stops_df: pd.DataFrame,
    ff_caught_T_new: np.ndarray | None = None,
    capture_match_window: float = 0.3,
    isi_threshold: float = 0.5,
    time_col: str = "stop_time",
    event_type_col: str | None = "event_type",  # set to None if not available
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Group temporally close stops into clusters and label cluster outcome.

    Parameters
    ----------
    stops_df : DataFrame
        Must contain column `time_col` (seconds). Optionally `event_type_col`
        with values {"capture","miss"} at the stop level.
    ff_caught_T_new : array-like or None
        Sorted capture times (s). Required if `event_type_col` is None.
    capture_match_window : float
        A stop is 'capture' if within this |Δt| to a capture time (when inferring).
    isi_threshold : float
        Two consecutive stops belong to the same cluster if their time gap <= threshold.
    time_col : str
        Column name for stop times.
    event_type_col : str | None
        Column name with stop labels. If None, labels will be inferred.

    Returns
    -------
    clusters_df : DataFrame
        One row per cluster:
          ['cluster_id','start','end','duration','n_stops','outcome',
           'stop_indices','stop_times']
        outcome ∈ {"success","giveup"} where "success" = any capture stop in cluster.
    stops_with_cluster : DataFrame
        Original stops annotated with:
          ['cluster_id','pos_in_cluster','is_capture','cluster_outcome']
    """
    if time_col not in stops_df.columns:
        raise ValueError(f"`stops_df` must contain '{time_col}'")

    df = stops_df.sort_values(time_col).reset_index(drop=False).rename(columns={"index":"orig_index"}).copy()
    t = df[time_col].to_numpy().astype(float)
    n = len(df)

    # --- determine per-stop capture flag ---
    if event_type_col is not None and event_type_col in df.columns:
        is_capture = (df[event_type_col].to_numpy() == "capture")
    else:
        if ff_caught_T_new is None:
            raise ValueError("Provide ff_caught_T_new or an event_type_col with 'capture'/'miss'.")
        cap = np.asarray(ff_caught_T_new, dtype=float)
        # vectorized nearest-neighbor distance in time
        idx = np.searchsorted(cap, t, side="left")
        left_dt  = np.where(idx > 0,            np.abs(t - cap[np.clip(idx-1,0,len(cap)-1)]), np.inf)
        right_dt = np.where(idx < cap.size,     np.abs(cap[idx] - t),                          np.inf)
        min_dt = np.minimum(left_dt, right_dt)
        is_capture = (min_dt <= capture_match_window)

    # --- build clusters by ISI threshold ---
    cluster_id = np.empty(n, dtype=int)
    pos_in_cluster = np.empty(n, dtype=int)
    cid = 0
    start_idx = 0
    for i in range(n):
        if i == 0:
            cluster_id[i] = cid
            pos_in_cluster[i] = 1
            continue
        gap = t[i] - t[i-1]
        if gap <= isi_threshold:
            # same cluster
            cluster_id[i] = cid
            pos_in_cluster[i] = pos_in_cluster[i-1] + 1
        else:
            # new cluster
            cid += 1
            cluster_id[i] = cid
            pos_in_cluster[i] = 1

    df["cluster_id"] = cluster_id
    df["pos_in_cluster"] = pos_in_cluster
    df["is_capture"] = is_capture

    # --- cluster-level aggregation ---
    grp = df.groupby("cluster_id", sort=True)
    start = grp[time_col].min()
    end = grp[time_col].max()
    duration = end - start
    n_stops = grp.size()
    any_capture = grp["is_capture"].any()
    outcome = np.where(any_capture, "success", "giveup")

    # collect member indices/times for reference
    members_idx = grp["orig_index"].apply(lambda x: x.to_list())
    members_times = grp[time_col].apply(lambda x: x.to_list())

    clusters_df = pd.DataFrame({
        "cluster_id": start.index,
        "start": start.values.astype(float),
        "end": end.values.astype(float),
        "duration": duration.values.astype(float),
        "n_stops": n_stops.values.astype(int),
        "outcome": outcome.astype(str),
        "stop_indices": members_idx.values,
        "stop_times": members_times.values,
    }).sort_values("start").reset_index(drop=True)

    # annotate stops with cluster outcome
    outcome_map = clusters_df.set_index("cluster_id")["outcome"].to_dict()
    df["cluster_outcome"] = df["cluster_id"].map(outcome_map)

    # return stops with original order preserved (plus annotations)
    stops_with_cluster = df.sort_values("orig_index").reset_index(drop=True)

    return clusters_df, stops_with_cluster


In [None]:
clusters_df, stops_annot = make_stop_clusters(
    stops_df=my_stops_df,                  # must have 'stop_time'
    ff_caught_T_new=ff_caught_T_new,       # if you don't already have event_type
    capture_match_window=0.3,
    isi_threshold=0.5,                      # tune: 0.3–0.6s works well
    time_col="stop_time",
    event_type_col=None                     # or "event_type" if present
)

print(clusters_df.head())
# cluster_id | start | end | duration | n_stops | outcome | stop_indices | stop_times

print(stops_annot.head())
# ... original columns ... + cluster_id | pos_in_cluster | is_capture | cluster_outcome


## censor_mask_for_event

censor_mdef censor_mask_for_event(t0, all_stops, time_axis, pad=0.15):
    abs_times = t0 + time_axis
    dmin = np.min(np.abs(abs_times[:,None] - all_stops[None,:]), axis=1)
    keep = (dmin >= pad) | np.isclose(abs_times, t0)
    return keep  # bool, shape (len(time_axis),)
ask_for_event