# Mettabook

## 1. Setup

### 1.1 Imports

In [None]:
import pandas as pd

# Enable auto-reload of modules
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display
from mettabook_widgets import (
    TrainingConfigurator,
    JobLauncher,
    JobStatusMonitor,
    WandBConnector,
    MetricsFetcher,
    ReplayViewer,
)


%matplotlib inline
plt.style.use("default")

print("Setup complete! Auto-reload enabled.")

### 1.2 Initialize Components

In [None]:
config = TrainingConfigurator()
launcher = JobLauncher(config)
monitor = JobStatusMonitor(launcher)
wandb_conn = WandBConnector()
fetcher = MetricsFetcher(wandb_conn)
replay_viewer = ReplayViewer(wandb_conn)

print("Components initialized!")

### 1.3. Confirm Credential Setup

In [None]:
# Confirm that you are connected to skypilot and wandb. If not, run `metta install` or follow the prompts.
!metta status

## 2. Training

This section allows you to launch and monitor a training run. You can skip to the "Analyze a Run" section if you have an existing run.


### 2.1 Specify Training Job

In [None]:
display(config.display())

### 2.2. Launch Training Job

In [None]:
# Display job launcher
display(launcher.display())

## 2.3. Monitor Training Job

In [None]:
display(monitor.display())
monitor.start_monitoring()

## 3. Analyze a Run

### 3.1 Pick which run to analyze

If you ran the Training section, you can deselect "Use specific run" to autoselect the run you trained.

In [None]:
display(wandb_conn.display())
# I'm using a different one here to avoid having to wait for the other run we kicked off to finish training

### 3.2 Fetch metrics

In [None]:
display(fetcher.display())
if wandb_conn.run:
    fetcher.auto_fetch()
else:
    print("Run section 3.1 first")

## 3.3. Analyze

`fetcher.metrics_df` contains a dataframe with the sampled metrics from above

You can analyze them in any way you like. Below is some boilerplate code that shows `overview/*` metrics over agent step.

In [None]:
metrics_df = fetcher.metrics_df
if metrics_df is None or len(metrics_df) == 0:
    print("No metrics data available. Please fetch metrics first.")
else:
    assert metrics_df is not None
    include_prefixes = ["overview/"]
    plot_cols = []

    for col in metrics_df.columns:
        # Skip non-numeric columns
        if not pd.api.types.is_numeric_dtype(metrics_df[col]):
            continue
        # Skip columns with no variation
        if metrics_df[col].nunique() <= 1:
            continue
        if not any(col.startswith(prefix) for prefix in include_prefixes):
            continue
        plot_cols.append(col)

    if not plot_cols:
        print("No plottable metrics found")
    else:
        # Calculate grid dimensions
        n_metrics = len(plot_cols)
        n_cols = min(3, n_metrics)  # Max 3 columns
        n_rows = (n_metrics + n_cols - 1) // n_cols

        # Create subplots
        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[col.replace("overview/", "").replace("_", " ") for col in plot_cols],
            vertical_spacing=0.08,
            horizontal_spacing=0.1,
        )

        # Color palette
        colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"]

        # Add traces for each metric
        for idx, col in enumerate(plot_cols):
            row = (idx // n_cols) + 1
            col_idx = (idx % n_cols) + 1
            color = colors[idx % len(colors)]

            if "_step" in metrics_df.columns:
                fig.add_trace(
                    go.Scatter(
                        x=metrics_df["_step"],
                        y=metrics_df[col],
                        mode="lines",
                        name=col.replace("overview/", ""),
                        line=dict(color=color, width=2),
                        showlegend=False,
                    ),
                    row=row,
                    col=col_idx,
                )

        # Update layout
        fig.update_layout(height=250 * n_rows, title_text="Overview Metric / Agent Step", showlegend=False)

        # Update x-axes labels for bottom row
        for col_idx in range(1, min(n_cols, n_metrics) + 1):
            fig.update_xaxes(title_text="Steps", row=n_rows, col=col_idx)

        fig.show()

### 3.4. Replay Viewer

In [None]:
display(replay_viewer.display())

if wandb_conn.run:
    print("Fetching replays...")
    replay_viewer.auto_fetch()
else:
    print("Select a run first")

In [None]:
replay_viewer.display_iframe(width=1000, height=600)