# Gain Robustness Analysis

This notebook generates:
1. Success Rate vs Gain bar plot
2. Break Rate vs Gain bar plot

In [None]:
# ============================================================
# BLOCK 1: IMPORTS & LOCAL CONFIGURATION
# ============================================================

import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Import shared utilities
from analysis_utils import (
    # Constants
    ENTITY, PROJECT, COLORS,
    TAG_EVAL_PERFORMANCE, TAG_EVAL_GAIN,
    METRIC_SUCCESS, METRIC_BREAKS, METRIC_TOTAL,
    # Data functions
    get_best_checkpoint_per_run,
    download_eval_data,
    # Plotting functions
    plot_rate_figure,
    print_data_summary,
)

# ============================================================
# LOCAL CONFIGURATION (specific to this analysis)
# ============================================================

# Method Tags
METHOD_TAGS = {
    "Pose": "pose_task_frag:2026-01-06_00:52",
    "MATCH": "LCLoP_task_frag:2026-01-06_00:27",
    "Hybrid-Basic": "basic-hybrid_task_frag:2026-01-06_00:56",
}

# Gain Level Mapping: display label -> metric gain string
GAIN_LEVELS = {
    "0.5x": "0.5x",
    "0.75x": "0.75x",
    "1.0x": "1.0x",
    "1.25x": "1.25x",
    "1.5x": "1.5x",
}

# Policy Selection
TOP_N_POLICIES = None
MAX_CHECKPOINT = None  # Set to int to limit checkpoint search (e.g., 2000000 for first 2M steps)

# Plot Configuration
SUCCESS_Y_LIM = (0, 100)
SUCCESS_Y_TICKS = [0, 20, 40, 60, 80, 100]
BREAK_Y_LIM = (0, 12)
BREAK_Y_TICKS = [0, 2, 4, 6, 8, 10, 12]

# Error type: "ci" for 95% confidence interval, "binary_se" for binary standard error
ERROR_TYPE = "ci"

In [None]:
# ============================================================
# BLOCK 2: DETERMINE BEST POLICY
# ============================================================

api = wandb.Api()
best_checkpoints_by_method = {}

for method_name, method_tag in METHOD_TAGS.items():
    print(f"\n{method_name} ({method_tag}):")
    best_checkpoints_by_method[method_name] = get_best_checkpoint_per_run(
        api, method_tag, max_checkpoint=MAX_CHECKPOINT
    )

In [None]:
# ============================================================
# BLOCK 3: DOWNLOAD DATA
# ============================================================

gain_data = {}
for method_name, method_tag in METHOD_TAGS.items():
    print(f"\nDownloading data for {method_name}...")
    gain_data[method_name] = download_eval_data(
        api=api,
        method_tag=method_tag,
        best_checkpoints=best_checkpoints_by_method[method_name],
        level_mapping=GAIN_LEVELS,
        prefix_template="Gain_Eval({level})_Core",
        level_col_name="gain_level",
        eval_tag=TAG_EVAL_GAIN,
    )

# Print summary
print_data_summary(
    data=gain_data,
    level_labels=list(GAIN_LEVELS.keys()),
    level_col="gain_level",
    metric="success",
    title="GAIN DATA SUMMARY (Success Rate)",
)

In [None]:
# ============================================================
# BLOCK 4: SUCCESS RATE VS GAIN
# ============================================================

fig, ax = plot_rate_figure(
    data=gain_data,
    method_names=list(METHOD_TAGS.keys()),
    level_labels=list(GAIN_LEVELS.keys()),
    level_col="gain_level",
    metric="success",
    title="Success Rate vs Gain",
    x_label="Gain Multiplier",
    y_label="Success Rate (%)",
    y_lim=SUCCESS_Y_LIM,
    y_ticks=SUCCESS_Y_TICKS,
    error_type=ERROR_TYPE,
    filter_top_n=TOP_N_POLICIES,
    best_checkpoints=best_checkpoints_by_method,
)
plt.show()

In [None]:
# ============================================================
# BLOCK 5: BREAK RATE VS GAIN
# ============================================================

fig, ax = plot_rate_figure(
    data=gain_data,
    method_names=list(METHOD_TAGS.keys()),
    level_labels=list(GAIN_LEVELS.keys()),
    level_col="gain_level",
    metric="breaks",
    title="Break Rate vs Gain",
    x_label="Gain Multiplier",
    y_label="Break Rate (%)",
    y_lim=BREAK_Y_LIM,
    y_ticks=BREAK_Y_TICKS,
    error_type=ERROR_TYPE,
    filter_top_n=TOP_N_POLICIES,
    best_checkpoints=best_checkpoints_by_method,
)
plt.show()