# Mettabook

## 1. Setup

### 1.1 Imports

In [None]:
import pandas as pd

# Enable auto-reload of modules
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import plotly.graph_objects as go
from IPython.display import display
from mettabook_widgets import (
    JobLauncher,
    MetricsFetcher,
    ReplayViewer,
    TrainingConfigurator,
    WandBConnector,
)
from plotly.subplots import make_subplots

%matplotlib inline
plt.style.use("default")

print("Setup complete! Auto-reload enabled.")

### 1.2 Initialize Components

In [None]:
from run_store import get_runstore

run_store = get_runstore()

# Create widgets using the new stateless methods
config_widgets = TrainingConfigurator.create_widgets()
launcher_widgets = JobLauncher.create_widgets(config_widgets)
wandb_widgets = WandBConnector.create_widgets()
fetcher_widgets = MetricsFetcher.create_widgets()
replay_widgets = ReplayViewer.create_widgets(fetcher_widgets)  # Pass fetcher_widgets

## 2. Training

### 2.1 Configure and Launch Training Job

In [None]:
display(TrainingConfigurator.display(config_widgets))
display(JobLauncher.display(launcher_widgets))

### 2.2 View All Runs

The RunStore provides unified tracking of all your training runs across SkyPilot and W&B. It persists data locally and provides a single view of all runs.

**Important**: If you've made changes to the code, restart the kernel (Kernel → Restart) and re-run cells 1-5 to reload the modules.

In [None]:
# Display RunStore table with all your runs
rs = get_runstore()
rs.to_widget()

### 2.3 Monitor Job Status (Optional)

If you just launched a job, you can monitor its status here:

## 3. Analyze Runs

You can analyze metrics from one or multiple runs. Copy run names from the RunStore table above.

### 3.1 Fetch Metrics from W&B

Enter one or more run names (one per line) to fetch their metrics. This will also automatically fetch available replay URLs:

In [None]:
display(MetricsFetcher.display(fetcher_widgets))

### 3.2 Analyze Metrics

The fetched metrics are stored in `fetcher_widgets['state']['metrics_dfs']` as a dictionary mapping run names to pandas DataFrames.

In [None]:
# Plot overview metrics for all fetched runs
metrics_dfs = fetcher_widgets["state"]["metrics_dfs"]
if not metrics_dfs:
    print("No metrics data available. Please fetch metrics first.")
else:
    print(f"Plotting metrics for {len(metrics_dfs)} runs")

    # Find common metrics across all runs
    all_columns = set()
    for _, df in metrics_dfs.items():
        all_columns.update(df.columns)

    # Filter for overview metrics
    include_prefixes = ["overview/"]
    plot_cols = []

    for col in all_columns:
        if not any(col.startswith(prefix) for prefix in include_prefixes):
            continue
        # Check if this column exists in at least one run with numeric data
        has_numeric_data = False
        for df in metrics_dfs.values():
            if col in df.columns and pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > 1:
                has_numeric_data = True
                break
        if has_numeric_data:
            plot_cols.append(col)

    if not plot_cols:
        print("No plottable metrics found")
    else:
        # Calculate grid dimensions
        n_metrics = len(plot_cols)
        n_cols = min(3, n_metrics)  # Max 3 columns
        n_rows = (n_metrics + n_cols - 1) // n_cols

        # Create subplots
        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[col.replace("overview/", "").replace("_", " ") for col in plot_cols],
            vertical_spacing=0.08,
            horizontal_spacing=0.1,
        )

        # Color palette for different runs
        colors = ["blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"]

        # Add traces for each metric and each run
        for idx, col in enumerate(plot_cols):
            row = (idx // n_cols) + 1
            col_idx = (idx % n_cols) + 1

            # Plot each run for this metric
            for run_idx, (run_name, df) in enumerate(metrics_dfs.items()):
                if col in df.columns and "_step" in df.columns:
                    color = colors[run_idx % len(colors)]

                    # Only show legend on first subplot to avoid clutter
                    show_legend = idx == 0

                    fig.add_trace(
                        go.Scatter(
                            x=df["_step"],
                            y=df[col],
                            mode="lines",
                            name=run_name,
                            line=dict(color=color, width=2),
                            showlegend=show_legend,
                            legendgroup=run_name,  # Group all traces from same run
                        ),
                        row=row,
                        col=col_idx,
                    )

        # Update layout
        runs_text = "run" if len(metrics_dfs) == 1 else "runs"
        fig.update_layout(
            height=250 * n_rows,
            title_text=f"Overview Metrics ({len(metrics_dfs)} {runs_text})",
            showlegend=True,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        )

        # Update x-axes labels for bottom row
        for col_idx in range(1, min(n_cols, n_metrics) + 1):
            fig.update_xaxes(title_text="Steps", row=n_rows, col=col_idx)

        fig.show()

In [None]:
display(ReplayViewer.display(replay_widgets))

In [None]:
ReplayViewer.display_iframe(replay_widgets, width=1000, height=600)