# Detach Encoder Ablation

**Goal:** Show that detaching the encoder halfway through training improves performance on precision-sensitive tasks.

**Tasks analyzed:**
- **Humanoid H1 Balance Simple** (from sweep `03_humanoid_bench_baseline`): 1M steps, detach at 500k
- **Acrobot Swingup** (from sweep `54_scaled_architecture_sweep`): 100k steps, detach at 50k

**Metrics:**
- Episode reward (train & eval) — expect increase after detach
- Consistency loss (train & validation) — expect decrease after detach

In [1]:
import sys
from pathlib import Path

# Add analysis tools to path
NOTEBOOK_DIR = Path.cwd()
ANALYSIS_ROOT = NOTEBOOK_DIR.parent if NOTEBOOK_DIR.name == "notebooks" else NOTEBOOK_DIR
if str(ANALYSIS_ROOT) not in sys.path:
    sys.path.insert(0, str(ANALYSIS_ROOT))

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from tools import wandb_io, paths

## Configuration

In [2]:
ENTITY = "thomasevers9"
PROJECT = "tdmpc2-tdmpc2"

# Task configurations
TASKS = {
    "humanoid_h1-balance_simple-v0": {
        "sweep_id": "8lsl2o4p",  # 03_humanoid_bench_baseline
        "display_name": "Humanoid H1 Balance Simple",
        "total_steps": 1_000_000,
        "detach_step": 500_000,
        "ma_steps": 10_000,  # Moving average window in env steps
    },
    "acrobot-swingup": {
        "sweep_id": "nqfufp3t",  # 54_scaled_architecture_sweep
        "display_name": "Acrobot Swingup",
        "total_steps": 100_000,
        "detach_step": 50_000,
        "ma_steps": 1_000,  # Moving average window in env steps
    },
}

# Metrics to fetch
REWARD_KEYS = ["eval/episode_reward", "train/episode_reward"]
LOSS_KEYS = ["train/consistency_loss", "validation_all/consistency_loss", "validation_recent/consistency_loss"]
ALL_KEYS = REWARD_KEYS + LOSS_KEYS + ["_step"]

# Colors
COLORS = {
    "eval/episode_reward": "#1f77b4",      # Blue
    "train/episode_reward": "#2ca02c",     # Green
    "train/consistency_loss": "#d62728",   # Red
    "validation_all/consistency_loss": "#ff7f0e",  # Orange
    "validation_recent/consistency_loss": "#9467bd",  # Purple
}

## Load Data from W&B

In [3]:

def fetch_task_data(task_name: str, task_config: dict) -> pd.DataFrame:
    """Fetch W&B data for a specific task from its sweep."""
    runs, manifest, source = wandb_io.fetch_sweep_runs(
        entity=ENTITY,
        project=PROJECT,
        sweep_id=task_config["sweep_id"],
        history_keys=ALL_KEYS,
        use_cache=True,
        force_refresh=False,
    )
    print(f"Loaded {len(runs)} runs for {task_config['display_name']} from {source}")
    
    # Filter to the specific task
    task_runs = [r for r in runs if r.get("config", {}).get("task") == task_name]
    print(f"  -> {len(task_runs)} runs for task '{task_name}'")
    
    if not task_runs:
        return pd.DataFrame()
    
    # Convert to DataFrame
    all_frames = []
    for run in task_runs:
        history = run.get("history", {})
        seed = run.get("config", {}).get("seed", 0)
        
        # History is a dict with "keys" and "rows"
        # The actual data is in history["rows"]
        if isinstance(history, dict) and "rows" in history:
            rows = history["rows"]
            if not rows:
                continue
            df = pd.DataFrame(rows)
        elif isinstance(history, list):
            # Fallback: history is directly a list of dicts
            if not history:
                continue
            df = pd.DataFrame(history)
        else:
            print(f"  Warning: Unknown history format for seed {seed}")
            continue
        
        df["seed"] = seed
        df["task"] = task_name
        all_frames.append(df)
    
    if not all_frames:
        return pd.DataFrame()
    
    combined = pd.concat(all_frames, ignore_index=True)
    print(f"  -> {len(combined)} total rows, columns: {list(combined.columns)}")
    return combined


In [4]:
# Fetch data for all tasks
task_data = {}
for task_name, config in TASKS.items():
    df = fetch_task_data(task_name, config)
    if not df.empty:
        task_data[task_name] = df
    print()

Loaded 6 runs for Humanoid H1 Balance Simple from cache
  -> 2 runs for task 'humanoid_h1-balance_simple-v0'
  -> 24939 total rows, columns: ['_step', 'eval/episode_reward', 'train/consistency_loss', 'train/episode_reward', 'validation_all/consistency_loss', 'validation_recent/consistency_loss', 'seed', 'task']

Loaded 16 runs for Acrobot Swingup from cache
  -> 8 runs for task 'acrobot-swingup'
  -> 4720 total rows, columns: ['_step', 'eval/episode_reward', 'train/consistency_loss', 'train/episode_reward', 'validation_all/consistency_loss', 'validation_recent/consistency_loss', 'seed', 'task']



In [5]:
# Check available keys for each task
for task_name, df in task_data.items():
    print(f"\n{task_name}:")
    available_keys = [k for k in ALL_KEYS if k in df.columns]
    print(f"  Available: {available_keys}")


humanoid_h1-balance_simple-v0:
  Available: ['eval/episode_reward', 'train/episode_reward', 'train/consistency_loss', 'validation_all/consistency_loss', 'validation_recent/consistency_loss', '_step']

acrobot-swingup:
  Available: ['eval/episode_reward', 'train/episode_reward', 'train/consistency_loss', 'validation_all/consistency_loss', 'validation_recent/consistency_loss', '_step']


## Plotting Functions

In [6]:

def hex_to_rgba(hex_color: str, alpha: float) -> str:
    """Convert hex color to rgba string."""
    hex_color = hex_color.lstrip('#')
    r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return f"rgba({r},{g},{b},{alpha})"


def compute_stats(df: pd.DataFrame, step_col: str, metric_col: str) -> pd.DataFrame:
    """Compute mean and std across seeds for a metric."""
    # Filter to rows with valid metric values
    valid = df[[step_col, metric_col, "seed"]].dropna()
    if valid.empty:
        return pd.DataFrame()
    
    stats = valid.groupby(step_col).agg(
        mean=(metric_col, "mean"),
        std=(metric_col, "std"),
        count=(metric_col, "count")
    ).reset_index()
    stats["std"] = stats["std"].fillna(0)
    return stats


In [7]:

def compute_ma_window_from_steps(df: pd.DataFrame, step_col: str, ma_steps: int) -> int:
    """Compute moving average window (# points) from desired step window.
    
    Args:
        df: DataFrame with step column.
        step_col: Name of the step column.
        ma_steps: Desired window size in environment steps.
    
    Returns:
        Window size in number of data points.
    """
    steps = df[step_col].dropna().sort_values().unique()
    if len(steps) < 2:
        return 1
    # Average step increment
    step_increment = np.median(np.diff(steps))
    window = max(1, int(ma_steps / step_increment))
    return window


def plot_metric_with_detach_line(
    df: pd.DataFrame,
    metrics: list,
    detach_step: int,
    title: str,
    y_title: str,
    colors: dict,
    step_col: str = "_step",
    ma_steps: int = 1000,
    log_scale: bool = False,
) -> go.Figure:
    """Plot metrics with a vertical line at the detach step.
    
    Args:
        df: DataFrame with metrics.
        metrics: List of metric keys to plot.
        detach_step: Step at which encoder was detached.
        title: Plot title.
        y_title: Y-axis title.
        colors: Dict mapping metric names to colors.
        step_col: Column name for steps.
        ma_steps: Moving average window in env steps (for training reward).
        log_scale: Whether to use log scale for y-axis.
    
    Training reward uses moving average; eval reward plotted on top (last).
    """
    fig = go.Figure()
    
    # Compute actual MA window from steps
    ma_window = compute_ma_window_from_steps(df, step_col, ma_steps)
    
    # Plot train metrics first, eval last (so eval is on top)
    metrics_train_first = sorted(metrics, key=lambda m: 0 if "train" in m else 1)
    
    for metric in metrics_train_first:
        if metric not in df.columns:
            continue
        
        stats = compute_stats(df, step_col, metric)
        if stats.empty:
            continue
        
        color = colors.get(metric, "#666666")
        x = stats[step_col]
        y_mean = stats["mean"]
        y_std = stats["std"]
        
        # Apply moving average to training reward
        is_train_reward = "train" in metric and "reward" in metric
        is_eval = "eval" in metric
        
        if is_train_reward:
            y_mean = y_mean.rolling(window=ma_window, min_periods=1).mean()
            y_std = y_std.rolling(window=ma_window, min_periods=1).mean()
        
        # Clean display name
        display_name = metric.split("/")[-1].replace("_", " ").title()
        if "train" in metric:
            display_name = f"Train {display_name}"
            if is_train_reward:
                display_name += f" (MA {ma_steps:,})"
        elif "eval" in metric:
            display_name = f"Eval {display_name}"
        elif "validation_all" in metric:
            display_name = "Val (All) Consistency"
        elif "validation_recent" in metric:
            display_name = "Val (Recent) Consistency"
        
        # Line width: thicker for eval
        line_width = 3 if is_eval else 2
        
        # Shaded std region
        fig.add_trace(go.Scatter(
            x=pd.concat([x, x[::-1]]),
            y=pd.concat([y_mean + y_std, (y_mean - y_std)[::-1]]),
            fill="toself",
            fillcolor=hex_to_rgba(color, 0.2),
            line=dict(color="rgba(0,0,0,0)"),
            showlegend=False,
            hoverinfo="skip",
        ))
        
        # Mean line
        fig.add_trace(go.Scatter(
            x=x,
            y=y_mean,
            mode="lines",
            name=display_name,
            line=dict(color=color, width=line_width),
            hovertemplate=f"{display_name}<br>Step: %{{x:,}}<br>Value: %{{y:.3f}}<extra></extra>",
        ))
    
    # Add vertical line at detach step
    fig.add_vline(
        x=detach_step,
        line_dash="dash",
        line_color="red",
        line_width=2,
        annotation_text="Encoder Detached",
        annotation_position="top",
        annotation_font_color="red",
    )
    
    fig.update_layout(
        title=dict(text=f"<b>{title}</b>", x=0.5),
        xaxis_title="Environment Steps",
        yaxis_title=y_title,
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
        hovermode="x unified",
        template="plotly_white",
        width=800,
        height=500,
    )
    
    # Log scale for y-axis with scientific notation (e.g., 1e-3)
    if log_scale:
        fig.update_yaxes(type="log", exponentformat="e")
    
    return fig


## Generate Plots for Each Task

In [8]:

all_figures = {}

for task_name, config in TASKS.items():
    if task_name not in task_data:
        print(f"Skipping {task_name} - no data")
        continue
    
    df = task_data[task_name]
    display_name = config["display_name"]
    detach_step = config["detach_step"]
    ma_steps = config.get("ma_steps", 1000)
    
    print(f"\n=== {display_name} ===")
    print(f"Detach step: {detach_step:,}, MA window: {ma_steps:,} steps")
    
    # Plot 1: Episode Rewards
    reward_fig = plot_metric_with_detach_line(
        df,
        metrics=REWARD_KEYS,
        detach_step=detach_step,
        title=f"{display_name} - Episode Reward",
        y_title="Episode Reward",
        colors=COLORS,
        ma_steps=ma_steps,
        log_scale=False,
    )
    reward_fig.show()
    all_figures[f"{task_name}_reward"] = reward_fig
    
    # Plot 2: Consistency Loss (log scale with scientific notation)
    loss_fig = plot_metric_with_detach_line(
        df,
        metrics=LOSS_KEYS,
        detach_step=detach_step,
        title=f"{display_name} - Consistency Loss",
        y_title="Consistency Loss",
        colors=COLORS,
        ma_steps=ma_steps,
        log_scale=True,
    )
    loss_fig.show()
    all_figures[f"{task_name}_loss"] = loss_fig



=== Humanoid H1 Balance Simple ===
Detach step: 500,000, MA window: 10,000 steps



=== Acrobot Swingup ===
Detach step: 50,000, MA window: 1,000 steps


## Combined Grid View

In [9]:

def create_combined_figure(task_data: dict, tasks_config: dict, colors: dict) -> go.Figure:
    """Create a 2x2 grid: rows=tasks, cols=[reward, consistency loss]."""
    
    task_names = list(task_data.keys())
    n_tasks = len(task_names)
    
    # Create subplot titles
    subplot_titles = []
    for task_name in task_names:
        display = tasks_config[task_name]["display_name"]
        subplot_titles.extend([f"{display} - Reward", f"{display} - Consistency Loss"])
    
    fig = make_subplots(
        rows=n_tasks,
        cols=2,
        subplot_titles=subplot_titles,
        vertical_spacing=0.12,
        horizontal_spacing=0.08,
    )
    
    legend_added = set()
    
    for row_idx, task_name in enumerate(task_names, 1):
        df = task_data[task_name]
        config = tasks_config[task_name]
        detach_step = config["detach_step"]
        
        # Column 1: Rewards
        for metric in REWARD_KEYS:
            if metric not in df.columns:
                continue
            stats = compute_stats(df, "_step", metric)
            if stats.empty:
                continue
            
            color = colors.get(metric, "#666666")
            x = stats["_step"]
            y_mean = stats["mean"]
            y_std = stats["std"]
            
            display_name = "Eval" if "eval" in metric else "Train"
            show_legend = metric not in legend_added
            legend_added.add(metric)
            
            # Shaded region
            fig.add_trace(
                go.Scatter(
                    x=pd.concat([x, x[::-1]]),
                    y=pd.concat([y_mean + y_std, (y_mean - y_std)[::-1]]),
                    fill="toself",
                    fillcolor=hex_to_rgba(color, 0.2),
                    line=dict(color="rgba(0,0,0,0)"),
                    showlegend=False,
                    hoverinfo="skip",
                    legendgroup=metric,
                ),
                row=row_idx, col=1
            )
            
            # Mean line
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=y_mean,
                    mode="lines",
                    name=f"{display_name} Reward",
                    line=dict(color=color, width=2),
                    showlegend=show_legend,
                    legendgroup=metric,
                ),
                row=row_idx, col=1
            )
        
        # Add detach line for rewards
        fig.add_vline(
            x=detach_step, row=row_idx, col=1,
            line_dash="dash", line_color="red", line_width=2,
        )
        
        # Column 2: Consistency Loss
        for metric in LOSS_KEYS:
            if metric not in df.columns:
                continue
            stats = compute_stats(df, "_step", metric)
            if stats.empty:
                continue
            
            color = colors.get(metric, "#666666")
            x = stats["_step"]
            y_mean = stats["mean"]
            y_std = stats["std"]
            
            if "validation_all" in metric:
                display_name = "Val (All)"
            elif "validation_recent" in metric:
                display_name = "Val (Recent)"
            else:
                display_name = "Train"
            
            show_legend = metric not in legend_added
            legend_added.add(metric)
            
            # Shaded region
            fig.add_trace(
                go.Scatter(
                    x=pd.concat([x, x[::-1]]),
                    y=pd.concat([y_mean + y_std, (y_mean - y_std)[::-1]]),
                    fill="toself",
                    fillcolor=hex_to_rgba(color, 0.2),
                    line=dict(color="rgba(0,0,0,0)"),
                    showlegend=False,
                    hoverinfo="skip",
                    legendgroup=metric,
                ),
                row=row_idx, col=2
            )
            
            # Mean line
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=y_mean,
                    mode="lines",
                    name=f"{display_name} Consistency",
                    line=dict(color=color, width=2),
                    showlegend=show_legend,
                    legendgroup=metric,
                ),
                row=row_idx, col=2
            )
        
        # Add detach line for losses
        fig.add_vline(
            x=detach_step, row=row_idx, col=2,
            line_dash="dash", line_color="red", line_width=2,
        )
    
    fig.update_layout(
        height=400 * n_tasks,
        width=1100,
        title_text="<b>Detach Encoder Ablation: Reward & Consistency Loss</b>",
        title_x=0.5,
        showlegend=True,
        legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.9)"),
        template="plotly_white",
    )
    
    # Add axis labels
    fig.update_xaxes(title_text="Steps", row=n_tasks, col=1)
    fig.update_xaxes(title_text="Steps", row=n_tasks, col=2)
    fig.update_yaxes(title_text="Reward", col=1)
    fig.update_yaxes(title_text="Loss", col=2)
    
    # Add annotation for red line legend
    fig.add_annotation(
        text="<span style=\'color:red\'>--- Encoder Detached</span>",
        xref="paper", yref="paper",
        x=1.02, y=0.5,
        showarrow=False,
        font=dict(size=12),
    )
    
    return fig


In [10]:
if task_data:
    combined_fig = create_combined_figure(task_data, TASKS, COLORS)
    combined_fig.show()
else:
    print("No task data available")

## Save Results

In [11]:
# Create output directory
output_dir = paths.notebook_results_dir("10_detach_encoder_ablation")
print(f"Output directory: {output_dir}")

# Save combined figure
if task_data:
    combined_fig.write_html(output_dir / "combined_ablation.html")
    print("Saved: combined_ablation.html")

# Save individual figures
for name, fig in all_figures.items():
    safe_name = name.replace("-", "_").replace(" ", "_")
    fig.write_html(output_dir / f"{safe_name}.html")
    print(f"Saved: {safe_name}.html")

Output directory: /gpfs/work4/0/prjs0951/Thomas/Thesis/RL_weather/tdmpc2-with-return-based-auxiliary-tasks/analysis/results/10_detach_encoder_ablation
Saved: combined_ablation.html
Saved: humanoid_h1_balance_simple_v0_reward.html
Saved: humanoid_h1_balance_simple_v0_loss.html
Saved: acrobot_swingup_reward.html
Saved: acrobot_swingup_loss.html
