# Using Constant Marker Size

To use constant marker sizes instead of sample size-based marker sizes in `plot_median_heading_across_monkeys_and_arc_types_with_difference`, add the `constant_marker_size` parameter:

```python
# Example: Use constant marker size of 12
cma.plot_median_heading_across_monkeys_and_arc_types_with_difference(
    x_var_column_list=x_var_column_list,
    fixed_variable_values_to_use=fixed_variable_values_to_use,
    changeable_variables=changeable_variables,
    columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color,
    constant_marker_size=12  # Set constant marker size (default is None, which uses sample size-based scaling)
)
```

When `constant_marker_size` is set to a number, all markers will have that size regardless of sample size. When `constant_marker_size` is `None` (default), markers will be sized based on sample size as before.


# Main features
'diff_in_abs_angle_to_nxt_ff', 'diff_in_abs_d_curv', 'dir_from_cur_ff_same_side'

# Import packages

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import os, sys
for p in [Path.cwd()] + list(Path.cwd().parents):
    if p.name == 'Multifirefly-Project':
        os.chdir(p)
        sys.path.insert(0, str(p / 'multiff_analysis/multiff_code/methods'))
        break


from pattern_discovery import monkey_landing_in_ff
from machine_learning.ml_methods import regression_utils, classification_utils, prep_ml_data_utils
from visualization.matplotlib_tools import plot_trials
from planning_analysis.show_planning import examine_null_arcs
from planning_analysis.show_planning.cur_vs_nxt_ff import cvn_from_ref_class
from planning_analysis.only_cur_ff import only_cur_ff_x_sess_class
from planning_analysis.plan_factors import plan_factors_utils, build_factor_comp, plan_factors_class, monkey_plan_factors_x_sess_class
from planning_analysis.agent_analysis import compare_monkey_and_agent_utils, agent_plan_factors_x_sess_class, cmp_monkey_agent_plan_class
from machine_learning.ml_methods import ml_methods_class, prep_ml_data_utils
from planning_analysis.plan_factors import monkey_plan_factors_x_sess_class, monkey_plan_factors_x_sess_class
from planning_analysis.factors_vs_indicators import make_variations_utils, process_variations_utils
from planning_analysis.factors_vs_indicators import make_variations_utils, predict_y_values_class, compare_y_values_class, variations_base_class
from planning_analysis.factors_vs_indicators.plot_plan_indicators import plot_variations_class, plot_variations_utils, parent_assembler

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import rc
import os, sys
from importlib import reload

plt.rcParams["animation.html"] = "html5"
os.environ['KMP_DUPLICATE_LIB_OK']='True'
rc('animation', html='jshtml')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['animation.embed_limit'] = 2**128
pd.set_option('display.float_format', lambda x: '%.5f' % x)
np.set_printoptions(suppress=True)
pd.options.display.max_rows = 101

In [None]:
data_item = None
raw_data_folder_path = "all_monkey_data/raw_monkey_data/monkey_Bruno/data_0328"

# run overnight

## agent

In [None]:
# exists_ok = True
# for opt_arc_type in ['norm_opt_arc', 'opt_arc_stop_closest', 'opt_arc_stop_first_vis_bdry']:
#     pfas = agent_plan_factors_x_sess_class.PlanFactorsAcrossAgentSessions(num_steps_per_dataset=5000,
#                                                                         opt_arc_type = opt_arc_type)
#     agent_all_ref_pooled_median_info = pfas.make_or_retrieve_all_ref_pooled_median_info(exists_ok=exists_ok)
#     agent_all_perc_df = pfas.make_or_retrieve_pooled_perc_info(exists_ok=exists_ok)



## monkey

In [None]:
exists_ok = True
list_of_curv_traj_window_before_stop = [[-25, 0]]

for monkey_name in ['monkey_Schro', 'monkey_Bruno']:
    for opt_arc_type in ['norm_opt_arc', 'opt_arc_stop_closest', 'opt_arc_stop_first_vis_bdry']:
        # suppress printed output
        # with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):

        ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name=monkey_name, 
                                                                opt_arc_type=opt_arc_type)
        all_ref_pooled_median_info = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=exists_ok,
                                                                    list_of_curv_traj_window_before_stop=list_of_curv_traj_window_before_stop,
                                                                    )
        pooled_perc_info = ps.make_or_retrieve_pooled_perc_info(exists_ok=exists_ok)
        
        # all_cur_and_nxt_lr_pred_ff_df = ps.make_or_retrieve_all_cur_and_nxt_lr_pred_ff_df(exists_ok=exists_ok)
        # all_cur_and_nxt_lr_df = ps.make_or_retrieve_all_cur_and_nxt_lr_df(exists_ok=exists_ok)
        # all_cur_and_nxt_clf_df = ps.make_or_retrieve_all_cur_and_nxt_clf_df()

        # # # can comment out the below if not needed
        # osfxs = only_cur_ff_x_sess_class.OnlyStopFFAcrossSessions(monkey_name=monkey_name)
        # all_only_cur_lr_df = osfxs.make_or_retrieve_all_only_cur_lr_df(exists_ok=exists_ok)
        # all_only_cur_ml_df = osfxs.make_or_retrieve_all_only_cur_ml_df(exists_ok=exists_ok)            

# For PPT (but better polished below now)

## heading

In [None]:
for monkey in ['Bruno', 'Schro']:

    ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name=f'monkey_{monkey}')
    monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                                    pooled_median_info_exists_ok=True)
    
    ps.plot_median_heading(monkey_median_df)



    ps.fig.update_layout(
        title=f"{monkey}: Reference Point vs Absolute Angle Difference (Median ± 2 Bootstrap SE)",
        xaxis_title="Reference Point Value in Distance to Current FF(cm)",
        yaxis_title="Median Angle Difference to Next FF (°)"
    )
        
    ps.fig.show()

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name='monkey_Schro')
monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                                pooled_median_info_exists_ok=True)


ps.plot_median_heading(monkey_median_df)



## curv

In [None]:
# Difference in Curvature Across Two Arc Pairs

In [None]:
for monkey in ['Bruno', 'Schro']:

    ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name=f'monkey_{monkey}')
    monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                                    pooled_median_info_exists_ok=True)

    ps.plot_median_curv(monkey_median_df)



    ps.fig.update_layout(
        title=f"{monkey}: Reference Point vs Difference in Curvature Across Two Arc Pairs (Median ± 2 Bootstrap SE)",
        xaxis_title="Reference Point Value in Distance to Current FF (cm)",
        yaxis_title="Median Angle Difference to Next FF"
    )
        

    ps.fig.show()

## perc

In [None]:
for monkey in ['Bruno', 'Schro']:
    ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name=f'monkey_{monkey}')
    monkey_perc_df = ps.make_or_retrieve_pooled_perc_info(exists_ok=True)
    ps.plot_same_side_percentage()
    


    ps.fig.update_layout(
                    title=dict(
                    text=f"{monkey}: Same-Side Stop Rate",
                    x=0.5,         # center title
                    y=0.95,        # a little below the top edge
                    xanchor="center",
                    yanchor="top",
                    font=dict(size=18)
                ),
        yaxis=dict(
            ticksuffix="%",   # just add % symbol to numbers
            title="Percentage (Median ± 2 Bootstrap SE)"
        )
    )
    ps.fig.show()

# single monkey & arc type

## heading

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name='monkey_Schro')
monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                                pooled_median_info_exists_ok=True)



ps.plot_median_heading(monkey_median_df, is_difference=False)
ps.fig


## curv

In [None]:
# Difference in Curvature Across Two Arc Pairs

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name='monkey_Schro')
monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                                pooled_median_info_exists_ok=True)
ps.plot_median_curv()



## perc

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name='monkey_Schro')
monkey_perc_df = ps.make_or_retrieve_pooled_perc_info(exists_ok=True)
ps.plot_same_side_percentage()


## example of changing plotting parameters

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions(monkey_name='monkey_Schro')
monkey_median_df = ps.make_or_retrieve_all_ref_pooled_median_info(exists_ok=True, 
                                                           pooled_median_info_exists_ok=True)

x_var_column_list = ['curv_traj_window_before_stop']

fixed_variable_values_to_use = {
                                #'whether_even_out_dist': True
                                # 'max_curv_range': 200,
                                # 'whether_filter_info': True,
                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                'key_for_split': 'ff_seen',
                                #'whether_even_out_dist': False,
}

changeable_variables = ['whether_even_out_dist']


columns_to_find_unique_combinations_for_color = []
columns_to_find_unique_combinations_for_line = []


ps.plot_median_heading(x_var_column_list=x_var_column_list, 
                                           fixed_variable_values_to_use=fixed_variable_values_to_use,
                                           changeable_variables=changeable_variables,
                                           columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color,
                                           columns_to_find_unique_combinations_for_line=columns_to_find_unique_combinations_for_line,
                                           )

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions()
monkey_median_df = ps.combine_all_ref_pooled_median_info_across_monkeys_and_opt_arc_types()
ps.plot_median_heading_across_monkeys_and_arc_types_with_difference(
                                                                      x_var_column_list=['if_test_nxt_ff_group_appear_after_stop'],
                                                                      fixed_variable_values_to_use={
                                                                                #'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                                                                                    #'key_for_split': 'ff_seen',
                                                                                                    'whether_even_out_dist': False,
                                                                                                    'curv_traj_window_before_stop': '[-25, 0]',
                                                                                                    'opt_arc_type': 'norm_opt_arc',
                                                                                                    'ref_point_value': -100,
                                                                                                    },
                                                                      changeable_variables=[
                                                                          'monkey_name',],
                                                                      columns_to_find_unique_combinations_for_color=['opt_arc_type'],
                                                                      columns_to_find_unique_combinations_for_line=[],
                                                                      )

# all monkeys: ref distance as x

## heading

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions()
monkey_median_df = ps.combine_all_ref_pooled_median_info_across_monkeys_and_opt_arc_types()
#monkey_median_df = monkey_median_df[monkey_median_df['opt_arc_type']=='norm_opt_arc'].copy()
#ps.plot_median_heading_across_monkeys_and_arc_types_with_difference(x_var_column_list=x_var_column_list,changeable_variables=changeable_variables)
ps.plot_median_heading_across_monkeys_and_arc_types_with_difference(monkey_median_df)


## curv

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions()
monkey_median_df = ps.combine_all_ref_pooled_median_info_across_monkeys_and_opt_arc_types()
ps.plot_median_curv_across_monkeys_and_arc_types_with_difference()

## perc

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions()
monkey_perc_df = ps.combine_pooled_perc_info_across_monkeys(pooled_perc_info_exists_ok=True)
ps.plot_same_side_percentage_across_monkeys(x_var_column_list=['monkey_name'], changeable_variables=['monkey_name'])


### polish

In [None]:
# --- 0) Nice, short subplot titles ---
titles = ['Bruno', 'Schro']
for i, t in enumerate(titles):
    if i < len(ps.fig.layout.annotations):
        ps.fig.layout.annotations[i].text = t
        ps.fig.layout.annotations[i].font = dict(size=16, family='Arial', color='#111')
        ps.fig.layout.annotations[i].yshift = 4  # nudge away from frame

# --- 1) Slim, balanced canvas ---
ps.fig.update_layout(
    width=720,  # a touch wider for 2 subplots
    height=360,
    margin=dict(l=70, r=20, t=60, b=70),
    template='plotly_white',
    autosize=False,
)


# --- 3) Axes polish ---
ps.fig.update_xaxes(
    title_font=dict(size=16),
    tickfont=dict(size=13),
    showline=False, linewidth=1, linecolor='#444',
    showgrid=False, gridwidth=1, gridcolor='rgba(0,0,0,0.06)',
    zeroline=False,
)
ps.fig.update_yaxes(
    title_text='Percentage',  # we’ll add label via annotations
    ticksuffix='%'
)

# --- 5) Subtle color cycle (optional, if you have multiple traces) ---
# Remove if you already set colors per trace.
ps.fig.update_layout(colorway=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])

# --- 6) Slightly thicker lines/markers for readability (optional) ---
for tr in ps.fig.data:
    if hasattr(tr, 'line') and tr.line is not None:
        tr.line.width = 2
    if hasattr(tr, 'marker') and tr.marker is not None and 'size' in tr.marker:
        tr.marker.size = max(6, tr.marker.size)

ps.fig.update_yaxes(
    showticklabels=False,   # hide tick labels
    title_text=None,        # remove axis title
    showgrid=False,         # optional: also hide grid lines
    row=1, col=2            # target 2nd subplot
)

ps.fig

## example of changing plotting parameters

In [None]:
ps = monkey_plan_factors_x_sess_class.PlanAcrossSessions()
monkey_median_df = ps.combine_all_ref_pooled_median_info_across_monkeys_and_opt_arc_types()
ps.plot_median_heading_across_monkeys_and_arc_types_with_difference(
                                                                      x_var_column_list=['opt_arc_type'],
                                                                      fixed_variable_values_to_use={
                                                                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                                                                                    'key_for_split': 'ff_seen',
                                                                                                    'whether_even_out_dist': False,
                                                                                                    'curv_traj_window_before_stop': '[-25, 0]'
                                                                                                    },
                                                                      changeable_variables=[
                                                                          'ref_point_value', 'monkey_name'],
                                                                      columns_to_find_unique_combinations_for_color=[],
                                                                      columns_to_find_unique_combinations_for_line=[],
                                                                      )

# Agent: for PPT

In [None]:
#for opt_arc_type in ['norm_opt_arc', 'opt_arc_stop_closest', 'opt_arc_stop_first_vis_bdry']:
cma = cmp_monkey_agent_plan_class.CompareMonkeyAgentPlan(opt_arc_type='norm_opt_arc')
cma.get_monkey_and_agent_all_ref_pooled_median_info()
cma.get_monkey_and_agent_pooled_perc_info()

In [None]:
cma.all_ref_pooled_median_info_heading['monkey_name'].unique()

## heading

In [None]:
x_var_column_list = ['ref_point_value']
fixed_variable_values_to_use = {#'whether_even_out_dist': True,
                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                'key_for_split': 'ff_seen',
                                }
changeable_variables = []
columns_to_find_unique_combinations_for_color = []

agent_median_df = cma.all_ref_pooled_median_info_heading[cma.all_ref_pooled_median_info_heading['monkey_name'] == 'agent'].copy()

cma.plot_median_heading_across_monkeys_and_arc_types_with_difference(all_ref_median_info=agent_median_df, x_var_column_list=x_var_column_list,
                                                                fixed_variable_values_to_use=fixed_variable_values_to_use,
                                                                changeable_variables=changeable_variables,
                                                                columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color,
                                                                constant_marker_size=15)

## new perc plot

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter

def plot_grouped_percent_bars(
    df: pd.DataFrame,
    *,
    group_order=('control', 'test'),
    monkey_order=None,
    bar_width=0.36,
    capsize=4,
    ylabel='Percentage',
    title=None,
    figsize=(7.5, 3.8),
    palette=None,
    ylim=None,
    show_values=False,
    ax=None,
):
    """
    Grouped bar chart with asymmetric CI error bars (Matplotlib).

    Parameters
    ----------
    df : DataFrame
        Columns: ['perc', 'test_or_control', 'ci_lower', 'ci_upper', 'monkey_name'].
        'perc', 'ci_*' should be in 0–100 units.
    group_order : tuple
        Order of groups on each x (e.g., ('control', 'test')).
    monkey_order : sequence or None
        Order of x categories. If None, uses first-appearance order in df.
    bar_width : float
        Width of each bar.
    capsize : float
        Error bar cap size (points).
    ylabel : str
        Y-axis label text.
    title : str or None
        Figure title.
    figsize : tuple
        Figure size in inches.
    palette : dict or None
        Optional mapping {group: color}. If None, Matplotlib defaults are used.
    ylim : tuple or None
        Optional y-limits in percent units, e.g., (45, 80).
    show_values : bool
        If True, annotate bars with their % values.
    ax : matplotlib.axes.Axes or None
        Existing axes to draw on; if None, a new figure/axes is created.

    Returns
    -------
    fig, ax
    """
    d = df.copy()
    d = d[d['test_or_control'].isin(group_order)]  # drop 'difference' etc.

    # Category orders
    if monkey_order is None:
        monkey_order = d['monkey_name'].drop_duplicates().tolist()
    d['monkey_name'] = pd.Categorical(d['monkey_name'], categories=monkey_order, ordered=True)
    d['test_or_control'] = pd.Categorical(d['test_or_control'], categories=list(group_order), ordered=True)
    d = d.sort_values(['monkey_name', 'test_or_control'])

    # X positions per monkey
    monkeys = pd.Index(monkey_order)
    x = np.arange(len(monkeys), dtype=float)

    # Group offsets
    n_groups = len(group_order)
    offsets = np.linspace(-bar_width*(n_groups-1)/2, bar_width*(n_groups-1)/2, n_groups)

    # Prepare axes
    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        created_fig = True
    else:
        fig = ax.figure

    bars_by_group = {}
    for i, grp in enumerate(group_order):
        sub = d[d['test_or_control'] == grp]
        # Ensure alignment with x order
        sub = sub.set_index('monkey_name').reindex(monkeys)

        y = sub['perc'].to_numpy()
        lo = sub['ci_lower'].to_numpy()
        hi = sub['ci_upper'].to_numpy()
        yerr = np.vstack([y - lo, hi - y])

        color = None if palette is None else palette.get(grp, None)
        bars = ax.bar(
            x + offsets[i],
            y,
            width=bar_width,
            yerr=yerr,
            capsize=capsize,
            label=grp,
            color=color,
            edgecolor='none',
        )
        bars_by_group[grp] = bars

        if show_values:
            for rect, val in zip(bars, y):
                ax.annotate(f'{val:.1f}%',
                            xy=(rect.get_x() + rect.get_width()/2, rect.get_height()),
                            xytext=(0, 3), textcoords='offset points',
                            ha='center', va='bottom', fontsize=9)

    # Axes cosmetics
    ax.set_xticks(x, monkeys.tolist())
    ax.set_ylabel(ylabel)
    ax.yaxis.set_major_formatter(PercentFormatter(xmax=100))  # 0–100 data shown as %
    if ylim is not None:
        ax.set_ylim(*ylim)

    ax.yaxis.grid(True, linestyle='-', alpha=0.15)
    ax.set_axisbelow(True)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.legend(title='Group', frameon=False)

    if title:
        ax.set_title(title)

    if created_fig:
        fig.tight_layout()

    return fig, ax



palette = {'control': '#f28e2b', 'test': '#4e79a7'}  # can omit for default

fig, ax = plot_grouped_percent_bars(
    pooled_perc_info,
    title='Same-Side Stop Rate',
    palette=palette,
    ylim=(45, 80),         # tweak or set to None
    show_values=False      # set True to print % on bars
)
plt.show()


In [None]:
palette = {'control': '#f28e2b', 'test': '#4e79a7'}  # can omit for default

fig, ax = plot_grouped_percent_bars(
    pooled_perc_info,
    title='Same-Side Stop Rate',
    palette=palette,
    ylim=(45, 80),         # tweak or set to None
    show_values=False      # set True to print % on bars
)
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter

def _stars_from_p(p: float) -> str:
    if p is None or np.isnan(p):
        return 'n.s.'
    if p < 1e-4:
        return '****'
    if p < 1e-3:
        return '***'
    if p < 0.01:
        return '**'
    if p < 0.05:
        return '*'
    return 'n.s.'

def _add_sig_bracket(ax, x1, x2, y, h=1.0, text='*', lw=1.2, fontsize=11):
    """
    Draw a significance bracket between x1 and x2 at height y with label 'text'.
    y and h are in data units (same as y-axis).
    """
    ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=lw, c='k')
    ax.text((x1 + x2) / 2, y + h, text, ha='center', va='bottom', fontsize=fontsize)

def plot_grouped_percent_bars(
    df: pd.DataFrame,
    *,
    group_order=('control', 'test'),
    monkey_order=None,
    bar_width=0.36,
    capsize=4,
    ylabel='Percentage',
    title=None,
    figsize=(7.5, 3.8),
    palette=None,
    ylim=None,
    show_values=False,
    pvals=None,              # dict: {monkey_name: p_value}
    bracket_height=1.2,      # vertical size of bracket in % units
    bracket_pad=1.0,         # padding above tallest bar in % units
    ax=None,
    dpi=100
):
    """
    Grouped bar chart (0–100 scale) with asymmetric CIs and optional significance brackets.

    df columns: ['perc', 'test_or_control', 'ci_lower', 'ci_upper', 'monkey_name'].
    Provide pvals={name: p} to annotate control vs test per monkey.
    """
    d = df.copy()
    d = d[d['test_or_control'].isin(group_order)]  # drop rows like 'difference'

    # Category orders
    if monkey_order is None:
        monkey_order = d['monkey_name'].drop_duplicates().tolist()
    d['monkey_name'] = pd.Categorical(d['monkey_name'], categories=monkey_order, ordered=True)
    d['test_or_control'] = pd.Categorical(d['test_or_control'], categories=list(group_order), ordered=True)
    d = d.sort_values(['monkey_name', 'test_or_control'])

    # X positions
    monkeys = pd.Index(monkey_order)
    x = np.arange(len(monkeys), dtype=float)

    # Offsets for grouped bars
    n_groups = len(group_order)
    offsets = np.linspace(-bar_width*(n_groups-1)/2, bar_width*(n_groups-1)/2, n_groups)
    offset_map = dict(zip(group_order, offsets))

    # Prepare axes
    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        created_fig = True
    else:
        fig = ax.figure

    # Plot bars
    bar_centers = {grp: [] for grp in group_order}
    bar_tops = {grp: [] for grp in group_order}

    for grp in group_order:
        sub = d[d['test_or_control'] == grp].set_index('monkey_name').reindex(monkeys)
        y = sub['perc'].to_numpy()
        lo = sub['ci_lower'].to_numpy()
        hi = sub['ci_upper'].to_numpy()
        yerr = np.vstack([y - lo, hi - y])

        color = None if palette is None else palette.get(grp, None)
        bars = ax.bar(
            x + offset_map[grp],
            y,
            width=bar_width,
            yerr=yerr,
            capsize=capsize,
            label=grp,
            color=color,
            edgecolor='none',
        )

        # store centers and tops for brackets
        bar_centers[grp] = x + offset_map[grp]
        # top = bar height + upper error (so bracket clears the error bar)
        bar_tops[grp] = y + yerr[1]

        if show_values:
            for rect, val in zip(bars, y):
                ax.annotate(f'{val:.1f}%',
                            xy=(rect.get_x() + rect.get_width()/2, rect.get_height()),
                            xytext=(0, 3), textcoords='offset points',
                            ha='center', va='bottom', fontsize=9)

    # Axes cosmetics
    ax.set_xticks(x, monkeys.tolist())
    ax.set_ylabel(ylabel)
    ax.yaxis.set_major_formatter(PercentFormatter(xmax=100))
    if ylim is not None:
        ax.set_ylim(*ylim)

    ax.yaxis.grid(True, linestyle='-', alpha=0.15)
    ax.yaxis.set_major_formatter(PercentFormatter(xmax=100, decimals=0))
    ax.set_axisbelow(True)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(title='Group', frameon=False)

    if title:
        ax.set_title(title)

    # --- Significance brackets (control vs test per monkey) ---
    if pvals:
        for i, name in enumerate(monkeys):
            p = pvals.get(name, None)
            if p is None:
                continue
            # x positions of the two bars
            x1 = bar_centers[group_order[0]][i]
            x2 = bar_centers[group_order[1]][i]
            # y baseline just above the taller error bar
            top = max(bar_tops[group_order[0]][i], bar_tops[group_order[1]][i])
            y = top + bracket_pad
            _add_sig_bracket(ax, x1, x2, y, h=bracket_height, text=_stars_from_p(p))

        # expand ylim if needed so brackets are visible
        ymin, ymax = ax.get_ylim()
        max_bracket = max(
            (max(bar_tops[g][i] for g in group_order) + bracket_pad + bracket_height)
            for i in range(len(monkeys))
        )
        if max_bracket > ymax:
            ax.set_ylim(ymin, max_bracket + 1.0)

    if created_fig:
        fig.tight_layout()

    return fig, ax



palette = {'control': '#f28e2b', 'test': '#4e79a7'}
pvals = {'Bruno': 0.00000, 'Schro': 0.00672, 'agent': 1e-5}

fig, ax = plot_grouped_percent_bars(
    pooled_perc_info,
    title='Same-Side Stop Rate',
    palette=palette,
    ylim=(45, 80),
    pvals=pvals,                # ← bracket + stars
    bracket_height=1.0,
    bracket_pad=0.8,
    dpi=300
)
plt.show()


In [None]:
pooled_perc_info_sub = pooled_perc_info[pooled_perc_info['monkey_name'] != 'agent']
fig, ax = plot_grouped_percent_bars(
    pooled_perc_info_sub,
    title='Same-Side Stop Rate',
    palette=palette,
    ylim=(45, 70),
    pvals=pvals,                # ← bracket + stars
    bracket_height=1.0,
    bracket_pad=0.8,
    dpi=300,
    figsize=(6, 3.8)
)
plt.show()


## perc

In [None]:
pooled_perc_info = cma.pooled_perc_info.copy()
for key, item in fixed_variable_values_to_use.items():
    pooled_perc_info = pooled_perc_info[pooled_perc_info[key] == item]
pooled_perc_info = pooled_perc_info.drop_duplicates()
pooled_perc_info = pooled_perc_info[['perc', 'test_or_control', 'ci_lower', 'ci_upper', 'monkey_name']].reset_index(drop=True)
pooled_perc_info.loc[pooled_perc_info['monkey_name'] == 'monkey_Bruno', 'monkey_name'] = 'Bruno'
pooled_perc_info.loc[pooled_perc_info['monkey_name'] == 'monkey_Schro', 'monkey_name'] = 'Schro'
pooled_perc_info


In [None]:
print(pooled_perc_info)

# agent vs monkey

In [None]:
#for opt_arc_type in ['norm_opt_arc', 'opt_arc_stop_closest', 'opt_arc_stop_first_vis_bdry']:
cma = cmp_monkey_agent_plan_class.CompareMonkeyAgentPlan(opt_arc_type='norm_opt_arc')
cma.get_monkey_and_agent_all_ref_pooled_median_info()
cma.get_monkey_and_agent_pooled_perc_info()

## heading

In [None]:
x_var_column_list = ['monkey_name']
fixed_variable_values_to_use = {#'whether_even_out_dist': True,
                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                'key_for_split': 'ff_seen',
                                }
changeable_variables = ['ref_point_value']
columns_to_find_unique_combinations_for_color = []

cma.plot_median_heading_across_monkeys_and_arc_types_with_difference(x_var_column_list=x_var_column_list,
                                                                fixed_variable_values_to_use=fixed_variable_values_to_use,
                                                                changeable_variables=changeable_variables,
                                                                columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color,
                                                                constant_marker_size=15)

## curv

In [None]:
x_var_column_list = ['monkey_name']
fixed_variable_values_to_use = {'whether_even_out_dist': False,
                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                'key_for_split': 'ff_seen',
                                }
changeable_variables = ['ref_point_value']
columns_to_find_unique_combinations_for_color = []

cma.plot_median_curv_across_monkeys_and_arc_types_with_difference(x_var_column_list=x_var_column_list,
                                                                fixed_variable_values_to_use=fixed_variable_values_to_use,
                                                                changeable_variables=changeable_variables,
                                                                columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color)

## perc

In [None]:
x_var_column_list = ['monkey_name']

fixed_variable_values_to_use = {#'whether_even_out_dist': False,
                                'if_test_nxt_ff_group_appear_after_stop': 'flexible',
                                'key_for_split': 'ff_seen',
                                }

changeable_variables = [] #'if_test_nxt_ff_group_appear_after_stop'

columns_to_find_unique_combinations_for_color = []

cma.plot_same_side_percentage_across_monkeys(x_var_column_list=x_var_column_list,
                                                                fixed_variable_values_to_use=fixed_variable_values_to_use,
                                                                changeable_variables=changeable_variables,
                                                                columns_to_find_unique_combinations_for_color=columns_to_find_unique_combinations_for_color)

# Collect new agent rollouts

In [None]:
# If you want to get new data from the agent

## remove the folder RL_models/SB3_stored_models/all_collected_data
# import shutil
# import os

# folder_path = "multiff_analysis/RL_models/SB3_stored_models/all_collected_data"

# # Check if folder exists before removing
# if os.path.exists(folder_path) and os.path.isdir(folder_path):
#     shutil.rmtree(folder_path)
#     print(f"Removed folder: {folder_path}")
# else:
#     print(f"Folder not found: {folder_path}")

# exists_ok = True
# pfas = agent_plan_factors_x_sess_class.PlanFactorsAcrossAgentSessions(num_steps_per_dataset=100000)
# agent_all_ref_pooled_median_info = pfas.make_or_retrieve_all_ref_pooled_median_info(exists_ok=exists_ok)
# agent_all_perc_df = pfas.make_or_retrieve_pooled_perc_info(exists_ok=exists_ok)
