# Multi-Task Performance Comparison (UTD=4)

This notebook compares the performance of our method (fixed UTD=4) against baselines across 5 tasks:
- `walker-run`
- `hopper-hop`
- `quadruped-walk`
- `cartpole-swingup-sparse`
- `acrobot-swingup`

We analyze:
1.  **Per-Task Performance**: Comparing "Ours (Planning)" and "Ours (Policy)" vs Baselines.
2.  **Aggregated Performance**: Average score across all 5 tasks.


In [1]:
import sys
from pathlib import Path

# Add analysis/tools to path
NOTEBOOK_DIR = Path().resolve()
REPO_ROOT = NOTEBOOK_DIR.parent.parent.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

print(f"Repository Root: {REPO_ROOT}")

Repository Root: /home/thomas/projects/Research/Masters Thesis/tdmpc2-with-return-based-auxiliary-tasks


In [2]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

from analysis.tools import wandb_io, plotting, baselines, encodings, aggregations, paths

# Configuration
ENTITY = "thomasevers9"
PROJECT = "tdmpc2-tdmpc2"
SWEEP_ID = "l7cnylog"

TASKS = [
    "walker-run",
    "hopper-hop",
    "quadruped-walk",
    "cartpole-swingup-sparse",
    "acrobot-swingup"
]

METRIC_PLANNING = "eval/episode_reward"
METRIC_POLICY = "policy_eval/episode_reward"

# Step keys to look for
STEP_KEYS = ["total_env_steps", "global_step", "step", "eval/step", "_step"]
MAX_STEPS = 100_000  # Adjust if needed (e.g. 500k)

In [3]:
# 1. Fetch Sweep Data
if SWEEP_ID == "TODO":
    raise ValueError("Please set the SWEEP_ID in the configuration cell above.")

print(f"Fetching sweep {SWEEP_ID}...")
runs, manifest, source = wandb_io.fetch_sweep_runs(
    entity=ENTITY,
    project=PROJECT,
    sweep_id=SWEEP_ID,
    history_keys=[METRIC_PLANNING, METRIC_POLICY] + STEP_KEYS,
    use_cache=True,
    force_refresh=False,
)

# Convert to DataFrame for Planning Metric
print("Processing Planning Metric...")
df_planning = aggregations.runs_history_to_frame(
    runs,
    metric_key=METRIC_PLANNING,
    step_keys=STEP_KEYS,
    config_to_columns={"task": "task", "seed": "seed"},
)
df_planning["variant"] = "Ours (Planning, UTD=4)"
df_planning.rename(columns={METRIC_PLANNING: "reward"}, inplace=True)

# Convert to DataFrame for Policy Metric
print("Processing Policy Metric...")
try:
    df_policy = aggregations.runs_history_to_frame(
        runs,
        metric_key=METRIC_POLICY,
        step_keys=STEP_KEYS,
        config_to_columns={"task": "task", "seed": "seed"},
    )
    df_policy["variant"] = "Ours (Policy, UTD=4)"
    df_policy.rename(columns={METRIC_POLICY: "reward"}, inplace=True)
except ValueError:
    print("Warning: No policy evaluation metrics found. Skipping policy plots.")
    df_policy = pd.DataFrame()

# Combine
sweep_frame = pd.concat([df_planning, df_policy], ignore_index=True)

# Filter Steps
sweep_frame = sweep_frame[sweep_frame["step"] <= MAX_STEPS]

print(f"Loaded {len(sweep_frame)} rows.")
sweep_frame.head()

Fetching sweep l7cnylog...
Processing Planning Metric...
Processing Policy Metric...
Loaded 400 rows.


Unnamed: 0,task,seed,run_id,step,reward,variant
0,cartpole-swingup-sparse,202,1t2tjaec,10000,3.0,"Ours (Planning, UTD=4)"
1,cartpole-swingup-sparse,202,1t2tjaec,20000,69.800003,"Ours (Planning, UTD=4)"
2,cartpole-swingup-sparse,202,1t2tjaec,30000,114.599998,"Ours (Planning, UTD=4)"
3,cartpole-swingup-sparse,202,1t2tjaec,40000,783.0,"Ours (Planning, UTD=4)"
4,cartpole-swingup-sparse,202,1t2tjaec,50000,846.799988,"Ours (Planning, UTD=4)"


In [4]:
# 2. Load Baselines
print("Loading baselines...")

baselines_data = {}

for task in TASKS:
    task_baselines = {}
    
    # DreamerV3
    if baselines.has_task(task, root=baselines.DREAMERV3_BASELINE_ROOT):
        df = baselines.load_task_baseline(task, root=baselines.DREAMERV3_BASELINE_ROOT)
        task_baselines["DreamerV3"] = df[df["step"] <= MAX_STEPS]

    # SAC
    if baselines.has_task(task, root=baselines.SAC_BASELINE_ROOT):
        df = baselines.load_task_baseline(task, root=baselines.SAC_BASELINE_ROOT)
        task_baselines["SAC"] = df[df["step"] <= MAX_STEPS]

    # TDMPC2 (Original)
    if baselines.has_task(task, root=baselines.STATE_BASELINE_ROOT):
        df = baselines.load_task_baseline(task, root=baselines.STATE_BASELINE_ROOT)
        task_baselines["TD-MPC2"] = df[df["step"] <= MAX_STEPS]
        
    baselines_data[task] = task_baselines

print("Baselines loaded.")

Loading baselines...
Baselines loaded.


In [5]:
# 3. Per-Task Plots

baseline_colors = {
    "DreamerV3": "#8c564b", # Brown
    "SAC": "#7f7f7f",       # Gray
    "TD-MPC2": "#000000"    # Black
}

output_dir = paths.notebook_results_dir("multitask_comparison")

for task in TASKS:
    print(f"Plotting {task}...")
    
    # Filter Ours
    task_df = sweep_frame[sweep_frame["task"] == task].copy()
    if task_df.empty:
        print(f"  No data found for {task} in sweep.")
        continue
        
    # Create Figure
    fig = plotting.sample_efficiency_figure(
        frame=task_df,
        metric_key="reward",
        variant_column="variant",
        task_name=f"{task} (UTD=4)",
        baseline_frame=pd.DataFrame(),
        baseline_label="",
        baseline_step_cap=MAX_STEPS
    )
    
    # Add Baselines
    for name, df in baselines_data.get(task, {}).items():
        summary = (
            df.groupby("step", as_index=False)
            .agg(mean_reward=("reward", "mean"))
            .sort_values("step")
        )
        fig.add_trace(
            go.Scatter(
                x=summary["step"],
                y=summary["mean_reward"],
                mode="lines",
                name=name,
                line=dict(color=baseline_colors.get(name, "gray"), dash="dash", width=2),
                legendgroup="Baselines"
            )
        )
        
    # Save
    output_path = output_dir / f"{task}_comparison.png"
    try:
        plotting.write_png(fig, output_path=output_path)
    except Exception as e:
        print(f"  Failed to save PNG: {e}")
        # Fallback to HTML
        html_path = output_path.with_suffix(".html")
        fig.write_html(str(html_path))
        print(f"  Saved as HTML instead: {html_path}")

    # Show one example in notebook (optional, or show all)
    fig.show()

Plotting walker-run...


Plotting hopper-hop...


Plotting quadruped-walk...


Plotting cartpole-swingup-sparse...


Plotting acrobot-swingup...


In [6]:
# 4. Aggregated Performance (Average across tasks)

# We need to interpolate all runs to a common step grid to average them.
common_steps = np.linspace(0, MAX_STEPS, 100) # 100 points

def interpolate_and_average(frames_dict, label):
    """
    Interpolates multiple task dataframes to common_steps and averages them.
    frames_dict: {task_name: dataframe}
    """
    interpolated_rewards = []
    
    for task, df in frames_dict.items():
        if df.empty: continue
        
        # Average over seeds first for this task
        # Note: 'reward' column is guaranteed by our preprocessing
        task_summary = df.groupby("step")["reward"].mean()
        
        # Interpolate
        interp_reward = np.interp(common_steps, task_summary.index, task_summary.values)
        interpolated_rewards.append(interp_reward)
        
    if not interpolated_rewards:
        return None
        
    # Average across tasks
    avg_reward = np.mean(interpolated_rewards, axis=0)
    return avg_reward

# --- Prepare Data Structures ---

# Ours (Planning)
ours_planning_frames = {t: sweep_frame[(sweep_frame["task"]==t) & (sweep_frame["variant"].str.contains("Planning"))] for t in TASKS}
avg_ours_planning = interpolate_and_average(ours_planning_frames, "Ours (Planning)")

# Ours (Policy)
ours_policy_frames = {t: sweep_frame[(sweep_frame["task"]==t) & (sweep_frame["variant"].str.contains("Policy"))] for t in TASKS}
avg_ours_policy = interpolate_and_average(ours_policy_frames, "Ours (Policy)")

# Baselines
avg_baselines = {}
for base_name in ["DreamerV3", "SAC", "TD-MPC2"]:
    base_frames = {t: baselines_data[t].get(base_name, pd.DataFrame()) for t in TASKS}
    avg_baselines[base_name] = interpolate_and_average(base_frames, base_name)

# --- Plotting Aggregate ---

fig_agg = go.Figure()

# Plot Ours
if avg_ours_planning is not None:
    fig_agg.add_trace(go.Scatter(x=common_steps, y=avg_ours_planning, mode='lines', name='Ours (Planning, UTD=4)'))

if avg_ours_policy is not None:
    fig_agg.add_trace(go.Scatter(x=common_steps, y=avg_ours_policy, mode='lines', name='Ours (Policy, UTD=4)'))

# Plot Baselines
for name, data in avg_baselines.items():
    if data is not None:
        fig_agg.add_trace(go.Scatter(
            x=common_steps, y=data, mode='lines', 
            name=name,
            line=dict(color=baseline_colors.get(name, "gray"), dash="dash", width=2)
        ))

fig_agg.update_layout(
    title="Average Performance (5 Tasks)",
    xaxis_title="Environment Steps",
    yaxis_title="Average Episode Reward",
    xaxis_range=[0, MAX_STEPS]
)

fig_agg.show()

# Save
output_path_agg = output_dir / "average_performance.png"
try:
    plotting.write_png(fig_agg, output_path=output_path_agg)
    print(f"Saved aggregate plot to {output_path_agg}")
except Exception as e:
    print(f"Failed to save aggregate PNG: {e}")
    # Fallback to HTML
    html_path = output_path_agg.with_suffix(".html")
    fig_agg.write_html(str(html_path))
    print(f"Saved aggregate plot as HTML instead: {html_path}")

Saved aggregate plot to /home/thomas/projects/Research/Masters Thesis/tdmpc2-with-return-based-auxiliary-tasks/analysis/results/multitask_comparison/average_performance.png
