# Dynamics Robustness Analysis

This notebook generates:
1. Success Rate heatmaps (friction x mass) per method
2. Break Rate heatmaps (friction x mass) per method

In [1]:
# ============================================================
# BLOCK 1: IMPORTS & CONSTANTS
# ============================================================

import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from collections import defaultdict

# WandB Configuration
ENTITY = "hur"
PROJECT = "SG_Exps"

# Method Tags
METHOD_TAGS = {
    "Pose": "pose_task_frag:2026-01-06_00:52",
    #"Pose" : "pose_25mm-noise:2026-01-19_07:50_15N",
    #"Pose-7.5mm":"pose_75mm-noise:2026-01-17_19:18",
    "MATCH": "LCLoP_task_frag:2026-01-06_00:27",
    "Hybrid-Basic": "basic-hybrid_task_frag:2026-01-06_00:56",
}

# Method Colors - set colors for each method here
""" THE COLORS KEY SHOULD NOT BE CHANGED FOR ANY REASON WITHOUT USER PERMISSION DO NOT OVERWRITE!!!"""
COLORS = {
    "Pose": "#2ca02c",        # Green
    "Hybrid-Basic": "#ff7f0e", # Orange
    "MATCH": "#1f77b4",       # Blue
}

# Evaluation Tags
TAG_EVAL_PERFORMANCE = "eval_performance"
TAG_EVAL_DYNAMICS = "eval_dynamics"

# Dynamics Parameters
FRIC_VALUES = ["0.5x", "0.75x", "1.0x"]
MASS_VALUES = ["0.5x", "1.0x", "1.5x", "2.0x"]

# Metrics
METRIC_SUCCESS = "num_successful_completions"
METRIC_BREAKS = "num_breaks"
METRIC_TOTAL = "total_episodes"

In [2]:
# ============================================================
# BLOCK 2: DETERMINE BEST POLICY
# ============================================================

def get_best_checkpoint_per_run(api, method_tag):
    """Find the best checkpoint for each run with the given method tag."""
    runs = api.runs(
        f"{ENTITY}/{PROJECT}",
        filters={"$and": [{"tags": method_tag}, {"tags": TAG_EVAL_PERFORMANCE}]}
    )
    
    best_checkpoints = {}
    for run in runs:
        history = run.history()
        if history.empty:
            print(f"Warning: Run {run.name} has no history data")
            continue
        
        # Calculate score: successes - breaks
        history["score"] = history[f"Eval_Core/{METRIC_SUCCESS}"] - history[f"Eval_Core/{METRIC_BREAKS}"]
        best_idx = history["score"].idxmax()
        best_step = int(history.loc[best_idx, "total_steps"])
        
        best_checkpoints[run.id] = {
            "run_name": run.name,
            "best_step": best_step,
            "score": history.loc[best_idx, "score"],
        }
        print(f"  {run.name}: best checkpoint at step {best_step} (score: {history.loc[best_idx, 'score']:.0f})")
    
    return best_checkpoints

# Get best checkpoints for each method
api = wandb.Api()
best_checkpoints_by_method = {}

for method_name, method_tag in METHOD_TAGS.items():
    print(f"\n{method_name} ({method_tag}):")
    best_checkpoints_by_method[method_name] = get_best_checkpoint_per_run(api, method_tag)

[34m[1mwandb[0m: Currently logged in as: [33mrobonuke[0m ([33mhur[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



Pose (pose_task_frag:2026-01-06_00:52):
  Eval_performance_pose_task_frag_f(15)_0: best checkpoint at step 1152000 (score: 100)
  Eval_performance_pose_task_frag_f(15)_1: best checkpoint at step 2688000 (score: 98)
  Eval_performance_pose_task_frag_f(15)_2: best checkpoint at step 2649600 (score: 100)
  Eval_performance_pose_task_frag_f(15)_3: best checkpoint at step 1651200 (score: 98)
  Eval_performance_pose_task_frag_f(15)_4: best checkpoint at step 2995200 (score: 99)

MATCH (LCLoP_task_frag:2026-01-06_00:27):
  Eval_performance_LCLoP_task_frag_f(15)_0: best checkpoint at step 2380800 (score: 97)
  Eval_performance_LCLoP_task_frag_f(15)_1: best checkpoint at step 2880000 (score: 95)
  Eval_performance_LCLoP_task_frag_f(15)_3: best checkpoint at step 2995200 (score: 92)
  Eval_performance_LCLoP_task_frag_f(15)_2: best checkpoint at step 2841600 (score: 100)
  Eval_performance_LCLoP_task_frag_f(15)_4: best checkpoint at step 2611200 (score: 98)

Hybrid-Basic (basic-hybrid_task_frag:

In [3]:
# ============================================================
# BLOCK 3: DOWNLOAD DATA
# ============================================================

def download_eval_dynamics_data(api, method_tag, best_checkpoints):
    """Download eval_dynamics data for best checkpoints across all fric x mass combinations."""
    runs = api.runs(
        f"{ENTITY}/{PROJECT}",
        filters={"$and": [{"tags": method_tag}, {"tags": TAG_EVAL_DYNAMICS}]}
    )

    # Build lookup by agent number from best_checkpoints
    checkpoint_by_agent = {}
    for run_id, info in best_checkpoints.items():
        agent_num = info["run_name"].rsplit("_", 1)[-1]
        checkpoint_by_agent[agent_num] = info["best_step"]

    data = []
    for run in runs:
        # Extract agent number from run name
        agent_num = run.name.rsplit("_", 1)[-1]

        if agent_num not in checkpoint_by_agent:
            print(f"Warning: No matching performance run for agent {agent_num} ({run.name})")
            continue

        best_step = checkpoint_by_agent[agent_num]
        history = run.history()
        
        if best_step not in history["total_steps"].values:
            print(f"Warning: Checkpoint {best_step} not found in {run.name}")
            continue
        
        row = history[history["total_steps"] == best_step].iloc[0]
        
        for fric in FRIC_VALUES:
            for mass in MASS_VALUES:
                prefix = f"Dyn_Eval(fric={fric},mass={mass})_Core"
                data.append({
                    "run_id": run.id,
                    "run_name": run.name,
                    "checkpoint": best_step,
                    "fric_level": fric,
                    "mass_level": mass,
                    "success": row[f"{prefix}/{METRIC_SUCCESS}"],
                    "breaks": row[f"{prefix}/{METRIC_BREAKS}"],
                    "total": row[f"{prefix}/{METRIC_TOTAL}"],
                })
    
    return pd.DataFrame(data)

# Download all data
dynamics_data = {}

for method_name, method_tag in METHOD_TAGS.items():
    print(f"\nDownloading data for {method_name}...")
    dynamics_data[method_name] = download_eval_dynamics_data(
        api, method_tag, best_checkpoints_by_method[method_name]
    )

# Print summary
print("\n" + "="*60)
print("DYNAMICS DATA SUMMARY")
print("="*60)
for method_name, df in dynamics_data.items():
    print(f"\n{method_name}:")
    for fric in FRIC_VALUES:
        for mass in MASS_VALUES:
            subset = df[(df["fric_level"] == fric) & (df["mass_level"] == mass)]
            if not subset.empty:
                total = subset["total"].sum()
                success_rate = 100 * subset["success"].sum() / total
                break_rate = 100 * subset["breaks"].sum() / total
                print(f"  fric={fric}, mass={mass}: Success={success_rate:.1f}%, Break={break_rate:.1f}%")


Downloading data for Pose...

Downloading data for MATCH...

Downloading data for Hybrid-Basic...

DYNAMICS DATA SUMMARY

Pose:


KeyError: 'fric_level'

In [None]:
# ============================================================
# BLOCK 4: SUCCESS RATE HEATMAP (FRICTION x MASS)
# ============================================================

# Policy Selection
TOP_N_POLICIES = None  # Set to integer (e.g., 3) to use top N policies, or None for all

# Plot Constants
FIGSIZE_PER_METHOD = (5, 4)  # Size per individual heatmap
DPI = 150

# Font sizes
FONT_TITLE = 14
FONT_SUPTITLE = 16
FONT_AXIS_LABEL = 12
FONT_TICK = 10
FONT_CELL = 11
FONT_COLORBAR = 10

# Colormap for success rate: red (low) -> green (high)
# Alternative colormaps to try:
#   CMAP = "viridis"       # Perceptually uniform: purple -> yellow
#   CMAP = "plasma"        # Perceptually uniform: purple -> yellow (warmer)
#   CMAP = "YlGn"          # Sequential: yellow -> green
#   CMAP = "coolwarm"      # Diverging: blue -> red
#   CMAP = "RdYlBu"        # Diverging: red -> blue
CMAP = "RdYlGn"

# Value range for color normalization
VMIN = 0
VMAX = 100

# Labels
SUPTITLE = "Success Rate vs Dynamics Parameters"
X_LABEL = "Friction"
Y_LABEL = "Mass"

# ============================================================

def filter_top_n_runs(df, best_checkpoints, top_n):
    """Filter dataframe to only include top N runs by score."""
    if top_n is None or len(best_checkpoints) <= top_n:
        return df
    sorted_runs = sorted(best_checkpoints.items(), key=lambda x: x[1]["score"], reverse=True)
    top_run_names = {info["run_name"] for _, info in sorted_runs[:top_n]}
    top_agent_nums = {name.rsplit("_", 1)[-1] for name in top_run_names}
    return df[df["run_name"].apply(lambda x: x.rsplit("_", 1)[-1] in top_agent_nums)]

def build_heatmap_grid(df, fric_values, mass_values, metric, total_col="total"):
    """Build a 2D array of mean rates for the heatmap.
    
    Returns grid with shape (len(mass_values), len(fric_values)).
    Rows = mass (y-axis), Columns = fric (x-axis).
    """
    grid = np.zeros((len(mass_values), len(fric_values)))
    for mi, mass in enumerate(mass_values):
        for fi, fric in enumerate(fric_values):
            subset = df[(df["fric_level"] == fric) & (df["mass_level"] == mass)]
            if not subset.empty:
                subset = subset.copy()
                subset["rate"] = 100 * subset[metric] / subset[total_col]
                grid[mi, fi] = subset["rate"].mean()
    return grid

method_names = list(METHOD_TAGS.keys())
n_methods = len(method_names)

fig, axes = plt.subplots(1, n_methods, figsize=(FIGSIZE_PER_METHOD[0] * n_methods, FIGSIZE_PER_METHOD[1]), dpi=DPI)
if n_methods == 1:
    axes = [axes]

norm = mcolors.Normalize(vmin=VMIN, vmax=VMAX)

for ax, method_name in zip(axes, method_names):
    df = dynamics_data[method_name]
    df = filter_top_n_runs(df, best_checkpoints_by_method[method_name], TOP_N_POLICIES)
    
    grid = build_heatmap_grid(df, FRIC_VALUES, MASS_VALUES, "success")
    
    im = ax.imshow(grid, cmap=CMAP, norm=norm, aspect="auto", origin="lower")
    
    # Add text annotations
    for mi in range(len(MASS_VALUES)):
        for fi in range(len(FRIC_VALUES)):
            val = grid[mi, fi]
            # Use black text on light backgrounds, white on dark
            text_color = "white" if val < 40 else "black"
            ax.text(fi, mi, f"{val:.1f}", ha="center", va="center",
                    fontsize=FONT_CELL, fontweight="bold", color=text_color)
    
    ax.set_xticks(range(len(FRIC_VALUES)))
    ax.set_xticklabels(FRIC_VALUES, fontsize=FONT_TICK)
    ax.set_yticks(range(len(MASS_VALUES)))
    ax.set_yticklabels(MASS_VALUES, fontsize=FONT_TICK)
    ax.set_xlabel(X_LABEL, fontsize=FONT_AXIS_LABEL)
    ax.set_ylabel(Y_LABEL, fontsize=FONT_AXIS_LABEL)
    ax.set_title(method_name, fontsize=FONT_TITLE)

fig.suptitle(SUPTITLE, fontsize=FONT_SUPTITLE, y=1.02)
cbar = fig.colorbar(im, ax=axes, shrink=0.8, pad=0.04)
cbar.set_label("Success Rate (%)", fontsize=FONT_COLORBAR)

plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# BLOCK 5: BREAK RATE HEATMAP (FRICTION x MASS)
# ============================================================

# Policy Selection
TOP_N_POLICIES = None  # Set to integer (e.g., 3) to use top N policies, or None for all

# Plot Constants
FIGSIZE_PER_METHOD = (5, 4)  # Size per individual heatmap
DPI = 150

# Font sizes
FONT_TITLE = 14
FONT_SUPTITLE = 16
FONT_AXIS_LABEL = 12
FONT_TICK = 10
FONT_CELL = 11
FONT_COLORBAR = 10

# Colormap for break rate: green (low) -> red (high)
# Alternative colormaps to try:
#   CMAP = "viridis"       # Perceptually uniform: purple -> yellow
#   CMAP = "plasma"        # Perceptually uniform: purple -> yellow (warmer)
#   CMAP = "YlOrRd"        # Sequential: yellow -> red
#   CMAP = "coolwarm"      # Diverging: blue -> red
#   CMAP = "RdYlBu_r"      # Diverging: blue -> red (reversed)
CMAP = "RdYlGn_r"

# Value range for color normalization
VMIN = 0
VMAX = 100

# Labels
SUPTITLE = "Break Rate vs Dynamics Parameters"
X_LABEL = "Friction"
Y_LABEL = "Mass"

# ============================================================

def filter_top_n_runs(df, best_checkpoints, top_n):
    """Filter dataframe to only include top N runs by score."""
    if top_n is None or len(best_checkpoints) <= top_n:
        return df
    sorted_runs = sorted(best_checkpoints.items(), key=lambda x: x[1]["score"], reverse=True)
    top_run_names = {info["run_name"] for _, info in sorted_runs[:top_n]}
    top_agent_nums = {name.rsplit("_", 1)[-1] for name in top_run_names}
    return df[df["run_name"].apply(lambda x: x.rsplit("_", 1)[-1] in top_agent_nums)]

def build_heatmap_grid(df, fric_values, mass_values, metric, total_col="total"):
    """Build a 2D array of mean rates for the heatmap.
    
    Returns grid with shape (len(mass_values), len(fric_values)).
    Rows = mass (y-axis), Columns = fric (x-axis).
    """
    grid = np.zeros((len(mass_values), len(fric_values)))
    for mi, mass in enumerate(mass_values):
        for fi, fric in enumerate(fric_values):
            subset = df[(df["fric_level"] == fric) & (df["mass_level"] == mass)]
            if not subset.empty:
                subset = subset.copy()
                subset["rate"] = 100 * subset[metric] / subset[total_col]
                grid[mi, fi] = subset["rate"].mean()
    return grid

method_names = list(METHOD_TAGS.keys())
n_methods = len(method_names)

fig, axes = plt.subplots(1, n_methods, figsize=(FIGSIZE_PER_METHOD[0] * n_methods, FIGSIZE_PER_METHOD[1]), dpi=DPI)
if n_methods == 1:
    axes = [axes]

norm = mcolors.Normalize(vmin=VMIN, vmax=VMAX)

for ax, method_name in zip(axes, method_names):
    df = dynamics_data[method_name]
    df = filter_top_n_runs(df, best_checkpoints_by_method[method_name], TOP_N_POLICIES)
    
    grid = build_heatmap_grid(df, FRIC_VALUES, MASS_VALUES, "breaks")
    
    im = ax.imshow(grid, cmap=CMAP, norm=norm, aspect="auto", origin="lower")
    
    # Add text annotations
    for mi in range(len(MASS_VALUES)):
        for fi in range(len(FRIC_VALUES)):
            val = grid[mi, fi]
            # Use black text on light backgrounds, white on dark
            text_color = "white" if val > 60 else "black"
            ax.text(fi, mi, f"{val:.1f}", ha="center", va="center",
                    fontsize=FONT_CELL, fontweight="bold", color=text_color)
    
    ax.set_xticks(range(len(FRIC_VALUES)))
    ax.set_xticklabels(FRIC_VALUES, fontsize=FONT_TICK)
    ax.set_yticks(range(len(MASS_VALUES)))
    ax.set_yticklabels(MASS_VALUES, fontsize=FONT_TICK)
    ax.set_xlabel(X_LABEL, fontsize=FONT_AXIS_LABEL)
    ax.set_ylabel(Y_LABEL, fontsize=FONT_AXIS_LABEL)
    ax.set_title(method_name, fontsize=FONT_TITLE)

fig.suptitle(SUPTITLE, fontsize=FONT_SUPTITLE, y=1.02)
cbar = fig.colorbar(im, ax=axes, shrink=0.8, pad=0.04)
cbar.set_label("Break Rate (%)", fontsize=FONT_COLORBAR)

plt.tight_layout()
plt.show()