# Mettabook

## Setup

In [7]:
# Optional: confirm you're set up to connect to the services used in this notebook
#    If the command does not run, run `./install.sh` from your terminal

!metta status --components=core,system,aws,wandb --non-interactive

[0m[0m[34mComponent | Installed  | Connected As              | Expected             | Status[0m[0m
[0m[34m----------------------------------------------------------------------------------[0m[0m
[0m[34mcore     | Yes        | -                         | -                    |[0m[0m[0m[32mOK[0m[0m
[0m[34msystem   | No         | -                         | -                    |[0m[0m[0m[31mNOT INSTALLED[0m[0m
[0m[34maws      | Yes        | 767406518141              | 751442549699         |[0m[0m[0m[33mWRONG ACCOUNT[0m[0m
[0m[34mwandb    | Yes        | metta-research            | metta-research       |[0m[0m[0m[32mOK[0m[0m
[0m[34m----------------------------------------------------------------------------------[0m[0m
[0m[33mSome components are not installed. Run 'metta install' to set them up.[0m[0m
[0m[33mComponents not installed: system[0m[0m
[0m[34mTo fix: metta install system[0m[0m
[0m[0m

In [8]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from experiments.notebooks.utils.metrics import fetch_metrics
from experiments.notebooks.utils.monitoring import monitor_training_statuses
from experiments.notebooks.utils.replays import show_replay
from experiments.notebooks.utils.training import launch_training
from datetime import datetime
from metta.common.wandb.wandb_runs import find_training_runs
from metta.rl.trainer_config import TrainerConfig, CheckpointConfig, TorchProfilerConfig, SimulationConfig
from datetime import datetime
from metta.rl.trainer_config import TrainerConfig, CheckpointConfig, TorchProfilerConfig, SimulationConfig
from experiments.notebooks.utils.training import launch_training
from datetime import datetime

%matplotlib inline
plt.style.use("default")

print("Setup complete! Auto-reload enabled.")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Setup complete! Auto-reload enabled.


## Configure the Training Environment

In [None]:
from metta.mettagrid.config import builder
from pprint import pprint

env_cfg = builder.arena(num_agents=2, combat=True)
curriculum = Curriculum()
agents = [1, 2, 4, 8, 16, 24]
combat = [False, True]

curriculum.add_bucket("game.agent.rewards.inventory.heart", [0, 1])

curriculum.add_bucket("game.objects.altar.input_resources.ore_red", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.battery_red", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.laser_red", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.armor_red", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.blueprint_red", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.ore_blue", [0, 1, 2])
curriculum.add_bucket("game.objects.altar.input_resources.battery_blue", [0, 1, 2])








curriculum.add_task(builder.arena(num_agents=24, combat=False))
curriculum.add_task(builder.arena(num_agents=24, combat=True, num_teams=2))

arena_curriculum = Curriculum()
arena_curriculum.add_task(builder.arena(num_agents=24, combat=False))
arena_curriculum.add_task(builder.arena(num_agents=24, combat=True, num_teams=2))




curriculum_cfg = builder.curriculum.basic(env_cfg)
# map_cfg = builder.maps.arena.basic()


print(env_cfg.model_dump_json(indent=2))

{
  "game": {
    "inventory_item_names": [
      "ore_red",
      "ore_blue",
      "ore_green",
      "battery_red",
      "battery_blue",
      "battery_green",
      "heart",
      "armor",
      "laser",
      "blueprint"
    ],
    "num_agents": 2,
    "max_steps": 1000,
    "episode_truncates": false,
    "obs_width": 11,
    "obs_height": 11,
    "num_observation_tokens": 200,
    "agent": {
      "default_resource_limit": 50,
      "resource_limits": {
        "heart": 255
      },
      "freeze_duration": 10,
      "rewards": {
        "inventory": {
          "ore_red": null,
          "ore_blue": null,
          "ore_green": null,
          "ore_red_max": null,
          "ore_blue_max": null,
          "ore_green_max": null,
          "battery_red": null,
          "battery_blue": null,
          "battery_green": null,
          "battery_red_max": null,
          "battery_blue_max": null,
          "battery_green_max": null,
          "heart": 1.0,
          "heart_max": nu

## Configure Trainer

In [23]:
run_name = f"{os.environ.get('USER')}.training-run.{datetime.now().strftime('%Y-%m-%d_%H-%M')}"
print(f"Launching training with run name: {run_name}...")

trainer_cfg = TrainerConfig(
    num_workers=6,
    profiler=TorchProfilerConfig(
        profile_dir=f"train_dir/{run_name}/profiles",
        # profile_dir="s3://softmax-public/profiles/${run}"
    ),
    checkpoint=CheckpointConfig(
        checkpoint_dir=f"train_dir/{run_name}/checkpoints",
        # checkpoint_dir="s3://softmax-public/checkpoints/${run}"
    ),
    simulation=SimulationConfig(
        replay_dir=f"train_dir/{run_name}"
        # replay_dir="s3://softmax-public/replays/${run}"
    ),
    env=env_cfg,
)

print(trainer_cfg.model_dump_json(indent=2))

Launching training with run name: daveey.training-run.2025-08-08_23-50...
{
  "total_timesteps": 10000000000,
  "ppo": {
    "clip_coef": 0.1,
    "ent_coef": 0.0021,
    "gae_lambda": 0.916,
    "gamma": 0.977,
    "max_grad_norm": 0.5,
    "vf_clip_coef": 0.1,
    "vf_coef": 0.44,
    "l2_reg_loss_coef": 0.0,
    "l2_init_loss_coef": 0.0,
    "norm_adv": true,
    "clip_vloss": true,
    "target_kl": null
  },
  "optimizer": {
    "type": "adam",
    "learning_rate": 0.000457,
    "beta1": 0.9,
    "beta2": 0.999,
    "eps": 1e-12,
    "weight_decay": 0.0
  },
  "prioritized_experience_replay": {
    "prio_alpha": 0.0,
    "prio_beta0": 0.6
  },
  "vtrace": {
    "vtrace_rho_clip": 1.0,
    "vtrace_c_clip": 1.0
  },
  "zero_copy": true,
  "require_contiguous_env_ids": false,
  "verbose": true,
  "batch_size": 524288,
  "minibatch_size": 16384,
  "bptt_horizon": 64,
  "update_epochs": 1,
  "scale_batches_by_world_size": false,
  "cpu_offload": false,
  "compile": false,
  "compile_mod

## Launch Training

In [24]:
# save trainer_cfg to a file
import yaml
import subprocess

# save trainer_cfg to a file
with open("../../configs/trainer/notebook.yaml", "w") as f:
    yaml.dump(trainer_cfg.model_dump_json(), f, indent=2)

subprocess.run(["tools/train.py", "trainer=notebook"], cwd="../../")

# # View `launch_training` function for all options
# result = launch_training(
#     run_name=run_name,
#     curriculum="env/mettagrid/arena/basic",
#     wandb_tags=[f"{os.environ.get('USER')}-arena-experiment"],
#     additional_args=["--skip-git-check"],
# )

W0808 23:50:48.355000 80147 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


[2025-08-08 23:50:49,929][HYDRA] Registered 20 custom resolvers at the start of a run
Setting run name to daveey.basic_easy_shaped.20250808_235049
[2;36m[23:50:49.945492][0m[2;36m [0m[34mINFO    [0m Starting main from                ]8;id=205236;file:///Users/daveey/code/metta/metta/util/metta_script.py\[2mmetta_script.py[0m]8;;\[2m:[0m]8;id=114676;file:///Users/daveey/code/metta/metta/util/metta_script.py#101\[2m101[0m]8;;\
[2;36m                  [0m         [35m/Users/daveey/code/metta/tools/[0m[95mtr[0m [2m                   [0m
[2;36m                  [0m         [95main.py[0m with run_dir:              [2m                   [0m
[2;36m                  [0m         .[35m/train_dir/[0m[95mdaveey.basic_easy_sha[0m [2m                   [0m
[2;36m                  [0m         [95mped.20250808_235049[0m               [2m                   [0m
[2;36m[23:50:49.948246][0m[2;36m [0m[34mINFO    [0m Environment setup completed       ]8;i

wandb: Syncing run daveey.basic_easy_shaped.20250808_235049
wandb: 🚀 View run at https://wandb.ai/metta-research/metta/runs/daveey.basic_easy_shaped.20250808_235049


[2;36m[23:50:50.990741][0m[2;36m [0m[34mINFO    [0m Successfully initialized W&B     ]8;id=519896;file:///Users/daveey/code/metta/common/src/metta/common/wandb/wandb_context.py\[2mwandb_context.py[0m]8;;\[2m:[0m]8;id=579715;file:///Users/daveey/code/metta/common/src/metta/common/wandb/wandb_context.py#120\[2m120[0m]8;;\
[2;36m                  [0m         run:                             [2m                    [0m
[2;36m                  [0m         daveey.basic_easy_shaped.2025080 [2m                    [0m
[2;36m                  [0m         8_235049                         [2m                    [0m
[2;36m                  [0m         [1m([0mdaveey.basic_easy_shaped.202508 [2m                    [0m
[2;36m                  [0m         08_235049[1m)[0m                       [2m                    [0m
[2;36m[23:50:50.991756][0m[2;36m [0m[34mINFO    [0m HEARTBEAT_FILE env var not set.  ]8;id=797911;file:///Users/daveey/code/metta/commo

wandb: uploading config.yaml; uploading output.log
wandb:                                                                                
wandb: 🚀 View run daveey.basic_easy_shaped.20250808_235049 at: https://wandb.ai/metta-research/metta/runs/daveey.basic_easy_shaped.20250808_235049
wandb: ⭐️ View project at: https://wandb.ai/metta-research/metta
wandb: Synced 8 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Error executing job with overrides: ['trainer=notebook']
Traceback (most recent call last):
  File "/Users/daveey/code/metta/metta/util/metta_script.py", line 112, in extended_main
    result = main(cfg)
             ^^^^^^^^^
  File "/Users/daveey/code/metta/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/daveey/code/metta/tools/train.py", line 191, in main
    handle_train(cfg, wandb_run, logger)
  File "/Users/da

CompletedProcess(args=['tools/train.py', 'trainer=notebook'], returncode=1)

## Monitor Training Jobs

In [None]:
# Monitor Training
run_names = ["daveey.navigation.low_reward.baseline.2", "daveey.navigation.low_reward.baseline.07-18"]

# Optional: instead, find all runs that meet some criteria
# run_names = find_training_runs(
#     # wandb_tags=["low_reward"],
#     # state="finished",
#     author=os.getenv("USER"),
#     limit=5,
# )

df = monitor_training_statuses(run_names, show_metrics=["_step", "overview/reward"])

## Fetch Metrics

In [None]:
metrics_dfs = fetch_metrics(run_names, samples=500)

## Analyze Metrics

In [None]:
# Plot overview metrics for all fetched runs
if not metrics_dfs:
    print("No metrics data available. Please fetch metrics first.")
else:
    print(f"Plotting metrics for {len(metrics_dfs)} runs")

    # Find common metrics across all runs
    all_columns = set()
    for _, df in metrics_dfs.items():
        all_columns.update(df.columns)

    columns = ["overview/reward", "losses/explained_variance"]
    plot_cols = []

    for col in all_columns:
        if col not in columns:
            continue
        # Check if this column exists in at least one run with numeric data
        has_numeric_data = False
        for df in metrics_dfs.values():
            if col in df.columns and pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > 1:
                has_numeric_data = True
                break
        if has_numeric_data:
            plot_cols.append(col)

    if not plot_cols:
        print("No plottable metrics found")
    else:
        # Calculate grid dimensions
        n_metrics = len(plot_cols)
        n_cols = min(3, n_metrics)  # Max 3 columns
        n_rows = (n_metrics + n_cols - 1) // n_cols

        # Create subplots
        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[col.replace("overview/", "").replace("_", " ") for col in plot_cols],
            vertical_spacing=0.08,
            horizontal_spacing=0.1,
        )

        # Color palette for different runs
        colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"]

        # Add traces for each metric and each run
        for idx, col in enumerate(plot_cols):
            row = (idx // n_cols) + 1
            col_idx = (idx % n_cols) + 1

            # Plot each run for this metric
            for run_idx, (run_name, df) in enumerate(metrics_dfs.items()):
                if col in df.columns and "_step" in df.columns:
                    color = colors[run_idx % len(colors)]

                    # Only show legend on first subplot to avoid clutter
                    show_legend = idx == 0

                    fig.add_trace(
                        go.Scatter(
                            x=df["_step"],
                            y=df[col],
                            mode="lines",
                            name=run_name,
                            line=dict(color=color, width=2),
                            showlegend=show_legend,
                            legendgroup=run_name,  # Group all traces from same run
                        ),
                        row=row,
                        col=col_idx,
                    )

        # Update layout
        runs_text = "run" if len(metrics_dfs) == 1 else "runs"
        fig.update_layout(
            height=250 * n_rows,
            showlegend=True,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        )

        # Update x-axes labels for bottom row
        for col_idx in range(1, min(n_cols, n_metrics) + 1):
            fig.update_xaxes(title_text="Steps", row=n_rows, col=col_idx)

        fig.show()

## View Replays

Display replay viewer for a specific run:

In [None]:
# Show available replays
# replays = get_available_replays("daveey.lp.16x4.bptt8")

# Show the last replay for a run
show_replay("daveey.lp.16x4.bptt8", step="last", width=1000, height=600)