# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

import os, sys, sys
from pathlib import Path
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break
    
from data_wrangling import specific_utils, process_monkey_information, general_utils
from pattern_discovery import pattern_by_trials, pattern_by_trials, cluster_analysis, organize_patterns_and_features
from visualization.matplotlib_tools import plot_behaviors_utils
from neural_data_analysis.neural_analysis_tools.get_neural_data import neural_data_processing
from neural_data_analysis.neural_analysis_tools.visualize_neural_data import plot_neural_data, plot_modeling_result
from neural_data_analysis.neural_analysis_tools.model_neural_data import transform_vars, neural_data_modeling, drop_high_corr_vars, drop_high_vif_vars
from neural_data_analysis.topic_based_neural_analysis.neural_vs_behavioral import prep_monkey_data, prep_target_data, neural_vs_behavioral_class
from neural_data_analysis.topic_based_neural_analysis.planning_and_neural import planning_and_neural_class, pn_utils, pn_helper_class, pn_aligned_by_seg, pn_aligned_by_event
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class
from neural_data_analysis.neural_analysis_tools.cca_methods import cca_class, cca_utils, cca_cv_utils
from neural_data_analysis.neural_analysis_tools.cca_methods.cca_plotting import cca_plotting, cca_plot_lag_vs_no_lag, cca_plot_cv
from machine_learning.ml_methods import regression_utils, regz_regression_utils, ml_methods_class, classification_utils, ml_plotting_utils, ml_methods_utils
from planning_analysis.show_planning import nxt_ff_utils, show_planning_utils
from neural_data_analysis.neural_analysis_tools.gpfa_methods import elephant_utils, fit_gpfa_utils, plot_gpfa_utils, gpfa_helper_class
from neural_data_analysis.neural_analysis_tools.align_trials import time_resolved_regression, time_resolved_gpfa_regression,plot_time_resolved_regression
from neural_data_analysis.neural_analysis_tools.align_trials import align_trial_utils

from neural_data_analysis.topic_based_neural_analysis.around_stops import psth_around_stops, stop_analysis_utils

import sys
import math
import gc
import subprocess
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
from scipy import linalg, interpolate
from scipy.signal import fftconvolve
from scipy.io import loadmat
from scipy import sparse
import torch
from numpy import pi
import cProfile
import pstats

# Machine Learning imports
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.cross_decomposition import CCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.multivariate.cancorr import CanCorr

# Neuroscience specific imports
import neo
import rcca

# To fit gpfa
import numpy as np
from importlib import reload
from scipy.integrate import odeint
import quantities as pq
import neo
from elephant.spike_train_generation import inhomogeneous_poisson_process
from elephant.gpfa import GPFA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from elephant.gpfa import gpfa_core, gpfa_util

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"
pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)

print("done")


%load_ext autoreload
%autoreload 2

# retrieve data

In [None]:
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0312"
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0330"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0316"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0327"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"

In [None]:
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0416"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0321"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0329"
# raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Schro/data_0403"

In [None]:
reduce_y_var_lags = False
planning_data_by_point_exists_ok = True
y_data_exists_ok = True

pn = pn_aligned_by_event.PlanningAndNeuralEventAligned(raw_data_folder_path=raw_data_folder_path)
pn.prep_data_to_analyze_planning(planning_data_by_point_exists_ok=planning_data_by_point_exists_ok)
# pn.planning_data_by_point, cols_to_drop = general_utils.drop_columns_with_many_nans(
#     pn.planning_data_by_point)
#pn.get_x_and_y_data_for_modeling(exists_ok=y_data_exists_ok, reduce_y_var_lags=reduce_y_var_lags)

if not hasattr(pn, 'spikes_df'):
    pn.retrieve_or_make_monkey_data()
    pn.spikes_df = neural_data_processing.make_spikes_df(pn.raw_data_folder_path, pn.ff_caught_T_sorted,
                                                            sampling_rate=pn.sampling_rate)

# Get misses

In [None]:
def filter_stops_based_on_distance_to_ff_capture(filtered_stops_df, monkey_information, ff_caught_T_new, min_cum_distance_to_ff_capture):
    # eliminate the stops that are too close to a ff capture (within min_cum_distance_to_ff_capture)

    # first find the corresponding point index of each time point in ff_caught_T_new
    ff_caught_points_sorted = np.searchsorted(
        monkey_information['time'].values, ff_caught_T_new)
    ff_caught_points_df = monkey_information.iloc[ff_caught_points_sorted].copy(
    )

    # for each value in filtered_stops_df's cum_distance column, find the closest cum_distance in ff_caught_points
    filtered_stops_df['distance_to_next_ff_capture'] = filtered_stops_df['cum_distance'].apply(
        lambda x: np.abs(ff_caught_points_df['cum_distance'].values - x).min())
    # then, eliminate the stops that are too close to a capture
    filtered_stops_df = filtered_stops_df[filtered_stops_df['distance_to_next_ff_capture']
                                          > min_cum_distance_to_ff_capture].copy()

    return filtered_stops_df

In [None]:
pn.find_patterns()

In [None]:
pn.get_give_up_after_trying_info()

In [None]:
there can be multiple stop_point_index for the same ff_index. This can happen when

In [None]:
pn.get_try_a_few_times_info

In [None]:
pn.make_one_stop_w_ff_df()

In [None]:
pn.one_stop_df

In [None]:
pn.one_stop_df

In [None]:
pn.one_stop_df[['point_index']].drop_duplicates().shape

In [None]:
pn.one_stop_w_ff_df

In [None]:
pn.one_stop_w_ff_df

# Get data

In [None]:
cfg = psth_around_stops.PSTHConfig(
    pre_window=1.0,
    post_window=1.0,
    bin_width=0.02,
    smoothing_sigma=0.05,
    min_trials=5,
    normalize="zscore",            # try: None, "sub", or "div"
)

an = psth_around_stops.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg)

# Per-cluster plots with bands
fig1 = an.plot_psth(cluster_idx=None, show_individual=False)

# Overlay comparison
fig2 = an.plot_comparison(cluster_idx=0)

# Stats in early post-stop window
stats_ = an.statistical_comparison(time_window=(0.0, 0.5))


df = psth_around_stops.export_psth_to_df(an)              # all clusters
df_c0 = psth_around_stops.export_psth_to_df(an, [0])      # just the first cluster


windows = {
    "pre_bump(-0.3–0.0)": (-0.3, 0.0),
    "early_dip(0.0–0.3)": (0.0, 0.3),
    "late_rebound(0.3–0.8)": (0.3, 0.8),
}
summary = psth_around_stops.compare_windows(an, windows, alpha=0.05)
summary.sort_values(["window","p"]).head(12)


In [None]:
# assume you already have `analyzer` built and run_full_analysis() done
res_pre  = an.statistical_comparison(time_window=(-0.3, 0.0))
res_early= an.statistical_comparison(time_window=(0.0, 0.3))
res_late = an.statistical_comparison(time_window=(0.3, 0.8))

# Access Cluster 0 by its original ID or inspect keys:
print(list(res_pre.keys())[:5])     # cluster-id strings (e.g., '0','1','7',...)
print(res_pre[str(an.clusters[0])])


In [None]:
# True/False per cluster
sig_any = summary.groupby("cluster")["sig_FDR"].any()  # index = cluster

# align to rows via map
mask = summary["cluster"].map(sig_any)                 # boolean per row
summary_any = summary[mask]                            # rows whose cluster is sig in ANY window

clusters_with_signal = sig_any[sig_any].index          # clusters with any True
summary_any = summary[ summary["cluster"].isin(clusters_with_signal) ]

summary_any = summary.merge(sig_any.rename("sig_any"), left_on="cluster", right_index=True)
summary_any = summary_any[ summary_any["sig_any"] ]

summary_any

In [None]:
# keep only rows where sig_FDR is True
sig_rows = summary[summary["sig_FDR"]]

# plot effect sizes by epoch
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
sns.barplot(data=sig_rows, x="window", y="cohens_d", hue="cluster", dodge=True)
plt.axhline(0, color="k", lw=1)
plt.ylabel("Cohen's d (capture − miss)")
plt.title("Significant neurons across epochs")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


## my method

In [None]:
min_stop_duration = 0.02
max_stop_duration = 1

# get stop_id for pn.closest_stop_to_capture_df
pn.closest_stop_to_capture_df['stop_id'] = pn.monkey_information.loc[pn.closest_stop_to_capture_df['point_index'], 'stop_id'].values
captures_df = pn.closest_stop_to_capture_df[['cur_ff_index', 'stop_id', 'time', 'point_index', 'stop_time', 'distance_from_ff_to_stop']]

stops_df = pn.monkey_information[pn.monkey_information['stop_id'].notna()].copy()

# add stop duration
stop_stats = stops_df.groupby("stop_id")["time"].agg(
    stop_id_start_time="min",
    stop_id_end_time="max"
)
stop_stats["stop_id_duration"] = (
    stop_stats["stop_id_end_time"] - stop_stats["stop_id_start_time"]
)

stops_df = stops_df.merge(stop_stats, on="stop_id", how="left").sort_values('point_index', ascending=True)

stops_df = stops_df.groupby('stop_id').first().reset_index()
no_capture_stops_df = stops_df[~stops_df['stop_id'].isin(captures_df['stop_id'])].reset_index(drop=True)

# filter by min_stop_duration
no_capture_stops_df = no_capture_stops_df[no_capture_stops_df['stop_id_duration'] >= min_stop_duration].reset_index(drop=True)

# filter by max_stop_duration
no_capture_stops_df = no_capture_stops_df[no_capture_stops_df['stop_id_duration'] <= max_stop_duration].reset_index(drop=True)

# now, drop the captures where distance_from_ff_to_stop > 25
captures_df_cleaned = captures_df[captures_df['distance_from_ff_to_stop'] <= 25].copy()


In [None]:
capture_match_window = 0.3
no_capture_stops_df_filtered = stop_analysis_utils.filter_no_capture_stops_vectorized(no_capture_stops_df, pn.ff_caught_T_new, capture_match_window)
no_capture_stops_df_filtered

## run class

In [None]:
cfg = psth_around_stops.PSTHConfig(
    pre_window=1.0,
    post_window=1.0,
    bin_width=0.02,
    smoothing_sigma=0.05,
    min_trials=5,
    normalize="zscore",            # try: None, "sub", or "div"
)


an = psth_around_stops.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
                                                 captures_df=captures_df_cleaned,
                                                 no_capture_stops_df=no_capture_stops_df_filtered)

an.identify_stop_events()

# Per-cluster plots with bands
fig1 = an.plot_psth(cluster_idx=None, show_individual=False)

# Overlay comparison
fig2 = an.plot_comparison(cluster_idx=0)

plt.show()

# Stats in early post-stop window
stats_ = an.statistical_comparison(time_window=(0.0, 0.5))


df = psth_around_stops.export_psth_to_df(an)              # all clusters
df_c0 = psth_around_stops.export_psth_to_df(an, [0])      # just the first cluster


windows = {
    "pre_bump(-0.3–0.0)": (-0.3, 0.0),
    "early_dip(0.0–0.3)": (0.0, 0.3),
    "late_rebound(0.3–0.8)": (0.3, 0.8),
}
summary = psth_around_stops.compare_windows(an, windows, alpha=0.05)
summary.sort_values(["window","p"]).head(12)


In [None]:
plt.show()

In [None]:
summary = psth_around_stops.compare_windows(an, windows, alpha=0.05)
summary.sort_values(["window","p"]).head(12)

## try 2

In [None]:
cfg = psth_around_stops.PSTHConfig(
    pre_window=1.0,
    post_window=1.0,
    bin_width=0.02,
    smoothing_sigma=0.05,
    min_trials=5,
    normalize="zscore",            # try: None, "sub", or "div"
)

# Make sure your monkey_information has the stop_id system
from data_wrangling.process_monkey_information import add_more_columns_to_monkey_information

# Add stop_id system to your monkey_information
pn.monkey_information = add_more_columns_to_monkey_information(pn.monkey_information)

# # Now use PSTH analysis - it will use your stop_id system
# an = psth_around_stops.PSTHAnalyzer(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
#                                           captures_df=captures_df_cleaned,
#                                           no_capture_stops_df=no_capture_stops_df_filtered)


an = psth_around_stops.create_psth_around_stops(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg,
                                                 captures_df=captures_df_cleaned,
                                                 no_capture_stops_df=no_capture_stops_df_filtered)

an.identify_stop_events()

# Per-cluster plots with bands
fig1 = an.plot_psth(cluster_idx=None, show_individual=False)

# Overlay comparison
fig2 = an.plot_comparison(cluster_idx=0)

plt.show()

# Stats in early post-stop window
stats_ = an.statistical_comparison(time_window=(0.0, 0.5))


df = psth_around_stops.export_psth_to_df(an)              # all clusters
df_c0 = psth_around_stops.export_psth_to_df(an, [0])      # just the first cluster


windows = {
    "pre_bump(-0.3–0.0)": (-0.3, 0.0),
    "early_dip(0.0–0.3)": (0.0, 0.3),
    "late_rebound(0.3–0.8)": (0.3, 0.8),
}
summary = psth_around_stops.compare_windows(an, windows, alpha=0.05)
summary.sort_values(["window","p"]).head(12)


# chart

In [None]:
multiff_code/notebooks/neural_data_analysis/selection_comparison_chart.py

In [None]:
import neural_data_analysis.selection_comparison_chart

In [None]:
!pwd

In [None]:
    fig = create_selection_comparison_chart()
    plt.show()



In [None]:

    # Create and display the comparison table
    comparison_df = create_comparison_table()
    print("\nDetailed Selection Criteria Comparison:")
    print("=" * 80)
    print(comparison_df.to_string(index=False))


In [None]:
# Import the chart functions
from multiff_analysis.selection_comparison_chart import (
    create_selection_comparison_chart, create_comparison_table
)

# Create and display the chart
fig = create_selection_comparison_chart()
plt.show()

# Create and display the comparison table
comparison_df = create_comparison_table()
print(comparison_df.to_string(index=False))

In [None]:
# Import the functions
from neural_data_analysis.selection_comparison_chart import (
    create_selection_comparison_chart, create_comparison_table
)

# Create and display the visual chart
fig = create_selection_comparison_chart()
plt.show()

# Create and display the comparison table
comparison_df = create_comparison_table()
print(comparison_df.to_string(index=False))

# debug

In [None]:
# After running your PSTH analysis
analyzer = psth_around_stops.PSTHAnalyzer(pn.spikes_df, pn.monkey_information, pn.ff_caught_T_new, cfg)
analyzer.identify_stop_events()

# Print detailed diagnostic report
analyzer.print_capture_diagnostic()

In [None]:
analyzer._find_stops_from_speed_fallback()

In [None]:
missing = [9, 18, 19, 27, 30, 33, 36, 45, 61, 64, 72, 87, 109, 125, 126, 131, 141, 199, 201, 204, 209, 227, 238, 264, 287, 317, 322, 336, 357, 361, 372, 376, 395, 406, 414, 416, 430, 432, 447, 454, 459, 465, 475, 494, 509, 514, 544, 547, 550, 572, 577, 585, 588, 599, 629, 631, 651, 657, 675, 680, 699, 700, 702, 712, 715, 741, 743, 757, 760, 768, 776, 784, 798, 837, 840, 845, 847, 853, 856, 874, 876, 888, 896, 901, 915, 919, 956, 961, 963, 967, 969, 975, 983, 986, 987, 1003, 1012, 1024, 1032, 1036, 1045, 1061, 1069, 1106, 1116, 1131, 1153, 1170, 1206, 1212, 1214, 1217, 1231, 1242, 1243, 1247, 1252, 1258, 1259, 1268, 1275, 1277, 1281, 1289, 1291, 1295, 1308, 1336, 1337]
pn.closest_stop_to_capture_df[pn.closest_stop_to_capture_df['cur_ff_index'].isin(missing)]

In [None]:
pn.make_or_retrieve_closest_stop_to_capture_df()

In [None]:
pn.closest_stop_to_capture_df[pn.closest_stop_to_capture_df['distance_from_ff_to_stop']> 26].shape

In [None]:
sns.histplot(pn.closest_stop_to_capture_df2.loc[pn.closest_stop_to_capture_df2['distance_from_ff_to_stop']> 25, 'distance_from_ff_to_stop'], bins=100)

In [None]:
sns.histplot(pn.closest_stop_to_capture_df['diff_from_caught_time'], bins=100)
pl

In [None]:
pn.closest_stop_to_capture_df2 = nxt_ff_utils.get_closest_stop_to_all_capture_position(pn.ff_caught_T_sorted, pn.monkey_information, pn.ff_real_position_sorted,
                                                                                       cur_ff_index_array=np.arange(len(pn.ff_caught_T_sorted)))


In [None]:
pn.ff_caught_T_sorted.shape

In [None]:
pn.closest_stop_to_capture_df['diff_from_caught_time'].describe()

In [None]:
(pn.closest_stop_to_capture_df['distance_from_ff_to_stop'] > 30).sum()

In [None]:
analyzer.stop_events[analyzer.stop_events['stop_time'] >= 58.7]

In [None]:
mdf = pn.monkey_information.copy()
m_sub = mdf[mdf['whether_new_distinct_stop']].copy()
m_sub['dt'] = m_sub['time'].diff()
len(m_sub[m_sub['dt'] < 0.2])

In [None]:
133/2838

In [None]:
m_sub.shape

In [None]:
mdf.shape

In [None]:
m_sub[m_sub['dt'] < 0.05]

# 2nd try

In [None]:
# Make sure your monkey_information has the stop_id system
from data_wrangling.process_monkey_information import add_more_columns_to_monkey_information

# Add stop_id system to your monkey_information
pn.monkey_information = add_more_columns_to_monkey_information(pn.monkey_information)

# Now use PSTH analysis - it will use your stop_id system
analyzer = psth_around_stops.PSTHAnalyzer(pn.spikes_df, monkey_information, pn.ff_caught_T_new, cfg)
analyzer.identify_stop_events()

In [None]:
# If you don't have the stop_id system, it will automatically fall back
analyzer = psth_around_stops.PSTHAnalyzer(pn.spikes_df, monkey_information, pn.ff_caught_T_new, cfg)
analyzer.identify_stop_events()  # Uses fallback method

# More plots

## Heatmap of effect sizes

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_sig_heatmap(summary: pd.DataFrame, title="Significant effects (Cohen's d)"):
    # keep only FDR-significant rows
    sig = summary[summary["sig_FDR"]].copy()
    if sig.empty:
        print("No significant results to plot.")
        return

    # pivot to clusters × windows, values = d
    pivot = sig.pivot_table(index="cluster", columns="window", values="cohens_d", aggfunc="mean")

    # optional: sort clusters by strongest absolute effect
    order = np.argsort(-pivot.abs().max(axis=1).values)
    pivot = pivot.iloc[order]

    # plot
    fig, ax = plt.subplots(figsize=(8, max(3, 0.35 * len(pivot))))
    im = ax.imshow(pivot.values, aspect="auto", cmap="coolwarm", vmin=-np.nanmax(abs(pivot.values)), vmax=np.nanmax(abs(pivot.values)))
    ax.set_xticks(range(pivot.shape[1])); ax.set_xticklabels(pivot.columns, rotation=30, ha="right")
    ax.set_yticks(range(pivot.shape[0])); ax.set_yticklabels(pivot.index)
    ax.set_title(title)
    cbar = plt.colorbar(im, ax=ax); cbar.set_label("Cohen's d (capture − miss)")
    plt.tight_layout()
    plt.show()

# usage
plot_sig_heatmap(summary)


## Bar chart of significant effects per epoch (one bar per cluster)

In [None]:
def plot_sig_bars(summary: pd.DataFrame, epoch: str):
    g = summary[(summary["window"] == epoch) & (summary["sig_FDR"])].copy()
    if g.empty:
        print(f"No significant clusters for {epoch}."); return
    g = g.sort_values("cohens_d", key=lambda s: s.abs(), ascending=False)

    fig, ax = plt.subplots(figsize=(10, max(3, 0.35 * len(g))))
    ax.barh(g["cluster"], g["cohens_d"])
    ax.axvline(0, color="k", lw=1, alpha=0.5)
    ax.set_xlabel("Cohen's d (capture − miss)")
    ax.set_ylabel("Cluster")
    ax.set_title(f"Significant clusters in {epoch}")
    plt.tight_layout(); plt.show()

# usage
plot_sig_bars(summary, "pre_bump(-0.3–0.0)")
plot_sig_bars(summary, "early_dip(0.0–0.3)")
plot_sig_bars(summary, "late_rebound(0.3–0.8)")


## Quickly plot PSTHs for the top significant neurons

In [None]:
import numpy as np

def plot_top_psths(analyzer, summary: pd.DataFrame, epoch: str, top_k=6):
    # pick significant clusters in the epoch, ranked by |d|
    g = summary[(summary["window"] == epoch) & (summary["sig_FDR"])].copy()
    if g.empty:
        print(f"No significant clusters for {epoch}."); return
    g = g.sort_values("cohens_d", key=lambda s: s.abs(), ascending=False).head(top_k)

    # map string cluster ids back to analyzer cluster indices
    plotted = 0
    for cl_str in g["cluster"]:
        # analyzer.clusters holds original IDs (numeric or str)
        # coerce both sides to string for robust matching
        matches = np.where(np.array(list(map(str, analyzer.clusters))) == str(cl_str))[0]
        if len(matches) == 0: 
            continue
        ci = int(matches[0])
        analyzer.plot_comparison(cluster_idx=ci)  # your existing method
        plotted += 1
    if plotted == 0:
        print("Nothing plotted (no matches).")

# usage
plot_top_psths(an, summary, "early_dip(0.0–0.3)", top_k=7)


## heatmap of effect sizes (Cohen’s d)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_sig_heatmap(summary: pd.DataFrame, title="Significant effects (Cohen's d)"):
    """
    Plot a heatmap of Cohen's d for significant cluster×epoch combinations.

    Parameters
    ----------
    summary : pd.DataFrame
        Output from summarize_epochs() / compare_windows().
        Must include columns: ['cluster','window','cohens_d','sig_FDR'].
    title : str
        Title for the plot.
    """
    # keep only significant rows
    sig = summary[summary["sig_FDR"]].copy()
    if sig.empty:
        print("No significant results to plot.")
        return

    # pivot into matrix: clusters (rows) × windows (columns)
    pivot = sig.pivot_table(
        index="cluster", columns="window", values="cohens_d", aggfunc="mean"
    )

    # optional: sort clusters by strongest absolute effect
    order = np.argsort(-pivot.abs().max(axis=1).values)
    pivot = pivot.iloc[order]

    # plot
    fig, ax = plt.subplots(figsize=(8, max(3, 0.4 * len(pivot))))
    vmax = np.nanmax(abs(pivot.values))
    im = ax.imshow(pivot.values, aspect="auto", cmap="coolwarm",
                   vmin=-vmax, vmax=vmax)

    ax.set_xticks(range(pivot.shape[1]))
    ax.set_xticklabels(pivot.columns, rotation=30, ha="right")
    ax.set_yticks(range(pivot.shape[0]))
    ax.set_yticklabels(pivot.index)

    ax.set_title(title)
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Cohen's d (capture − miss)")
    plt.tight_layout()
    plt.show()


summary = psth_around_stops.summarize_epochs(an, alpha=0.05)
plot_sig_heatmap(summary)


## heatmaps including all clusters

In [None]:
summary = psth_around_stops.summarize_epochs(an, alpha=0.05)
psth_around_stops.plot_effect_heatmap_all(summary)                       # sort by strongest effect
# or:
psth_around_stops.plot_effect_heatmap_all(summary, order="cluster")      # keep cluster order
