# Compare Sweeps to Baselines

General-purpose notebook: fill in one or more sweep IDs and compare against all available baselines.

Outputs:
- Per-task learning curves (grid view)
- Individual task plots
- Normalized aggregate performance
- Summary table at configurable eval step

In [105]:
import sys
from pathlib import Path

# Add analysis tools to path
NOTEBOOK_DIR = Path.cwd()
# Walk up until we find the analysis root (parent of 'notebooks' or 'notebooks/general')
ANALYSIS_ROOT = NOTEBOOK_DIR
while ANALYSIS_ROOT.name in ("notebooks", "general") or (ANALYSIS_ROOT / "tools").exists() is False:
    ANALYSIS_ROOT = ANALYSIS_ROOT.parent
    if ANALYSIS_ROOT == ANALYSIS_ROOT.parent:
        raise RuntimeError("Could not find analysis root")
if str(ANALYSIS_ROOT) not in sys.path:
    sys.path.insert(0, str(ANALYSIS_ROOT))

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from tools import wandb_io, config, paths
from tools.aggregations import runs_history_to_frame

## Configuration

Edit the cells below to configure your comparison.

In [106]:
# ============================================================
# SWEEP CONFIG — edit these
# ============================================================
SWEEP_IDS: list[str] = [
    # "abc12345",  # Add your sweep IDs here
    "aesenloq"

]
ENTITY = "thomasevers9"
PROJECT = "tdmpc2-tdmpc2"

# ============================================================
# PLOT CONFIG
# ============================================================
MAX_STEP = 200_000            # Max env step to plot
EVAL_STEP = 200_000           # Step at which to evaluate for summary table
METRIC_KEY = "eval/episode_reward"
OUR_LABEL = "Ours"            # Legend label for sweep results

# ============================================================
# TASK NAME MAPPING: sweep task name -> baseline CSV name
# Add entries when the sweep uses a different name than the
# baseline CSVs.  Keys not present here are used as-is.
# ============================================================
TASK_NAME_MAP: dict[str, str] = {
    # DM-Control
    "ball_in_cup-catch": "cup-catch",
    "finger-turn_easy": "finger-turn-easy",
    "finger-turn_hard": "finger-turn-hard",
    "cartpole-balance_sparse": "cartpole-balance-sparse",
    "cartpole-swingup_sparse": "cartpole-swingup-sparse",
    # Humanoid-Bench
    "humanoid_h1-slide-v0": "h1-slide-v0",
    "humanoid_h1-walk-v0": "h1-walk-v0",
    "humanoid_h1-balance_simple-v0": "h1-balance-simple-v0",
    "humanoid_h1-balance_hard-v0": "h1-balance-hard-v0",
    "humanoid_h1-run-v0": "h1-run-v0",
    "humanoid_h1-stand-v0": "h1-stand-v0",
}

# BMPC uses h1hand-* naming for humanoidbench
BMPC_TASK_MAP: dict[str, str] = {
    "h1-slide-v0": "h1hand-slide-v0",
    "h1-walk-v0": "h1hand-walk-v0",
    "h1-balance-simple-v0": "h1hand-balance_simple-v0",
    "h1-balance-hard-v0": "h1hand-balance_hard-v0",
    "h1-run-v0": "h1hand-run-v0",
    "h1-stand-v0": "h1hand-stand-v0",
}

# ============================================================
# TASK CATEGORIES — used for per-category aggregate plots.
# A task is auto-classified based on prefix matching.
# Only tasks present in the sweep are included.
# ============================================================
EASY_DMC_PREFIXES = [
    "acrobot", "cartpole", "cup", "cheetah", "finger", "fish",
    "hopper", "pendulum", "quadruped", "reacher", "walker",
]
HARD_DMC_PREFIXES = ["dog", "humanoid-"]   # humanoid-run/stand/walk (DMC)
HUMANOID_BENCH_PREFIXES = ["h1-"]           # h1-slide-v0, h1-walk-v0, etc.


def classify_task(task: str) -> str:
    """Return category name for a task based on prefix matching."""
    for prefix in HUMANOID_BENCH_PREFIXES:
        if task.startswith(prefix):
            return "Humanoid-Bench"
    for prefix in HARD_DMC_PREFIXES:
        if task.startswith(prefix):
            return "Hard DMC"
    for prefix in EASY_DMC_PREFIXES:
        if task.startswith(prefix):
            return "Easy DMC"
    return "Other"


# Methods to EXCLUDE from the total aggregate plot
AGGREGATE_EXCLUDE_METHODS: set[str] = {"EZ2"}

# ============================================================
# BASELINE ROOTS — each maps method name -> directory of CSVs
# Files inside must be named <task>.csv with columns:
#   step, reward, seed
# ============================================================
BASELINE_ROOTS: dict[str, Path] = {
    "TDMPC2":    paths.BASELINE_TDMPC2,
    "DreamerV3":  paths.BASELINE_DREAMERV3,
    "SAC":        paths.BASELINE_SAC,
    "EZ2":        paths.PROJECT_ROOT / "results" / "ez2_parsed",
    "SimbaV2":    paths.PROJECT_ROOT / "results" / "simbav2_parsed",
    "BMPC":       paths.PROJECT_ROOT / "results" / "BMPC_parsed" / "dmcontrol",
    "BMPC_humanoid": paths.PROJECT_ROOT / "results" / "BMPC_parsed" / "humanoidbench",
}

# ============================================================
# COLORS
# ============================================================
COLORS: dict[str, str] = {
    OUR_LABEL:  "#1f77b4",   # Blue
    "TDMPC2":   "#2ca02c",  # Green
    "DreamerV3": "#ff7f0e", # Orange
    "EZ2":      "#d62728",  # Red
    "SimbaV2":  "#9467bd",  # Purple
    "SAC":      "#8c564b",  # Brown
    "BMPC":     "#17becf",  # Cyan
}

assert len(SWEEP_IDS) > 0, "Add at least one sweep ID to SWEEP_IDS above."

## Load Our Results from W&B

In [107]:
# Fetch runs from all sweeps (with caching)
all_runs: list = []
for sweep_id in SWEEP_IDS:
    runs, manifest, source = wandb_io.fetch_sweep_runs(
        entity=ENTITY,
        project=PROJECT,
        sweep_id=sweep_id,
        history_keys=[METRIC_KEY, "_step"],
        use_cache=True,
        force_refresh=False,
    )
    print(f"Sweep {sweep_id}: {len(runs)} runs from {source} "
          f"(fetched: {manifest.get('fetched_at', 'N/A')})")
    all_runs.extend(runs)

print(f"\nTotal: {len(all_runs)} runs from {len(SWEEP_IDS)} sweep(s)")

Sweep aesenloq: 41 runs from cache (fetched: 2026-02-09T19:10:35.624354+00:00)

Total: 41 runs from 1 sweep(s)


In [108]:
# Convert to DataFrame
ours_df = runs_history_to_frame(
    all_runs,
    metric_key=METRIC_KEY,
    step_keys=["_step"],
    config_to_columns={"task": "task", "seed": "seed"},
)

# Normalize task names to match baseline naming
ours_df["task"] = ours_df["task"].replace(TASK_NAME_MAP)
ours_df = ours_df.rename(columns={METRIC_KEY: "reward"})
ours_df["method"] = OUR_LABEL

print(f"Our results: {len(ours_df)} rows")
print(f"Tasks: {sorted(ours_df['task'].unique().tolist())}")
print(f"Seeds: {sorted(ours_df['seed'].unique().tolist())}")
print(f"Step range: {ours_df['step'].min()} – {ours_df['step'].max()}")
ours_df.head()

Our results: 379 rows
Tasks: ['acrobot-swingup', 'cartpole-balance', 'cartpole-balance-sparse', 'cartpole-swingup', 'cartpole-swingup-sparse', 'cheetah-run', 'dog-run', 'dog-stand', 'dog-trot', 'dog-walk', 'finger-spin', 'finger-turn-easy', 'finger-turn-hard', 'fish-swim', 'h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0', 'hopper-hop', 'hopper-stand', 'humanoid-run', 'humanoid-stand', 'humanoid-walk', 'humanoid_h1-crawl-v0', 'humanoid_h1-hurdle-v0', 'humanoid_h1-maze-v0', 'humanoid_h1-pole-v0', 'humanoid_h1-reach-v0', 'humanoid_h1-sit_hard-v0', 'humanoid_h1-sit_simple-v0', 'humanoid_h1-stair-v0', 'pendulum-swingup', 'quadruped-run', 'quadruped-walk', 'reacher-easy', 'reacher-hard', 'walker-run', 'walker-stand', 'walker-walk']
Seeds: [1]
Step range: 10000 – 100000


Unnamed: 0,task,seed,run_id,step,reward,method
0,h1-balance-hard-v0,1,0efyfvek,10012,17.919926,Ours
1,h1-balance-hard-v0,1,0efyfvek,20013,17.932884,Ours
2,h1-balance-hard-v0,1,0efyfvek,30005,27.441183,Ours
3,h1-balance-hard-v0,1,0efyfvek,40029,32.208141,Ours
4,h1-balance-hard-v0,1,0efyfvek,50017,36.767048,Ours


## Load Baseline Results

In [109]:
def load_baseline_for_task(task: str, method: str, root: Path) -> pd.DataFrame:
    """Load baseline CSV for a single task.

    Args:
        task: Normalized task name (e.g. 'cup-catch', 'h1-slide-v0').
        method: Internal method key (may differ from display name).
        root: Directory containing per-task CSVs.

    Returns:
        DataFrame [step, reward, seed, task, method] or empty.
    """
    # BMPC_humanoid uses a different naming scheme
    if method == "BMPC_humanoid":
        csv_task = BMPC_TASK_MAP.get(task, task)
        display_method = "BMPC"
    else:
        csv_task = task
        display_method = method

    csv_path = root / f"{csv_task}.csv"
    if csv_path.exists():
        df = pd.read_csv(csv_path)
        df["task"] = task
        df["method"] = display_method
        return df
    return pd.DataFrame()


def load_all_baselines(tasks: list[str], max_step: int) -> pd.DataFrame:
    """Load all baseline results for *tasks*, up to *max_step*.

    Args:
        tasks: Normalized task names.
        max_step: Maximum environment step to include.

    Returns:
        Combined DataFrame of all baselines.
    """
    frames: list[pd.DataFrame] = []
    missing: dict[str, list[str]] = {m: [] for m in BASELINE_ROOTS}

    for method, root in BASELINE_ROOTS.items():
        if not root.exists():
            print(f"Warning: {method} root not found: {root}")
            continue
        for task in tasks:
            df = load_baseline_for_task(task, method, root)
            if df.empty:
                missing[method].append(task)
            else:
                frames.append(df[df["step"] <= max_step])

    # Report missing (skip BMPC_humanoid — merged into BMPC)
    for method, tasks_missing in missing.items():
        if method == "BMPC_humanoid":
            continue
        if tasks_missing:
            label = (f"{method}: missing {len(tasks_missing)} tasks: {tasks_missing[:5]}..."
                     if len(tasks_missing) > 5
                     else f"{method}: missing {tasks_missing}")
            print(label)

    if not frames:
        return pd.DataFrame()
    result = pd.concat(frames, ignore_index=True)
    result = result.drop_duplicates(subset=["method", "task", "step", "seed"])
    return result

In [110]:
our_tasks = sorted(ours_df["task"].unique().tolist())
print(f"Our tasks ({len(our_tasks)}): {our_tasks}")

baselines_df = load_all_baselines(our_tasks, MAX_STEP)
print(f"\nBaseline rows: {len(baselines_df)}")

if not baselines_df.empty:
    coverage = baselines_df.groupby("method")["task"].nunique()
    print("\nTask coverage per method:")
    for method, count in coverage.items():
        print(f"  {method}: {count}/{len(our_tasks)} tasks")

Our tasks (40): ['acrobot-swingup', 'cartpole-balance', 'cartpole-balance-sparse', 'cartpole-swingup', 'cartpole-swingup-sparse', 'cheetah-run', 'dog-run', 'dog-stand', 'dog-trot', 'dog-walk', 'finger-spin', 'finger-turn-easy', 'finger-turn-hard', 'fish-swim', 'h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0', 'hopper-hop', 'hopper-stand', 'humanoid-run', 'humanoid-stand', 'humanoid-walk', 'humanoid_h1-crawl-v0', 'humanoid_h1-hurdle-v0', 'humanoid_h1-maze-v0', 'humanoid_h1-pole-v0', 'humanoid_h1-reach-v0', 'humanoid_h1-sit_hard-v0', 'humanoid_h1-sit_simple-v0', 'humanoid_h1-stair-v0', 'pendulum-swingup', 'quadruped-run', 'quadruped-walk', 'reacher-easy', 'reacher-hard', 'walker-run', 'walker-stand', 'walker-walk']
TDMPC2: missing 13 tasks: ['h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0']...
DreamerV3: missing 13 tasks: ['h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0'].

## Combine Data

In [111]:
# Filter our results to max step and combine with baselines
ours_filtered = ours_df[ours_df["step"] <= MAX_STEP].copy()
all_data = pd.concat([ours_filtered, baselines_df], ignore_index=True)
print(f"Total rows: {len(all_data)}")
print(f"Methods: {sorted(all_data['method'].unique().tolist())}")
print(f"Tasks: {sorted(all_data['task'].unique().tolist())}")

Total rows: 3521
Methods: ['BMPC', 'DreamerV3', 'EZ2', 'Ours', 'SAC', 'SimbaV2', 'TDMPC2']
Tasks: ['acrobot-swingup', 'cartpole-balance', 'cartpole-balance-sparse', 'cartpole-swingup', 'cartpole-swingup-sparse', 'cheetah-run', 'dog-run', 'dog-stand', 'dog-trot', 'dog-walk', 'finger-spin', 'finger-turn-easy', 'finger-turn-hard', 'fish-swim', 'h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0', 'hopper-hop', 'hopper-stand', 'humanoid-run', 'humanoid-stand', 'humanoid-walk', 'humanoid_h1-crawl-v0', 'humanoid_h1-hurdle-v0', 'humanoid_h1-maze-v0', 'humanoid_h1-pole-v0', 'humanoid_h1-reach-v0', 'humanoid_h1-sit_hard-v0', 'humanoid_h1-sit_simple-v0', 'humanoid_h1-stair-v0', 'pendulum-swingup', 'quadruped-run', 'quadruped-walk', 'reacher-easy', 'reacher-hard', 'walker-run', 'walker-stand', 'walker-walk']


## Plotting Helpers

In [112]:
def hex_to_rgba(hex_color: str, alpha: float) -> str:
    """Convert hex color to rgba string."""
    h = hex_color.lstrip("#")
    r, g, b = (int(h[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"


def compute_stats(df: pd.DataFrame, group_cols: list[str]) -> pd.DataFrame:
    """Compute mean / std / count of reward across seeds."""
    return (
        df.groupby(group_cols)
        .agg(mean_reward=("reward", "mean"),
             std_reward=("reward", "std"),
             n_seeds=("reward", "count"))
        .reset_index()
    )


def _ordered_methods(methods_present: list[str]) -> list[str]:
    """Return methods ordered with OUR_LABEL first, rest sorted."""
    return [OUR_LABEL] + sorted(m for m in methods_present if m != OUR_LABEL)


def _add_method_traces(
    fig: go.Figure,
    stats: pd.DataFrame,
    colors: dict[str, str],
    *,
    show_legend: bool = True,
    legend_added: set | None = None,
    row: int | None = None,
    col: int | None = None,
    fill_alpha: float = 0.2,
    line_width: float = 2.0,
) -> None:
    """Add mean +/- std traces for each method to *fig*."""
    methods = _ordered_methods(stats["method"].unique().tolist())
    kwargs = {} if row is None else {"row": row, "col": col}

    for method in methods:
        ms = stats[stats["method"] == method].sort_values("step")
        if ms.empty:
            continue

        color = colors.get(method, "#666666")
        x = ms["step"].values
        y_mean = ms["mean_reward"].values
        y_std = np.nan_to_num(ms["std_reward"].values, nan=0)

        # Shaded std region
        x_fill = np.concatenate([x, x[::-1]])
        y_fill = np.concatenate([y_mean + y_std, (y_mean - y_std)[::-1]])

        should_show = show_legend and (legend_added is None or method not in legend_added)
        if legend_added is not None:
            legend_added.add(method)

        fig.add_trace(
            go.Scatter(
                x=x_fill, y=y_fill,
                fill="toself",
                fillcolor=hex_to_rgba(color, fill_alpha),
                line=dict(color="rgba(0,0,0,0)"),
                showlegend=False,
                hoverinfo="skip",
                legendgroup=method,
            ),
            **kwargs,
        )

        line_style = dict(color=color, width=line_width)
        if method != OUR_LABEL:
            line_style["dash"] = "dash"

        fig.add_trace(
            go.Scatter(
                x=x, y=y_mean,
                mode="lines",
                name=method,
                line=line_style,
                showlegend=should_show,
                legendgroup=method,
                hovertemplate=f"{method}<br>Step: %{{x:,}}<br>Reward: %{{y:.1f}}<extra></extra>",
            ),
            **kwargs,
        )

## Individual Task Plots

In [124]:
def plot_task(data: pd.DataFrame, task: str, colors: dict[str, str]) -> go.Figure | None:
    """Learning-curve comparison for a single task."""
    td = data[data["task"] == task]
    if td.empty:
        return None
    stats = compute_stats(td, ["method", "step"])
    fig = go.Figure()
    _add_method_traces(fig, stats, colors)
    fig.update_layout(
        title=dict(text=f"<b>{task}</b>", x=0.5),
        xaxis_title="Environment Steps",
        yaxis_title="Episode Reward",
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        hovermode="x unified",
        template="plotly_white",
        width=700, height=450,
    )
    return fig


figures: dict[str, go.Figure] = {}
for task in our_tasks:
    fig = plot_task(all_data, task, COLORS)
    if fig is not None:
        figures[task] = fig

print(f"Generated {len(figures)} task plots")

Generated 40 task plots


In [114]:
# Show all individual plots
for task, fig in figures.items():
    fig.show()

## Grid View (All Tasks)

In [125]:
def create_grid_figure(
    data: pd.DataFrame,
    tasks: list[str],
    colors: dict[str, str],
    cols: int = 3,
) -> go.Figure:
    """Grid of per-task learning curves.

    Args:
        data: Combined DataFrame [step, reward, method, task].
        tasks: Task names to include.
        colors: Method -> hex color.
        cols: Number of grid columns.

    Returns:
        Plotly Figure with subplots.
    """
    n_tasks = len(tasks)
    rows = (n_tasks + cols - 1) // cols

    # Keep spacing small so subplots get most of the room
    v_spacing = min(0.03, 0.9 / max(rows - 1, 1))
    h_spacing = min(0.04, 0.9 / max(cols - 1, 1))

    fig = make_subplots(
        rows=rows, cols=cols,
        subplot_titles=tasks,
        vertical_spacing=v_spacing,
        horizontal_spacing=h_spacing,
    )

    legend_added: set[str] = set()

    for idx, task in enumerate(tasks):
        r = idx // cols + 1
        c = idx % cols + 1
        task_data = data[data["task"] == task]
        if task_data.empty:
            continue
        stats = compute_stats(task_data, ["method", "step"])
        _add_method_traces(
            fig, stats, colors,
            legend_added=legend_added,
            row=r, col=c,
            fill_alpha=0.15, line_width=1.5,
        )

    # Give each row enough height for readable plots
    row_height = 300
    fig.update_layout(
        height=row_height * rows + 80,  # +80 for title/legend
        width=1100,
        title_text=f"<b>{OUR_LABEL} vs Baselines</b>",
        title_x=0.5,
        showlegend=True,
        legend=dict(x=1.01, y=1, bgcolor="rgba(255,255,255,0.9)"),
        template="plotly_white",
        margin=dict(t=60, b=40, l=50, r=120),
    )

    # Shrink subplot title font so it doesn't overlap
    for ann in fig.layout.annotations:
        ann.font.size = 11

    return fig


grid_fig = create_grid_figure(all_data, our_tasks, COLORS, cols=3)
grid_fig.show()

## Normalized Aggregate Performance

Per-category aggregate plots. For fairness, each category only includes tasks
that **every** method (in that category) has data for. Methods in `AGGREGATE_EXCLUDE_METHODS` are excluded from the total aggregate.

In [116]:
def _interpolate_to_common_steps(
    data: pd.DataFrame,
    reward_col: str,
    common_steps: list[int],
) -> pd.DataFrame:
    """Interpolate per-(method, task) to common steps. No extrapolation.

    Args:
        data: DataFrame with columns [method, task, step, <reward_col>].
        reward_col: Name of the reward column to interpolate.
        common_steps: Target step grid.

    Returns:
        DataFrame [method, task, step, <reward_col>].
    """
    rows: list[dict] = []
    for (method, task), grp in data.groupby(["method", "task"]):
        seed_avg = (
            grp.groupby("step")[reward_col]
            .mean()
            .reset_index()
            .sort_values("step")
        )
        xs = seed_avg["step"].values
        ys = seed_avg[reward_col].values
        if len(xs) == 0:
            continue
        x_min, x_max = xs.min(), xs.max()
        for ts in common_steps:
            if ts < x_min or ts > x_max:
                continue
            val = float(np.interp(ts, xs, ys))
            rows.append({"method": method, "task": task, "step": ts, reward_col: val})
    return pd.DataFrame(rows)


def compute_normalized_aggregate(
    data: pd.DataFrame,
    eval_step: int,
    common_steps: list[int] | None = None,
) -> pd.DataFrame:
    """Compute normalized aggregate performance across tasks.

    For each task the reward is divided by the best method's mean at *eval_step*.
    No extrapolation beyond each method's actual data range.

    Args:
        data: Combined DataFrame [step, reward, method, task].
        eval_step: Step for computing normalisation constant.
        common_steps: Interpolation grid. Default: 0, 10k, …, eval_step.

    Returns:
        DataFrame [method, step, mean_normalized, std_normalized, n_tasks].
    """
    if common_steps is None:
        common_steps = list(range(0, eval_step + 1, 10_000))

    # Per-task normalisation constant
    task_max: dict[str, float] = {}
    for task in data["task"].unique():
        td = data[data["task"] == task]
        best = 0.0
        for method in td["method"].unique():
            md = td[td["method"] == method]
            valid = md["step"].unique()
            valid = valid[valid <= eval_step]
            if len(valid) > 0:
                best = max(best, md[md["step"] == valid.max()]["reward"].mean())
        task_max[task] = best if best > 0 else 1.0

    data_norm = data.copy()
    data_norm["norm_reward"] = data_norm.apply(
        lambda r: r["reward"] / task_max[r["task"]], axis=1
    )

    interp_df = _interpolate_to_common_steps(data_norm, "norm_reward", common_steps)
    agg = (
        interp_df.groupby(["method", "step"])
        .agg(mean_normalized=("norm_reward", "mean"),
             std_normalized=("norm_reward", "std"),
             n_tasks=("norm_reward", "count"))
        .reset_index()
    )
    return agg


def compute_unnormalized_aggregate(
    data: pd.DataFrame,
    eval_step: int,
    common_steps: list[int] | None = None,
) -> pd.DataFrame:
    """Compute unnormalized (raw reward) aggregate across tasks.

    Same interpolation/no-extrapolation logic as normalized variant
    but without dividing by the per-task maximum.

    Args:
        data: Combined DataFrame [step, reward, method, task].
        eval_step: Used only for default common_steps grid.
        common_steps: Interpolation grid. Default: 0, 10k, …, eval_step.

    Returns:
        DataFrame [method, step, mean_reward, std_reward, n_tasks].
    """
    if common_steps is None:
        common_steps = list(range(0, eval_step + 1, 10_000))

    interp_df = _interpolate_to_common_steps(data, "reward", common_steps)
    agg = (
        interp_df.groupby(["method", "step"])
        .agg(mean_reward=("reward", "mean"),
             std_reward=("reward", "std"),
             n_tasks=("reward", "count"))
        .reset_index()
    )
    return agg


def get_common_tasks_for_category(
    data: pd.DataFrame,
    category_tasks: list[str],
    exclude_methods: set[str] | None = None,
) -> tuple[list[str], list[str]]:
    """Return tasks common to all relevant methods, and those methods.

    Only methods that have data for at least one task in this category
    are considered (methods with zero overlap are ignored).

    Args:
        data: Combined DataFrame.
        category_tasks: Tasks belonging to this category.
        exclude_methods: Methods to ignore entirely.

    Returns:
        (common_tasks, relevant_methods): Tasks present in every relevant
        method, and the list of methods considered.
    """
    exclude_methods = exclude_methods or set()
    cat_set = set(category_tasks)
    all_methods = [m for m in data["method"].unique() if m not in exclude_methods]

    # Only keep methods that have data for ≥1 task in this category
    relevant_methods: list[str] = []
    for method in all_methods:
        method_tasks = set(data[data["method"] == method]["task"].unique())
        if method_tasks & cat_set:
            relevant_methods.append(method)

    if not relevant_methods:
        return sorted(category_tasks), []

    common = set(category_tasks)
    for method in relevant_methods:
        method_tasks = set(data[data["method"] == method]["task"].unique())
        common &= method_tasks
    return sorted(common), sorted(relevant_methods)

In [123]:
def _build_footer(
    methods: list[str] | None = None,
    tasks: list[str] | None = None,
    extra: str | None = None,
) -> str:
    """Build a compact description string for the bottom of a figure."""
    parts: list[str] = []
    if methods:
        parts.append(f"Methods: {', '.join(methods)}")
    if tasks:
        parts.append(f"Tasks: {', '.join(tasks)}")
    if extra:
        parts.append(extra)
    return "  |  ".join(parts)


def _apply_footer(
    fig: go.Figure,
    methods: list[str] | None = None,
    tasks: list[str] | None = None,
    extra: str | None = None,
) -> None:
    """Add a description annotation at the bottom of the figure."""
    text = _build_footer(methods, tasks, extra)
    if not text:
        return
    fig.add_annotation(
        text=text,
        xref="paper", yref="paper",
        x=0.5, y=-0.18,
        showarrow=False,
        font=dict(size=10, color="gray"),
        xanchor="center",
    )
    # Ensure enough bottom margin for the footer
    current_b = fig.layout.margin.b if fig.layout.margin and fig.layout.margin.b else 80
    fig.update_layout(margin=dict(b=max(current_b, 100)))


def plot_normalized_aggregate(
    agg: pd.DataFrame,
    colors: dict[str, str],
    title: str = "Normalized Aggregate Performance",
    subtitle_tasks: list[str] | None = None,
    subtitle_methods: list[str] | None = None,
    subtitle_extra: str | None = None,
) -> go.Figure:
    """Plot normalized aggregate performance."""
    fig = go.Figure()
    stats = agg.rename(columns={"mean_normalized": "mean_reward",
                                 "std_normalized": "std_reward"})
    _add_method_traces(fig, stats, colors, line_width=2.5)
    fig.update_layout(
        title=dict(text=f"<b>{title}</b>", x=0.5),
        xaxis_title="Environment Steps",
        yaxis_title="Normalized Reward (fraction of max)",
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        hovermode="x unified",
        template="plotly_white",
        width=800, height=520,
        yaxis=dict(range=[0, 1.05]),
        margin=dict(t=60, b=100),
    )
    _apply_footer(fig, subtitle_methods, subtitle_tasks, subtitle_extra)
    return fig


def plot_unnormalized_aggregate(
    agg: pd.DataFrame,
    colors: dict[str, str],
    title: str = "Unnormalized Aggregate Performance",
    subtitle_tasks: list[str] | None = None,
    subtitle_methods: list[str] | None = None,
    subtitle_extra: str | None = None,
) -> go.Figure:
    """Plot unnormalized (raw reward) aggregate performance."""
    fig = go.Figure()
    _add_method_traces(fig, agg, colors, line_width=2.5)
    fig.update_layout(
        title=dict(text=f"<b>{title}</b>", x=0.5),
        xaxis_title="Environment Steps",
        yaxis_title="Mean Episode Reward",
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        hovermode="x unified",
        template="plotly_white",
        width=800, height=520,
        margin=dict(t=60, b=100),
    )
    _apply_footer(fig, subtitle_methods, subtitle_tasks, subtitle_extra)
    return fig


# ------------------------------------------------------------------
# Build per-category aggregates (normalized + unnormalized)
# ------------------------------------------------------------------

# Classify tasks into categories
task_categories: dict[str, list[str]] = {}
for t in our_tasks:
    cat = classify_task(t)
    task_categories.setdefault(cat, []).append(t)

print("Task categories detected:")
for cat, tasks in sorted(task_categories.items()):
    print(f"  {cat} ({len(tasks)}): {tasks}")

# Results dicts: category -> (df, fig)
aggregate_results: dict[str, tuple[pd.DataFrame, go.Figure]] = {}
unnorm_aggregate_results: dict[str, tuple[pd.DataFrame, go.Figure]] = {}

for category, cat_tasks in sorted(task_categories.items()):
    common, relevant_methods = get_common_tasks_for_category(
        all_data, cat_tasks, exclude_methods=None,
    )
    if not common:
        print(f"\n{category}: no common tasks across methods — skipping")
        continue

    dropped = set(cat_tasks) - set(common)
    if dropped:
        print(f"\n{category}: dropped {dropped} (not in all methods)")
    print(f"\n{category}: {len(common)} tasks, methods: {relevant_methods}")

    cat_data = all_data[
        (all_data["task"].isin(common)) & (all_data["method"].isin(relevant_methods))
    ]
    extra = (f"Only tasks common to all methods shown. Dropped: {sorted(dropped)}"
             if dropped else "All category tasks included.")

    # Normalized
    agg = compute_normalized_aggregate(cat_data, eval_step=EVAL_STEP)
    fig = plot_normalized_aggregate(
        agg, COLORS,
        title=f"{category} — Normalized Aggregate ({len(common)} tasks)",
        subtitle_tasks=common, subtitle_methods=relevant_methods,
        subtitle_extra=extra,
    )
    aggregate_results[category] = (agg, fig)
    fig.show()

    # Unnormalized
    uagg = compute_unnormalized_aggregate(cat_data, eval_step=EVAL_STEP)
    ufig = plot_unnormalized_aggregate(
        uagg, COLORS,
        title=f"{category} — Unnormalized Aggregate ({len(common)} tasks)",
        subtitle_tasks=common, subtitle_methods=relevant_methods,
        subtitle_extra=extra,
    )
    unnorm_aggregate_results[category] = (uagg, ufig)
    ufig.show()

# ------------------------------------------------------------------
# Total aggregate (excluding AGGREGATE_EXCLUDE_METHODS)
# ------------------------------------------------------------------
all_cat_tasks = [t for tasks in task_categories.values() for t in tasks]
common_all, total_methods = get_common_tasks_for_category(
    all_data, all_cat_tasks, AGGREGATE_EXCLUDE_METHODS,
)
if common_all:
    mask_all = (
        (all_data["task"].isin(common_all))
        & (all_data["method"].isin(total_methods))
    )
    total_data = all_data[mask_all]
    excluded_str = ", ".join(sorted(AGGREGATE_EXCLUDE_METHODS))
    total_extra = f"Excluded from total: {excluded_str}. Only common tasks across remaining methods."

    # Normalized total
    total_agg = compute_normalized_aggregate(total_data, eval_step=EVAL_STEP)
    total_fig = plot_normalized_aggregate(
        total_agg, COLORS,
        title=f"Total Aggregate ({len(common_all)} tasks, excl. {excluded_str})",
        subtitle_tasks=common_all, subtitle_methods=total_methods,
        subtitle_extra=total_extra,
    )
    aggregate_results["_total"] = (total_agg, total_fig)
    total_fig.show()

    # Unnormalized total
    total_uagg = compute_unnormalized_aggregate(total_data, eval_step=EVAL_STEP)
    total_ufig = plot_unnormalized_aggregate(
        total_uagg, COLORS,
        title=f"Total Aggregate Unnormalized ({len(common_all)} tasks, excl. {excluded_str})",
        subtitle_tasks=common_all, subtitle_methods=total_methods,
        subtitle_extra=total_extra,
    )
    unnorm_aggregate_results["_total"] = (total_uagg, total_ufig)
    total_ufig.show()

    print(f"\nTotal: {len(common_all)} tasks, methods: {total_methods}")
else:
    print("No tasks common to all methods for total aggregate.")

Task categories detected:
  Easy DMC (20): ['acrobot-swingup', 'cartpole-balance', 'cartpole-balance-sparse', 'cartpole-swingup', 'cartpole-swingup-sparse', 'cheetah-run', 'finger-spin', 'finger-turn-easy', 'finger-turn-hard', 'fish-swim', 'hopper-hop', 'hopper-stand', 'pendulum-swingup', 'quadruped-run', 'quadruped-walk', 'reacher-easy', 'reacher-hard', 'walker-run', 'walker-stand', 'walker-walk']
  Hard DMC (7): ['dog-run', 'dog-stand', 'dog-trot', 'dog-walk', 'humanoid-run', 'humanoid-stand', 'humanoid-walk']
  Humanoid-Bench (5): ['h1-balance-hard-v0', 'h1-balance-simple-v0', 'h1-run-v0', 'h1-slide-v0', 'h1-stand-v0']
  Other (8): ['humanoid_h1-crawl-v0', 'humanoid_h1-hurdle-v0', 'humanoid_h1-maze-v0', 'humanoid_h1-pole-v0', 'humanoid_h1-reach-v0', 'humanoid_h1-sit_hard-v0', 'humanoid_h1-sit_simple-v0', 'humanoid_h1-stair-v0']

Easy DMC: dropped {'fish-swim'} (not in all methods)

Easy DMC: 19 tasks, methods: ['BMPC', 'DreamerV3', 'EZ2', 'Ours', 'SAC', 'SimbaV2', 'TDMPC2']



Hard DMC: 7 tasks, methods: ['BMPC', 'DreamerV3', 'Ours', 'SAC', 'SimbaV2', 'TDMPC2']



Humanoid-Bench: 5 tasks, methods: ['BMPC', 'Ours', 'SimbaV2']



Other: 8 tasks, methods: ['Ours']



Total: 27 tasks, methods: ['BMPC', 'DreamerV3', 'Ours', 'SAC', 'SimbaV2', 'TDMPC2']


## Summary Table

In [118]:
def create_summary_table(
    data: pd.DataFrame,
    eval_step: int | None = None,
) -> pd.DataFrame:
    """Final mean +/- std per method per task.

    Args:
        data: Combined DataFrame.
        eval_step: Step at which to evaluate. None = max per method.

    Returns:
        Wide DataFrame: one row per task, columns <method>_mean / <method>_std.
    """
    results: list[dict] = []
    methods = sorted(data["method"].unique())

    for task in sorted(data["task"].unique()):
        td = data[data["task"] == task]
        row: dict = {"task": task}
        for method in methods:
            md = td[td["method"] == method]
            if md.empty:
                row[f"{method}_mean"] = np.nan
                row[f"{method}_std"] = np.nan
                continue
            if eval_step is not None:
                valid = md["step"].unique()
                valid = valid[valid <= eval_step]
                step = valid.max() if len(valid) else md["step"].max()
            else:
                step = md["step"].max()
            final = md[md["step"] == step]
            row[f"{method}_mean"] = final["reward"].mean()
            row[f"{method}_std"] = final["reward"].std()
        results.append(row)

    return pd.DataFrame(results)


def format_summary_pretty(df: pd.DataFrame) -> pd.DataFrame:
    """Format summary as 'mean +/- std' columns."""
    methods = sorted({c.replace('_mean', '') for c in df.columns if c.endswith('_mean')})
    formatted = df[["task"]].copy()
    for method in methods:
        mc, sc = f"{method}_mean", f"{method}_std"
        if mc in df.columns:
            formatted[method] = df.apply(
                lambda r, _mc=mc, _sc=sc: (
                    f"{r[_mc]:.0f} \u00b1 {r[_sc]:.0f}" if pd.notna(r[_mc]) else "\u2014"
                ),
                axis=1,
            )
    return formatted

In [119]:
summary = create_summary_table(all_data, eval_step=EVAL_STEP)
pretty_summary = format_summary_pretty(summary)

print(f"Summary at {EVAL_STEP:,} steps (or max available):")
pretty_summary

Summary at 200,000 steps (or max available):


Unnamed: 0,task,BMPC,DreamerV3,EZ2,Ours,SAC,SimbaV2,TDMPC2
0,acrobot-swingup,242 ± 86,178 ± 67,286 ± 100,373 ± nan,3 ± 3,193 ± 45,296 ± 111
1,cartpole-balance,993 ± 11,948 ± 18,939 ± 9,999 ± nan,998 ± 1,1000 ± 0,996 ± 5
2,cartpole-balance-sparse,1000 ± 0,958 ± 23,1000 ± nan,1000 ± nan,807 ± 335,1000 ± 0,1000 ± 0
3,cartpole-swingup,875 ± 3,796 ± 19,796 ± 30,881 ± nan,842 ± 3,880 ± 1,872 ± 11
4,cartpole-swingup-sparse,371 ± 356,615 ± 309,790 ± nan,848 ± 0,55 ± 64,779 ± 84,308 ± 416
5,cheetah-run,611 ± 107,353 ± 108,663 ± 27,804 ± nan,574 ± 0,624 ± 241,614 ± 94
6,dog-run,175 ± 18,6 ± 2,—,165 ± nan,9 ± 2,166 ± 38,67 ± 47
7,dog-stand,780 ± 89,27 ± nan,—,752 ± nan,31 ± 13,827 ± 74,526 ± 47
8,dog-trot,131 ± 72,6 ± 1,—,352 ± nan,10 ± 2,303 ± 82,71 ± 39
9,dog-walk,379 ± 140,4 ± nan,—,261 ± nan,11 ± 3,507 ± 154,208 ± 88


## Save Results

In [126]:
import os
import tempfile
import plotly.io as pio

# Fix stale TMPDIR — kaleido/choreographer calls tempfile.TemporaryDirectory()
# with no explicit dir=, so it relies on tempfile.gettempdir() which caches
# the first result.  Force both the env var AND the cached value to /tmp
# unconditionally so re-runs also work.
os.environ["TMPDIR"] = "/tmp"
os.environ.pop("TEMP", None)
os.environ.pop("TMP", None)
tempfile.tempdir = "/tmp"

# Output directory: compare_to_baselines/<sweep_id1,sweep_id2,...>
sweep_tag = ",".join(SWEEP_IDS)
output_dir = paths.RESULTS_ROOT / "compare_to_baselines" / sweep_tag
output_dir.mkdir(parents=True, exist_ok=True)

# Sub-folders for organisation
summary_dir = output_dir / "summary"
grid_dir = output_dir / "grid"
aggregate_dir = output_dir / "aggregates" / "normalized"
unnorm_dir = output_dir / "aggregates" / "unnormalized"
tasks_dir = output_dir / "tasks"
for d in (summary_dir, grid_dir, aggregate_dir, unnorm_dir, tasks_dir):
    d.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir}")


def _save_fig(fig: go.Figure, directory: Path, name: str) -> None:
    """Save a plotly figure as HTML and PNG, preserving the original style."""
    fig.write_html(str(directory / f"{name}.html"))
    png_bytes = pio.to_image(fig, format="png", scale=2)
    (directory / f"{name}.png").write_bytes(png_bytes)


# ── Summary tables ────────────────────────────────────────────
summary.to_csv(summary_dir / "summary.csv", index=False)
pretty_summary.to_csv(summary_dir / "summary_pretty.csv", index=False)

# ── Grid figure ───────────────────────────────────────────────
_save_fig(grid_fig, grid_dir, "grid_comparison")

# ── Normalized aggregate figures ──────────────────────────────
for cat_name, (agg_df, agg_fig) in aggregate_results.items():
    safe_cat = cat_name.replace(" ", "_").replace("-", "_").lower()
    agg_df.to_csv(aggregate_dir / f"{safe_cat}.csv", index=False)
    _save_fig(agg_fig, aggregate_dir, safe_cat)

# ── Unnormalized aggregate figures ────────────────────────────
for cat_name, (uagg_df, uagg_fig) in unnorm_aggregate_results.items():
    safe_cat = cat_name.replace(" ", "_").replace("-", "_").lower()
    uagg_df.to_csv(unnorm_dir / f"{safe_cat}.csv", index=False)
    _save_fig(uagg_fig, unnorm_dir, safe_cat)

# ── Individual task figures ───────────────────────────────────
for task, fig in figures.items():
    safe_name = task.replace("-", "_")
    _save_fig(fig, tasks_dir, safe_name)

n_agg = len(aggregate_results) + len(unnorm_aggregate_results)
print(f"\nSaved to {output_dir}:")
print(f"  summary/                  — summary.csv, summary_pretty.csv")
print(f"  grid/                     — grid_comparison (.html + .png)")
print(f"  aggregates/normalized/    — {len(aggregate_results)} normalized plots (.csv + .html + .png)")
print(f"  aggregates/unnormalized/  — {len(unnorm_aggregate_results)} unnormalized plots (.csv + .html + .png)")
print(f"  tasks/                    — {len(figures)} individual task plots (.html + .png)")

Output directory: /gpfs/work4/0/prjs0951/Thomas/Thesis/RL_weather/tdmpc2-with-return-based-auxiliary-tasks.worktrees/main/analysis/results/compare_to_baselines/aesenloq

Saved to /gpfs/work4/0/prjs0951/Thomas/Thesis/RL_weather/tdmpc2-with-return-based-auxiliary-tasks.worktrees/main/analysis/results/compare_to_baselines/aesenloq:
  summary/                  — summary.csv, summary_pretty.csv
  grid/                     — grid_comparison (.html + .png)
  aggregates/normalized/    — 5 normalized plots (.csv + .html + .png)
  aggregates/unnormalized/  — 5 unnormalized plots (.csv + .html + .png)
  tasks/                    — 40 individual task plots (.html + .png)
