# Shape Comparison Analysis

This notebook generates a 2Ã—N grid comparing performance across different peg shapes:
- Top row: Success Rate vs Position Noise
- Bottom row: Break Rate vs Position Noise
- Each column shows the shape name and a cross-sectional icon
- Gold highlight box around the reference shape

In [None]:
# ============================================================
# BLOCK 1: IMPORTS & CONSTANTS
# ============================================================

import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import RegularPolygon, Rectangle, Circle, Ellipse, FancyBboxPatch
from collections import defaultdict

# WandB Configuration
ENTITY = "hur"
PROJECT = "SG_Exps"

# Shape Levels - keys are internal names, each contains method tags
SHAPE_LEVELS = {
    "circle": {
        "Pose": "PLACEHOLDER_POSE_CIRCLE",
        "Hybrid-Basic": "PLACEHOLDER_HYBRID_CIRCLE",
        "LCLoP": "PLACEHOLDER_LCLOP_CIRCLE",
    },
    "square": {
        "Pose": "PLACEHOLDER_POSE_SQUARE",
        "Hybrid-Basic": "PLACEHOLDER_HYBRID_SQUARE",
        "LCLoP": "PLACEHOLDER_LCLOP_SQUARE",
    },
    "rectangle": {
        "Pose": "PLACEHOLDER_POSE_RECTANGLE",
        "Hybrid-Basic": "PLACEHOLDER_HYBRID_RECTANGLE",
        "LCLoP": "PLACEHOLDER_LCLOP_RECTANGLE",
    },
    "hexagon": {
        "Pose": "PLACEHOLDER_POSE_HEXAGON",
        "Hybrid-Basic": "PLACEHOLDER_HYBRID_HEXAGON",
        "LCLoP": "PLACEHOLDER_LCLOP_HEXAGON",
    },
    "oval": {
        "Pose": "PLACEHOLDER_POSE_OVAL",
        "Hybrid-Basic": "PLACEHOLDER_HYBRID_OVAL",
        "LCLoP": "PLACEHOLDER_LCLOP_OVAL",
    },
}

# Display name mapping
SHAPE_DISPLAY_NAMES = {
    "circle": "Circle",
    "square": "Square",
    "rectangle": "Rectangle",
    "hexagon": "Hexagon",
    "oval": "Oval",
}

# Shape drawing configuration - maps shape key to drawing function parameters
# Each entry: (shape_type, kwargs for the patch)
# shape_type can be: 'circle', 'square', 'rectangle', 'hexagon', 'oval', 'polygon'
SHAPE_ICONS = {
    "circle": ("circle", {}),
    "square": ("square", {}),
    "rectangle": ("rectangle", {"aspect": 0.5}),  # width/height ratio
    "hexagon": ("polygon", {"num_sides": 6}),
    "oval": ("oval", {"aspect": 0.6}),  # width/height ratio
}

# Highlight box - which shape to highlight with gold box
HIGHLIGHT_SHAPE = "circle"  # Set to None for no highlight

# Evaluation Tags
TAG_EVAL_PERFORMANCE = "eval_performance"
TAG_EVAL_NOISE = "eval_noise"

# Noise Level Mapping: display label -> metric range string
NOISE_LEVELS = {
    "1mm": "0mm-1mm",
    "2.5mm": "1mm-2.5mm",
    "5mm": "2.5mm-5mm",
    "7.5mm": "5mm-7.5mm",
}

# Metrics
METRIC_SUCCESS = "num_successful_completions"
METRIC_BREAKS = "num_breaks"
METRIC_TOTAL = "total_episodes"

In [None]:
# ============================================================
# BLOCK 2: DETERMINE BEST POLICY
# ============================================================

def get_best_checkpoint_per_run(api, method_tag):
    """Find the best checkpoint for each run with the given method tag."""
    runs = api.runs(
        f"{ENTITY}/{PROJECT}",
        filters={"$and": [{"tags": method_tag}, {"tags": TAG_EVAL_PERFORMANCE}]}
    )
    
    best_checkpoints = {}
    for run in runs:
        history = run.history()
        if history.empty:
            print(f"Warning: Run {run.name} has no history data")
            continue
        
        # Calculate score: successes - breaks
        history["score"] = history[f"Eval_Core/{METRIC_SUCCESS}"] - history[f"Eval_Core/{METRIC_BREAKS}"]
        best_idx = history["score"].idxmax()
        best_step = int(history.loc[best_idx, "total_steps"])
        
        best_checkpoints[run.id] = {
            "run_name": run.name,
            "best_step": best_step,
            "score": history.loc[best_idx, "score"],
        }
        print(f"    {run.name}: best checkpoint at step {best_step} (score: {history.loc[best_idx, 'score']:.0f})")
    
    return best_checkpoints

# Get best checkpoints for each shape and method
api = wandb.Api()
best_checkpoints = defaultdict(dict)  # best_checkpoints[shape][method]

for shape, method_tags in SHAPE_LEVELS.items():
    print(f"\n{'='*60}")
    print(f"Shape: {shape}")
    print(f"{'='*60}")
    for method_name, method_tag in method_tags.items():
        print(f"\n  {method_name} ({method_tag}):")
        best_checkpoints[shape][method_name] = get_best_checkpoint_per_run(api, method_tag)

In [None]:
# ============================================================
# BLOCK 3: DOWNLOAD DATA
# ============================================================

def download_eval_noise_data(api, method_tag, best_checkpoints):
    """Download eval_noise data for best checkpoints across all noise levels."""
    runs = api.runs(
        f"{ENTITY}/{PROJECT}",
        filters={"$and": [{"tags": method_tag}, {"tags": TAG_EVAL_NOISE}]}
    )

    # Build lookup by agent number from best_checkpoints
    checkpoint_by_agent = {}
    for run_id, info in best_checkpoints.items():
        agent_num = info["run_name"].rsplit("_", 1)[-1]
        checkpoint_by_agent[agent_num] = info["best_step"]

    data = []
    for run in runs:
        # Extract agent number from run name
        agent_num = run.name.rsplit("_", 1)[-1]

        if agent_num not in checkpoint_by_agent:
            print(f"Warning: No matching performance run for agent {agent_num} ({run.name})")
            continue

        best_step = checkpoint_by_agent[agent_num]
        history = run.history()
        
        if best_step not in history["total_steps"].values:
            print(f"Warning: Checkpoint {best_step} not found in {run.name}")
            continue
        
        row = history[history["total_steps"] == best_step].iloc[0]
        
        for noise_label, noise_range in NOISE_LEVELS.items():
            prefix = f"Noise_Eval({noise_range})_Core"
            data.append({
                "run_id": run.id,
                "run_name": run.name,
                "checkpoint": best_step,
                "noise_level": noise_label,
                "success": row[f"{prefix}/{METRIC_SUCCESS}"],
                "breaks": row[f"{prefix}/{METRIC_BREAKS}"],
                "total": row[f"{prefix}/{METRIC_TOTAL}"],
            })
    
    return pd.DataFrame(data)

# Download all data
noise_data = defaultdict(dict)  # noise_data[shape][method]

for shape, method_tags in SHAPE_LEVELS.items():
    print(f"\n{'='*60}")
    print(f"Downloading data for Shape: {shape}")
    print(f"{'='*60}")
    for method_name, method_tag in method_tags.items():
        print(f"\n  {method_name}...")
        noise_data[shape][method_name] = download_eval_noise_data(
            api, method_tag, best_checkpoints[shape][method_name]
        )

# Print summary
print("\n" + "="*60)
print("DATA SUMMARY")
print("="*60)
for shape in SHAPE_LEVELS.keys():
    print(f"\n{shape}:")
    for method_name in SHAPE_LEVELS[shape].keys():
        df = noise_data[shape][method_name]
        if not df.empty:
            num_runs = df["run_name"].nunique()
            print(f"  {method_name}: {num_runs} runs")
        else:
            print(f"  {method_name}: No data")

In [None]:
# ============================================================
# BLOCK 4: SHAPE COMPARISON FIGURE
# ============================================================

# Policy Selection
TOP_N_POLICIES = None  # Set to integer (e.g., 3) to use top N policies, or None for all

# Highlight Configuration
HIGHLIGHT_SHAPE_PLOT = "circle"  # Which shape to highlight with gold box, or None
HIGHLIGHT_COLOR = "gold"
HIGHLIGHT_LINEWIDTH = 3

# N/A Configuration - for shapes where break rate is not applicable
NA_SHAPES = []  # List of shape keys to show N/A instead of break rate plot
NA_TEXT = "N/A"  # Text to display in the N/A box

# Figure Configuration
FIGSIZE = (16, 7)  # Width x Height (slightly taller for shape icons)
DPI = 150
BAR_WIDTH = 0.25

# Shape Icon Configuration
SHAPE_ICON_SIZE = 0.12  # Size of shape icon (in figure coordinates)
SHAPE_ICON_COLOR = "#333333"  # Dark gray fill
SHAPE_ICON_EDGE_COLOR = "black"
SHAPE_ICON_LINEWIDTH = 1.5

# Colors
COLORS = {
    "Pose": "#2ca02c",        # Green
    "Hybrid-Basic": "#ff7f0e", # Orange
    "LCLoP": "#1f77b4",       # Blue
}

# Font sizes
FONT_SUPTITLE = 16
FONT_TITLE = 11
FONT_AXIS_LABEL = 10
FONT_TICK = 9
FONT_LEGEND = 9
FONT_NA = 12  # Font size for N/A text

# Axis configuration
SUCCESS_Y_LIM = (0, 100)
SUCCESS_Y_TICKS = [0, 20, 40, 60, 80, 100]
BREAK_Y_LIM = (0, 25)
BREAK_Y_TICKS = [0, 5, 10, 15, 20, 25]

# Labels
SUPTITLE = "Performance vs Position Noise Across Peg Shapes"
X_LABEL = "Position Noise"
SUCCESS_Y_LABEL = "Success Rate (%)"
BREAK_Y_LABEL = "Break Rate (%)"

# ============================================================

def draw_shape_icon(ax, shape_key, x, y, size, shape_icons_config):
    """Draw a shape icon at the specified position in axes coordinates."""
    if shape_key not in shape_icons_config:
        return
    
    shape_type, kwargs = shape_icons_config[shape_key]
    
    # Create inset axes for the shape
    # Position is in figure coordinates
    inset_ax = ax.inset_axes([x - size/2, y - size/2, size, size], transform=ax.figure.transFigure)
    inset_ax.set_xlim(-1.2, 1.2)
    inset_ax.set_ylim(-1.2, 1.2)
    inset_ax.set_aspect('equal')
    inset_ax.axis('off')
    
    if shape_type == "circle":
        patch = Circle((0, 0), 1, facecolor=SHAPE_ICON_COLOR, 
                       edgecolor=SHAPE_ICON_EDGE_COLOR, linewidth=SHAPE_ICON_LINEWIDTH)
    elif shape_type == "square":
        patch = Rectangle((-0.9, -0.9), 1.8, 1.8, facecolor=SHAPE_ICON_COLOR,
                         edgecolor=SHAPE_ICON_EDGE_COLOR, linewidth=SHAPE_ICON_LINEWIDTH)
    elif shape_type == "rectangle":
        aspect = kwargs.get("aspect", 0.5)
        patch = Rectangle((-0.9, -0.9 * aspect), 1.8, 1.8 * aspect, facecolor=SHAPE_ICON_COLOR,
                         edgecolor=SHAPE_ICON_EDGE_COLOR, linewidth=SHAPE_ICON_LINEWIDTH)
    elif shape_type == "oval":
        aspect = kwargs.get("aspect", 0.6)
        patch = Ellipse((0, 0), 2, 2 * aspect, facecolor=SHAPE_ICON_COLOR,
                       edgecolor=SHAPE_ICON_EDGE_COLOR, linewidth=SHAPE_ICON_LINEWIDTH)
    elif shape_type == "polygon":
        num_sides = kwargs.get("num_sides", 6)
        patch = RegularPolygon((0, 0), num_sides, radius=1, facecolor=SHAPE_ICON_COLOR,
                              edgecolor=SHAPE_ICON_EDGE_COLOR, linewidth=SHAPE_ICON_LINEWIDTH)
    else:
        return
    
    inset_ax.add_patch(patch)
    return inset_ax

def filter_top_n_runs(df, method_best_checkpoints, top_n):
    """Filter dataframe to only include top N runs by score."""
    if df.empty or top_n is None or len(method_best_checkpoints) <= top_n:
        return df
    sorted_runs = sorted(method_best_checkpoints.items(), key=lambda x: x[1]["score"], reverse=True)
    top_run_names = {info["run_name"] for _, info in sorted_runs[:top_n]}
    # Match by agent number
    top_agent_nums = {name.rsplit("_", 1)[-1] for name in top_run_names}
    return df[df["run_name"].apply(lambda x: x.rsplit("_", 1)[-1] in top_agent_nums)]

def compute_rates(df, noise_labels, metric="success"):
    """Compute success or break rates for each noise level. Returns zeros if df is empty."""
    if df.empty:
        return [0] * len(noise_labels)
    
    rates = []
    for noise_label in noise_labels:
        subset = df[df["noise_level"] == noise_label]
        if not subset.empty:
            total = subset["total"].sum()
            rate = 100 * subset[metric].sum() / total
            rates.append(rate)
        else:
            rates.append(0)
    return rates

# Setup
shape_keys = list(SHAPE_LEVELS.keys())
method_names = list(SHAPE_LEVELS[shape_keys[0]].keys())
noise_labels = list(NOISE_LEVELS.keys())
n_shapes = len(shape_keys)
x = np.arange(len(noise_labels))

# Create figure with extra space at top for shape icons
fig, axes = plt.subplots(2, n_shapes, figsize=FIGSIZE, dpi=DPI)
fig.suptitle(SUPTITLE, fontsize=FONT_SUPTITLE, y=0.98)

# Plot each shape
for col, shape in enumerate(shape_keys):
    ax_success = axes[0, col]
    ax_break = axes[1, col]
    
    display_name = SHAPE_DISPLAY_NAMES.get(shape, shape)
    is_na_shape = shape in NA_SHAPES
    
    # Plot bars for each method (success rate)
    for i, method_name in enumerate(method_names):
        df = noise_data[shape][method_name]
        df = filter_top_n_runs(df, best_checkpoints[shape][method_name], TOP_N_POLICIES)
        
        success_rates = compute_rates(df, noise_labels, "success")
        offset = (i - len(method_names)/2 + 0.5) * BAR_WIDTH
        
        ax_success.bar(x + offset, success_rates, BAR_WIDTH, 
                       label=method_name, color=COLORS[method_name])
        
        # Only plot break rates if not an N/A shape
        if not is_na_shape:
            break_rates = compute_rates(df, noise_labels, "breaks")
            ax_break.bar(x + offset, break_rates, BAR_WIDTH,
                         label=method_name, color=COLORS[method_name])
    
    # Configure success rate subplot (no title yet - will add with shape icon)
    ax_success.set_xticks(x)
    ax_success.set_xticklabels(noise_labels, fontsize=FONT_TICK)
    ax_success.set_ylim(SUCCESS_Y_LIM)
    ax_success.set_yticks(SUCCESS_Y_TICKS)
    ax_success.tick_params(axis='y', labelsize=FONT_TICK)
    
    # Configure break rate subplot
    if is_na_shape:
        # Clear the axes and show N/A text
        ax_break.set_xticks([])
        ax_break.set_yticks([])
        ax_break.text(0.5, 0.5, NA_TEXT, transform=ax_break.transAxes,
                      fontsize=FONT_NA, ha='center', va='center',
                      style='italic', color='gray')
        # Keep the spines for visual consistency
        for spine in ax_break.spines.values():
            spine.set_visible(True)
    else:
        ax_break.set_xlabel(X_LABEL, fontsize=FONT_AXIS_LABEL)
        ax_break.set_xticks(x)
        ax_break.set_xticklabels(noise_labels, fontsize=FONT_TICK)
        ax_break.set_ylim(BREAK_Y_LIM)
        ax_break.set_yticks(BREAK_Y_TICKS)
        ax_break.tick_params(axis='y', labelsize=FONT_TICK)
    
    # Only show y-axis label on leftmost plots
    if col == 0:
        ax_success.set_ylabel(SUCCESS_Y_LABEL, fontsize=FONT_AXIS_LABEL)
        ax_break.set_ylabel(BREAK_Y_LABEL, fontsize=FONT_AXIS_LABEL)
    
    # Only show legend on first plot
    if col == 0:
        ax_success.legend(fontsize=FONT_LEGEND, loc='upper left')
    
    # Add gold highlight to spines if this is the highlighted shape
    if HIGHLIGHT_SHAPE_PLOT is not None and shape == HIGHLIGHT_SHAPE_PLOT:
        for spine in ['top', 'left', 'right', 'bottom']:
            ax_success.spines[spine].set_color(HIGHLIGHT_COLOR)
            ax_success.spines[spine].set_linewidth(HIGHLIGHT_LINEWIDTH)
            ax_break.spines[spine].set_color(HIGHLIGHT_COLOR)
            ax_break.spines[spine].set_linewidth(HIGHLIGHT_LINEWIDTH)

# Apply tight_layout first
plt.tight_layout(rect=[0, 0, 1, 0.92])  # Leave room for suptitle and shape icons

# Force draw to get updated positions
fig.canvas.draw()

# Add shape icons and titles above each column
for col, shape in enumerate(shape_keys):
    ax_success = axes[0, col]
    display_name = SHAPE_DISPLAY_NAMES.get(shape, shape)
    
    # Get position of the success rate subplot
    bbox = ax_success.get_position()
    center_x = bbox.x0 + bbox.width / 2
    
    # Add shape icon above the plot
    icon_y = bbox.y1 + 0.02
    draw_shape_icon(ax_success, shape, center_x, icon_y, SHAPE_ICON_SIZE, SHAPE_ICONS)
    
    # Add text label below the icon
    text_y = bbox.y1 + 0.005
    fig.text(center_x, text_y, display_name, ha='center', va='top', 
             fontsize=FONT_TITLE, fontweight='bold')

plt.show()