# Professional Interactive Sweep Analysis Dashboard

This notebook creates a WandB-quality interactive dashboard using Plotly Dash.
Features:
- Interactive scatter plots with detailed hover tooltips
- Real-time filtering with range sliders
- Linked visualizations that update together
- Professional styling with Bootstrap components
- Click on any point to see full run details


In [1]:
# Import required libraries
import sys
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import dash
from dash import dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from jupyter_dash import JupyterDash
import warnings
warnings.filterwarnings('ignore')

# Add the metta module to path
sys.path.append(os.path.abspath('../..'))

# Import sweep utilities
from metta.sweep.wandb_utils import get_sweep_runs, deep_clean
import wandb

print("Libraries loaded successfully!")


/Users/axel/Documents/Softmax/metta/.venv/lib/python3.11/site-packages/pydantic/_internal/_config.py:323: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/


Libraries loaded successfully!


In [2]:
# Configuration
WANDB_ENTITY = "metta-research"
WANDB_PROJECT = "metta"
WANDB_SWEEP_NAME = "axel.arena_phased_812.v1"
MAX_OBSERVATIONS = 1000
HOURLY_COST = 4.6  # Dollar cost per hour for instance

print(f"Configuration set for sweep: {WANDB_SWEEP_NAME}")


Configuration set for sweep: axel.arena_phased_812.v1


In [3]:
# Helper functions
def flatten_nested_dict(d, parent_key='', sep='.'):
    """Recursively flatten a nested dictionary structure."""
    items = []
    
    if not isinstance(d, dict):
        return {parent_key: d} if parent_key else {}
    
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    
    return dict(items)

def extract_observations_to_dataframe(observations):
    """Convert protein observations to a pandas DataFrame."""
    all_rows = []
    
    for obs in observations:
        if obs.get('is_failure', False):
            continue
        
        row_data = {}
        
        # Get suggestion
        suggestion = obs.get('suggestion', {})
        
        # Flatten the suggestion dictionary
        if isinstance(suggestion, dict) and suggestion:
            flattened = flatten_nested_dict(suggestion)
            row_data.update(flattened)
        
        # Add metrics
        row_data['score'] = obs.get('objective', np.nan)
        row_data['cost'] = obs.get('cost', np.nan)
        row_data['runtime'] = obs.get('cost', np.nan)
        row_data['timestamp'] = obs.get('timestamp', obs.get('created_at', np.nan))
        row_data['run_name'] = obs.get('run_name', '')
        row_data['run_id'] = obs.get('run_id', '')
        
        all_rows.append(row_data)
    
    df = pd.DataFrame(all_rows)
    
    # Convert timestamp to datetime
    if 'timestamp' in df.columns and not df['timestamp'].isna().all():
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values('timestamp').reset_index(drop=True)
    
    # Add dollar cost
    if 'runtime' in df.columns:
        df['dollar_cost'] = (df['runtime'] / 3600.0) * HOURLY_COST
    
    return df

print("Helper functions defined")


Helper functions defined


In [4]:
# Load CSV data (alternative to WandB API)
# If you have the sweep_analysis.csv file from the previous notebook
try:
    df = pd.read_csv("sweep_analysis.csv")
    print(f"Loaded {len(df)} observations from CSV")
    
    # Ensure timestamp is datetime
    if 'timestamp' in df.columns:
        df['timestamp'] = pd.to_datetime(df['timestamp'])
    
    # Calculate dollar cost if not present
    if 'dollar_cost' not in df.columns and 'runtime' in df.columns:
        df['dollar_cost'] = (df['runtime'] / 3600.0) * HOURLY_COST
    
    param_cols = [col for col in df.columns if col not in ['score', 'cost', 'runtime', 'timestamp', 'run_name', 'run_id', 'dollar_cost']]
    print(f"Found {len(param_cols)} hyperparameters")
    
except FileNotFoundError:
    print("CSV file not found. Please run the sweep_analysis notebook first to generate the CSV.")


Loaded 108 observations from CSV
Found 15 hyperparameters


## Interactive Dashboard using Plotly (No Dash Required)

This creates a fully interactive dashboard using just Plotly, which works immediately without additional setup.


In [5]:
# Create comprehensive interactive dashboard using Plotly subplots
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Create figure with subplots
fig = make_subplots(
    rows=3, cols=2,
    subplot_titles=('Cost vs Score (Hover for Details)', 
                   'Score Progression Over Time',
                   'Parameter Importance',
                   'Efficiency Frontier',
                   'Score Distribution',
                   'Cost Distribution'),
    specs=[[{'type': 'scatter'}, {'type': 'scatter'}],
           [{'type': 'bar'}, {'type': 'scatter'}],
           [{'type': 'histogram'}, {'type': 'histogram'}]],
    vertical_spacing=0.12,
    horizontal_spacing=0.15
)

# 1. Cost vs Score with detailed hover
hover_text = []
for idx, row in df.iterrows():
    text = f"<b>Run: {row.get('run_name', idx)}</b><br>"
    text += f"Score: {row['score']:.4f}<br>"
    text += f"Cost: ${row.get('dollar_cost', row.get('cost', 0)):.4f}<br>"
    text += f"Runtime: {row.get('runtime', 0):.1f}s<br>"
    if 'trainer.total_timesteps' in row:
        text += f"Timesteps: {row['trainer.total_timesteps']/1e6:.1f}M<br>"
    if 'trainer.optimizer.learning_rate' in row:
        text += f"LR: {row['trainer.optimizer.learning_rate']:.6f}<br>"
    if 'trainer.ppo.clip_coef' in row:
        text += f"Clip: {row['trainer.ppo.clip_coef']:.4f}"
    hover_text.append(text)

fig.add_trace(
    go.Scatter(
        x=df.get('dollar_cost', df.get('cost', [])),
        y=df['score'],
        mode='markers',
        marker=dict(
            size=8,
            color=df['score'],
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(
                title="Score",
                x=1.15,
                len=0.3,
                y=0.85
            ),
            line=dict(width=0.5, color='white')
        ),
        text=hover_text,
        hovertemplate='%{text}<extra></extra>',
        name='Runs'
    ),
    row=1, col=1
)

# Highlight best point
if not df.empty:
    best_idx = df['score'].idxmax()
    fig.add_trace(
        go.Scatter(
            x=[df.loc[best_idx, 'dollar_cost'] if 'dollar_cost' in df.columns else df.loc[best_idx, 'cost']],
            y=[df.loc[best_idx, 'score']],
            mode='markers',
            marker=dict(
                size=15,
                color='red',
                symbol='star',
                line=dict(width=2, color='darkred')
            ),
            name='Best',
            hovertemplate=f"<b>BEST</b><br>Score: {df.loc[best_idx, 'score']:.4f}<extra></extra>"
        ),
        row=1, col=1
    )

# 2. Score over time
if 'timestamp' in df.columns:
    fig.add_trace(
        go.Scatter(
            x=df['timestamp'],
            y=df['score'],
            mode='markers+lines',
            marker=dict(
                size=6,
                color=range(len(df)),
                colorscale='Plasma',
                showscale=False
            ),
            line=dict(color='rgba(100,100,100,0.3)', width=1),
            name='Score',
            hovertemplate='Score: %{y:.4f}<br>Time: %{x}<extra></extra>'
        ),
        row=1, col=2
    )
    
    # Add moving average
    window = min(5, len(df) // 3)
    if window > 1:
        df_sorted = df.sort_values('timestamp')
        rolling_mean = df_sorted['score'].rolling(window=window, center=True).mean()
        fig.add_trace(
            go.Scatter(
                x=df_sorted['timestamp'],
                y=rolling_mean,
                mode='lines',
                line=dict(color='red', width=2),
                name='Avg',
                hovertemplate='Avg: %{y:.4f}<extra></extra>'
            ),
            row=1, col=2
        )

# 3. Parameter importance
param_cols = [col for col in df.columns if col.startswith('trainer.') and pd.api.types.is_numeric_dtype(df[col])]
correlations = []
for col in param_cols[:15]:  # Top 15 parameters
    if df[col].nunique() >= 2:
        corr = df[col].corr(df['score'])
        if not np.isnan(corr):
            correlations.append({'param': col.replace('trainer.', ''), 'corr': corr})

if correlations:
    corr_df = pd.DataFrame(correlations)
    corr_df = corr_df.nlargest(10, 'corr', keep='all')
    
    colors = ['green' if x > 0 else 'red' for x in corr_df['corr']]
    fig.add_trace(
        go.Bar(
            y=corr_df['param'],
            x=corr_df['corr'],
            orientation='h',
            marker=dict(color=colors, opacity=0.7),
            text=[f"{x:.3f}" for x in corr_df['corr']],
            textposition='outside',
            hovertemplate='%{y}: %{x:.3f}<extra></extra>',
            name='Correlation'
        ),
        row=2, col=1
    )

# 4. Pareto frontier
sorted_df = df.sort_values('dollar_cost' if 'dollar_cost' in df.columns else 'cost')
pareto_front = []
max_score = -np.inf
for _, row in sorted_df.iterrows():
    if row['score'] >= max_score:
        pareto_front.append(row)
        max_score = row['score']

if pareto_front:
    pareto_df = pd.DataFrame(pareto_front)
    
    # All points
    fig.add_trace(
        go.Scatter(
            x=df.get('dollar_cost', df.get('cost', [])),
            y=df['score'],
            mode='markers',
            marker=dict(size=4, color='lightgray'),
            name='All',
            hovertemplate='Cost: $%{x:.2f}<br>Score: %{y:.4f}<extra></extra>'
        ),
        row=2, col=2
    )
    
    # Pareto frontier
    fig.add_trace(
        go.Scatter(
            x=pareto_df.get('dollar_cost', pareto_df.get('cost', [])),
            y=pareto_df['score'],
            mode='lines+markers',
            line=dict(color='red', width=2),
            marker=dict(size=8, color='red'),
            name='Pareto',
            hovertemplate='<b>Efficient</b><br>Cost: $%{x:.2f}<br>Score: %{y:.4f}<extra></extra>'
        ),
        row=2, col=2
    )

# 5. Score distribution
fig.add_trace(
    go.Histogram(
        x=df['score'],
        nbinsx=20,
        marker=dict(color='green', opacity=0.7),
        name='Score',
        hovertemplate='Score: %{x}<br>Count: %{y}<extra></extra>'
    ),
    row=3, col=1
)

# 6. Cost distribution
fig.add_trace(
    go.Histogram(
        x=df.get('dollar_cost', df.get('cost', [])),
        nbinsx=20,
        marker=dict(color='blue', opacity=0.7),
        name='Cost',
        hovertemplate='Cost: $%{x:.2f}<br>Count: %{y}<extra></extra>'
    ),
    row=3, col=2
)

# Update layout
fig.update_layout(
    title_text=f"<b>Interactive Sweep Analysis Dashboard</b><br><sub>Sweep: {WANDB_SWEEP_NAME}</sub>",
    height=1200,
    showlegend=False,
    template='plotly_white',
    hovermode='closest'
)

# Update axes labels
fig.update_xaxes(title_text="Cost ($)", row=1, col=1)
fig.update_yaxes(title_text="Score", row=1, col=1)

fig.update_xaxes(title_text="Timestamp", row=1, col=2)
fig.update_yaxes(title_text="Score", row=1, col=2)

fig.update_xaxes(title_text="Correlation", row=2, col=1)
fig.update_yaxes(title_text="Parameter", row=2, col=1)

fig.update_xaxes(title_text="Cost ($)", row=2, col=2)
fig.update_yaxes(title_text="Score", row=2, col=2)

fig.update_xaxes(title_text="Score", row=3, col=1)
fig.update_yaxes(title_text="Count", row=3, col=1)

fig.update_xaxes(title_text="Cost ($)", row=3, col=2)
fig.update_yaxes(title_text="Count", row=3, col=2)

# Display the interactive dashboard
fig.show()

print(f"\n✅ Interactive dashboard created successfully!")
print(f"📊 Total runs visualized: {len(df)}")
print(f"⭐ Best score: {df['score'].max():.4f}")
print(f"💰 Total cost: ${df.get('dollar_cost', df.get('cost', pd.Series([0]))).sum():.2f}")


ValueError: 
    Invalid value of type 'builtins.range' received for the 'color' property of scatter.marker
        Received value: range(0, 108)

    The 'color' property is a color and may be specified as:
      - A hex string (e.g. '#ff0000')
      - An rgb/rgba string (e.g. 'rgb(255,0,0)')
      - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
      - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
      - A named CSS color: see https://plotly.com/python/css-colors/ for a list
      - A number that will be interpreted as a color
        according to scatter.marker.colorscale
      - A list or array of any of the above

## Full Dash Implementation (Optional)

Run the cell below to create a fully interactive Dash dashboard with filters and click interactions.
This provides the most professional experience with real-time filtering and detailed run information on click.
