# Shape Comparison Analysis

This notebook generates a 2Ã—N grid comparing performance across different peg shapes:
- Top row: Success Rate vs Position Noise
- Bottom row: Break Rate vs Position Noise
- Each column shows the shape name and a cross-sectional icon
- Gold highlight box around the reference shape

In [None]:
# ============================================================
# BLOCK 1: IMPORTS & LOCAL CONFIGURATION
# ============================================================

import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Import shared utilities
from analysis_utils import (
    # Constants
    ENTITY, PROJECT, COLORS,
    TAG_EVAL_PERFORMANCE, TAG_EVAL_NOISE,
    METRIC_SUCCESS, METRIC_BREAKS, METRIC_TOTAL,
    SHAPE_ICONS,
    # Data functions
    get_best_checkpoint_per_run,
    download_eval_data,
    # Plotting functions
    plot_multi_panel_grid,
    print_data_summary,
)

# ============================================================
# LOCAL CONFIGURATION (specific to this analysis)
# ============================================================

# Shape Levels - keys are internal names, each contains method tags
SHAPE_LEVELS = {
    "circle": {
        "Pose(1mm)": "pose_task_frag:2026-01-06_00:52",
        "MATCH(1mm)": "LCLoP_task_frag:2026-01-06_00:27",
        "Hybrid-Basic(1mm)": "basic-hybrid_task_frag:2026-01-06_00:56",
    },
    "hexagon": {
        "Pose(1mm)": "pose_hex:2026-01-21_13:50",
        "MATCH(1mm)": "MATCH_hex:2026-01-21_13:07",
        "Hybrid-Basic(1mm)": "basic-hybrid_hex:2026-01-21_13:07",
    },
}

# Display name mapping
SHAPE_DISPLAY_NAMES = {
    "circle": "Circle",
    "square": "Square",
    "rectangle": "Rectangle",
    "hexagon": "Hexagon",
    "oval": "Oval",
    "arch": "Arch",
}

# Noise Level Mapping: display label -> metric range string
NOISE_LEVELS = {
    "1mm": "0mm-1mm",
    "2.5mm": "1mm-2.5mm",
    "5mm": "2.5mm-5mm",
    "7.5mm": "5mm-7.5mm",
}

# Policy Selection
TOP_N_POLICIES = None
MAX_CHECKPOINT = None  # Set to int to limit checkpoint search (e.g., 2000000 for first 2M steps)

# Highlight Configuration
HIGHLIGHT_SHAPE = "circle"  # Which shape to highlight with gold box, or None

# N/A panels (if any shapes have no break rate data)
NA_SHAPES = []

# Plot Configuration
SUCCESS_Y_LIM = (0, 100)
SUCCESS_Y_TICKS = [0, 20, 40, 60, 80, 100]
BREAK_Y_LIM = (0, 12)
BREAK_Y_TICKS = [0, 2, 4, 6, 8, 10, 12]

# Error type: "ci" for 95% confidence interval, "binary_se" for binary standard error
ERROR_TYPE = "ci"

In [None]:
# ============================================================
# BLOCK 2: DETERMINE BEST POLICY
# ============================================================

api = wandb.Api()
best_checkpoints = defaultdict(dict)  # best_checkpoints[shape][method]

for shape, method_tags in SHAPE_LEVELS.items():
    print(f"\n{'='*60}")
    print(f"Shape: {shape}")
    print(f"{'='*60}")
    for method_name, method_tag in method_tags.items():
        print(f"\n  {method_name} ({method_tag}):")
        best_checkpoints[shape][method_name] = get_best_checkpoint_per_run(
            api, method_tag, max_checkpoint=MAX_CHECKPOINT
        )

In [None]:
# ============================================================
# BLOCK 3: DOWNLOAD DATA
# ============================================================

noise_data = defaultdict(dict)  # noise_data[shape][method]

for shape, method_tags in SHAPE_LEVELS.items():
    print(f"\n{'='*60}")
    print(f"Downloading data for Shape: {shape}")
    print(f"{'='*60}")
    for method_name, method_tag in method_tags.items():
        print(f"\n  {method_name}...")
        noise_data[shape][method_name] = download_eval_data(
            api=api,
            method_tag=method_tag,
            best_checkpoints=best_checkpoints[shape][method_name],
            level_mapping=NOISE_LEVELS,
            prefix_template="Noise_Eval({level})_Core",
            level_col_name="noise_level",
            eval_tag=TAG_EVAL_NOISE,
        )

# Print summary
print("\n" + "="*60)
print("DATA SUMMARY")
print("="*60)
for shape in SHAPE_LEVELS.keys():
    print(f"\n{shape}:")
    for method_name in SHAPE_LEVELS[shape].keys():
        df = noise_data[shape][method_name]
        if not df.empty:
            num_runs = df["run_name"].nunique()
            print(f"  {method_name}: {num_runs} runs")
        else:
            print(f"  {method_name}: No data")

In [None]:
# ============================================================
# BLOCK 4: SUCCESS RATE COMPARISON FIGURE
# ============================================================

# Get all unique method names across all shapes
method_names = []
for shape in SHAPE_LEVELS.keys():
    for method_name in SHAPE_LEVELS[shape].keys():
        if method_name not in method_names:
            method_names.append(method_name)

fig, axes = plot_multi_panel_grid(
    data=dict(noise_data),
    panel_keys=list(SHAPE_LEVELS.keys()),
    panel_display_names=SHAPE_DISPLAY_NAMES,
    method_names=method_names,
    level_labels=list(NOISE_LEVELS.keys()),
    level_col="noise_level",
    metric="success",
    n_cols=2,
    suptitle="Success Rate vs Position Noise Across Peg Shapes",
    x_label="Position Noise",
    y_label="Success Rate (%)",
    y_lim=SUCCESS_Y_LIM,
    y_ticks=SUCCESS_Y_TICKS,
    error_type=ERROR_TYPE,
    highlight_panel=HIGHLIGHT_SHAPE,
    filter_top_n=TOP_N_POLICIES,
    best_checkpoints=dict(best_checkpoints),
    show_shape_icons=True,
    shape_icons_config=SHAPE_ICONS,
)
plt.show()

In [None]:
# ============================================================
# BLOCK 5: BREAK RATE COMPARISON FIGURE
# ============================================================

fig, axes = plot_multi_panel_grid(
    data=dict(noise_data),
    panel_keys=list(SHAPE_LEVELS.keys()),
    panel_display_names=SHAPE_DISPLAY_NAMES,
    method_names=method_names,
    level_labels=list(NOISE_LEVELS.keys()),
    level_col="noise_level",
    metric="breaks",
    n_cols=2,
    suptitle="Break Rate vs Position Noise Across Peg Shapes",
    x_label="Position Noise",
    y_label="Break Rate (%)",
    y_lim=BREAK_Y_LIM,
    y_ticks=BREAK_Y_TICKS,
    error_type=ERROR_TYPE,
    highlight_panel=HIGHLIGHT_SHAPE,
    na_panels=NA_SHAPES,
    filter_top_n=TOP_N_POLICIES,
    best_checkpoints=dict(best_checkpoints),
    show_shape_icons=True,
    shape_icons_config=SHAPE_ICONS,
)
plt.show()