# Mettabook

## Setup

In [None]:
# Optional: confirm you're set up to connect to the services used in this notebook
#    If the command does not run, run `./install.sh` from your terminal

# !metta status --components=core,system,aws,wandb --non-interactive

In [None]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from experiments.notebooks.utils.metrics import fetch_metrics
from experiments.notebooks.utils.monitoring import monitor_training_statuses
from experiments.notebooks.utils.replays import show_replay
from experiments.notebooks.utils.training import launch_training
from datetime import datetime
from metta.common.wandb.wandb_runs import find_training_runs

%matplotlib inline
plt.style.use("default")

print("Setup complete! Auto-reload enabled.")

## Launch Training

In [None]:
# Example: Launch training

# run_name = f"{os.environ.get('USER')}.training-run.{datetime.now().strftime('%Y-%m-%d_%H-%M')}"
# print(f"Launching training with run name: {run_name}...")

# # # View `launch_training` function for all options
# result = launch_training(
#     run_name=run_name,
#     curriculum="env/mettagrid/arena/basic",
#     wandb_tags=[f"{os.environ.get('USER')}-arena-experiment"],
#     additional_args=["--skip-git-check"],
# )

## Monitor Training Jobs

In [None]:
# Monitor Training
run_names = ["daveey.navigation.low_reward.baseline.2", "daveey.navigation.low_reward.baseline.07-18"]

# Optional: instead, find all runs that meet some criteria
# run_names = find_training_runs(
#     # wandb_tags=["low_reward"],
#     # state="finished",
#     author=os.getenv("USER"),
#     limit=5,
# )

df = monitor_training_statuses(run_names, show_metrics=["_step", "overview/reward"])

## Fetch Metrics

In [None]:
metrics_dfs = fetch_metrics(run_names, samples=500)

## Analyze Metrics

In [None]:
# Plot overview metrics for all fetched runs
if not metrics_dfs:
    print("No metrics data available. Please fetch metrics first.")
else:
    print(f"Plotting metrics for {len(metrics_dfs)} runs")

    # Find common metrics across all runs
    all_columns = set()
    for _, df in metrics_dfs.items():
        all_columns.update(df.columns)

    columns = ["overview/reward", "losses/explained_variance"]
    plot_cols = []

    for col in all_columns:
        if col not in columns:
            continue
        # Check if this column exists in at least one run with numeric data
        has_numeric_data = False
        for df in metrics_dfs.values():
            if col in df.columns and pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > 1:
                has_numeric_data = True
                break
        if has_numeric_data:
            plot_cols.append(col)

    if not plot_cols:
        print("No plottable metrics found")
    else:
        # Calculate grid dimensions
        n_metrics = len(plot_cols)
        n_cols = min(3, n_metrics)  # Max 3 columns
        n_rows = (n_metrics + n_cols - 1) // n_cols

        # Create subplots
        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[col.replace("overview/", "").replace("_", " ") for col in plot_cols],
            vertical_spacing=0.08,
            horizontal_spacing=0.1,
        )

        # Color palette for different runs
        colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"]

        # Add traces for each metric and each run
        for idx, col in enumerate(plot_cols):
            row = (idx // n_cols) + 1
            col_idx = (idx % n_cols) + 1

            # Plot each run for this metric
            for run_idx, (run_name, df) in enumerate(metrics_dfs.items()):
                if col in df.columns and "_step" in df.columns:
                    color = colors[run_idx % len(colors)]

                    # Only show legend on first subplot to avoid clutter
                    show_legend = idx == 0

                    fig.add_trace(
                        go.Scatter(
                            x=df["_step"],
                            y=df[col],
                            mode="lines",
                            name=run_name,
                            line=dict(color=color, width=2),
                            showlegend=show_legend,
                            legendgroup=run_name,  # Group all traces from same run
                        ),
                        row=row,
                        col=col_idx,
                    )

        # Update layout
        runs_text = "run" if len(metrics_dfs) == 1 else "runs"
        fig.update_layout(
            height=250 * n_rows,
            showlegend=True,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        )

        # Update x-axes labels for bottom row
        for col_idx in range(1, min(n_cols, n_metrics) + 1):
            fig.update_xaxes(title_text="Steps", row=n_rows, col=col_idx)

        fig.show()

## View Replays

Display replay viewer for a specific run:

In [None]:
# Show available replays
# replays = get_available_replays("daveey.lp.16x4.bptt8")

# Show the last replay for a run
show_replay("daveey.lp.16x4.bptt8", step="last", width=1000, height=600)