# 🚀 Sweep Runner Dashboard

Launch and monitor hyperparameter sweeps on SkyPilot with real-time visualization.

## Features
- **Parallel Execution**: Launch multiple workers with configurable GPUs and nodes
- **Live Monitoring**: Real-time cost, performance, and health tracking
- **Interactive Visualizations**: Score trends and parameter importance analysis
- **Automatic Management**: Heartbeat monitoring and failure recovery


## Configuration


In [13]:
# Sweep Configuration
SWEEP_NAME = "axel.late_night_parallel"  # Your sweep name

# SkyPilot Configuration
NUM_PARALLEL_WORKERS = 2  # Number of parallel sweep workers
GPUS_PER_WORKER = 1  # GPUs per worker
NODES_PER_WORKER = 1  # Nodes per worker

# WandB Configuration
WANDB_ENTITY = "metta-research"
WANDB_PROJECT = "metta"

# Display Configuration
UPDATE_INTERVAL_SECONDS = 120  # Dashboard refresh rate
MAX_DISPLAY_RUNS = 100  # Maximum runs to display in plots


## Setup


In [14]:
import asyncio
import subprocess
import sys
import os
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
import threading
import json

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import wandb

# Add metta to path
sys.path.append(os.path.abspath('../../..'))

from metta.sweep.wandb_utils import get_sweep_runs, deep_clean
from metta.common.util.fs import get_repo_root, cd_repo_root
from metta.common.util.git import get_current_commit

# Set working directory to repo root
cd_repo_root()

print("✅ Setup complete")


✅ Setup complete


## Helper Functions


In [15]:
def flatten_nested_dict(d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
    """Recursively flatten a nested dictionary structure."""
    items = []
    
    if not isinstance(d, dict):
        return {parent_key: d} if parent_key else {}
    
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    
    return dict(items)

def get_sky_job_status(job_name: str) -> Optional[str]:
    """Get the status of a specific SkyPilot job."""
    try:
        result = subprocess.run(
            ["sky", "jobs", "queue"],
            capture_output=True,
            text=True,
            check=False
        )
        
        if result.returncode != 0:
            return "ERROR"
            
        lines = result.stdout.strip().split("\n")
        
        # Find the header line
        header_idx = None
        for i, line in enumerate(lines):
            if line.startswith("ID") and "NAME" in line and "STATUS" in line:
                header_idx = i
                break
        
        if header_idx is None:
            return "UNKNOWN"
        
        # Get column positions from header
        header_line = lines[header_idx]
        name_start = header_line.find("NAME")
        name_end = header_line.find("RESOURCES")
        status_start = header_line.find("STATUS")
        
        # Search for our job in the data rows
        for line in lines[header_idx + 1:]:
            if not line.strip() or line.startswith("No ") or line.startswith("Fetching"):
                continue
            
            # Extract name from fixed positions
            if name_end > 0:
                name = line[name_start:name_end].strip()
            else:
                name = line[name_start:].strip()
            
            # Check if this is our job
            if name == job_name:
                # Extract status
                status = line[status_start:].strip()
                return status
        
        return "NOT_FOUND"
    except Exception as e:
        return f"ERROR: {str(e)}"

def estimate_cost(runtime_seconds: float, gpus: int = 1) -> float:
    """Estimate cost based on runtime and GPU count."""
    # Rough estimate: $0.50 per GPU-hour for A10G spot instances
    gpu_hour_rate = 0.50
    runtime_hours = runtime_seconds / 3600
    return runtime_hours * gpus * gpu_hour_rate

def format_duration(seconds: float) -> str:
    """Format duration in human-readable format."""
    if seconds < 60:
        return f"{seconds:.0f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}m"
    else:
        return f"{seconds/3600:.1f}h"


## Sweep Launcher


In [16]:
class SweepLauncher:
    """Manages launching sweep workers on SkyPilot."""
    
    def __init__(self, sweep_name: str):
        self.sweep_name = sweep_name
        self.worker_jobs = []
        self.launch_time = None
        
    def launch_workers(
        self,
        num_workers: int,
        gpus_per_worker: int,
        nodes_per_worker: int
    ) -> List[str]:
        """Launch multiple sweep workers on SkyPilot."""
        
        job_names = []
        commit_hash = get_current_commit()
        
        print(f"🚀 Launching {num_workers} sweep workers...")
        print(f"   Sweep: {self.sweep_name}")
        print(f"   GPUs per worker: {gpus_per_worker}")
        print(f"   Nodes per worker: {nodes_per_worker}")
        print(f"   Git commit: {commit_hash[:8]}")
        
        for i in range(num_workers):
            worker_name = f"{self.sweep_name}" #_worker_{i+1}" TODO: Implement me!
            
            # Build the launch command
            cmd = [
                "./devops/skypilot/launch.py",
                f"--gpus={gpus_per_worker}",
                f"--nodes={nodes_per_worker}",
                "--skip-git-check",  # Assuming we want to skip for notebook usage
                "sweep",
                f"run={self.sweep_name}",
            ]
            
            print(f"   Launching worker {i+1}/{num_workers}: {worker_name}")
            print(cmd)
            try:
                result = subprocess.run(
                    cmd,
                    capture_output=True,
                    text=True,
                    check=False
                )
                
                if result.returncode == 0:
                    job_names.append(worker_name)
                    print(f"   ✅ Worker {i+1} launched successfully")
                else:
                    print(f"   ❌ Failed to launch worker {i+1}: {result.stderr}")
                    
            except Exception as e:
                print(f"   ❌ Error launching worker {i+1}: {str(e)}")
        
        self.worker_jobs = job_names
        self.launch_time = datetime.now()
        
        print(f"\n✅ Launched {len(job_names)}/{num_workers} workers successfully")
        return job_names
    
    def stop_all_workers(self):
        """Stop all sweep workers."""
        print(f"🛑 Stopping {len(self.worker_jobs)} workers...")
        
        for job_name in self.worker_jobs:
            try:
                subprocess.run(
                    ["sky", "jobs", "cancel", job_name],
                    capture_output=True,
                    check=False
                )
                print(f"   ✅ Stopped {job_name}")
            except Exception as e:
                print(f"   ❌ Failed to stop {job_name}: {str(e)}")
        
        self.worker_jobs = []
        print("✅ All workers stopped")


## Live Dashboard


In [None]:
# Enhanced dashboard features we could add:

def calculate_efficiency_metrics(df):
    """Calculate advanced efficiency metrics."""
    if df.empty:
        return {}
    
    metrics = {
        'avg_score_per_dollar': df['score'].mean() / df['cost'].sum() if df['cost'].sum() > 0 else 0,
        'best_score_per_dollar': df['score'].max() / df['cost'].sum() if df['cost'].sum() > 0 else 0,
        'improvement_rate': (df['score'].iloc[-1] - df['score'].iloc[0]) / len(df) if len(df) > 1 else 0,
        'convergence_ratio': df['score'].rolling(10).std().iloc[-1] if len(df) > 10 else float('inf'),
        'success_rate': (df['score'] > df['score'].median()).mean() * 100,
    }
    return metrics

def create_convergence_plot(df):
    """Create a convergence plot showing best score over time."""
    if df.empty:
        return go.Figure()
    
    df_sorted = df.sort_values('timestamp')
    df_sorted['best_score'] = df_sorted['score'].cummax()
    
    fig = go.Figure()
    
    # Individual runs
    fig.add_trace(go.Scatter(
        x=df_sorted['timestamp'],
        y=df_sorted['score'],
        mode='markers',
        name='Individual Runs',
        marker=dict(size=6, opacity=0.5),
    ))
    
    # Best score line
    fig.add_trace(go.Scatter(
        x=df_sorted['timestamp'],
        y=df_sorted['best_score'],
        mode='lines',
        name='Best Score',
        line=dict(color='red', width=3),
    ))
    
    # Rolling average
    if len(df_sorted) > 5:
        df_sorted['rolling_avg'] = df_sorted['score'].rolling(5, min_periods=1).mean()
        fig.add_trace(go.Scatter(
            x=df_sorted['timestamp'],
            y=df_sorted['rolling_avg'],
            mode='lines',
            name='5-Run Average',
            line=dict(color='blue', dash='dash'),
        ))
    
    fig.update_layout(
        title='Convergence Analysis',
        xaxis_title='Time',
        yaxis_title='Score',
        hovermode='x unified',
        height=400
    )
    
    return fig

def create_parameter_heatmap(df, param1, param2):
    """Create a heatmap showing parameter interactions."""
    if df.empty or param1 not in df.columns or param2 not in df.columns:
        return go.Figure()
    
    # Create pivot table
    pivot = df.pivot_table(
        values='score',
        index=param1,
        columns=param2,
        aggfunc='mean'
    )
    
    fig = go.Figure(data=go.Heatmap(
        z=pivot.values,
        x=pivot.columns,
        y=pivot.index,
        colorscale='Viridis',
        text=pivot.values.round(4),
        texttemplate='%{text}',
        textfont={"size": 10},
    ))
    
    fig.update_layout(
        title=f'Parameter Interaction: {param1} vs {param2}',
        xaxis_title=param2,
        yaxis_title=param1,
        height=400
    )
    
    return fig


In [None]:
# Note: EnhancedSweepDashboard class will be defined after SweepDashboard
# to avoid forward reference issues


In [17]:
class LiveSkyJobsMonitor:
    """Display live sky jobs queue output using watch command."""
    
    def __init__(self, sweep_name: str, refresh_seconds: int = UPDATE_INTERVAL_SECONDS):
        self.sweep_name = sweep_name
        self.refresh_seconds = refresh_seconds
        self.process = None
        
    def start(self):
        """Start watch command for sky jobs queue."""
        print(f"🔍 Starting live monitoring for sweep: {self.sweep_name}")
        print(f"   Refreshing every {self.refresh_seconds} seconds")
        print(f"   Press Ctrl+C in the terminal to stop monitoring")
        print("=" * 80)
        
        # Use watch command for clean, automatic updates
        # -n specifies interval in seconds
        # -t removes the title/header from watch
        # -c enables color output (if supported)
        cmd = f"watch -n{self.refresh_seconds} -t 'sky jobs queue -s'"
        
        print(f"Running: {cmd}")
        print("Note: This will run in your terminal. Check your terminal for the live output.")
        
        # For notebook compatibility, we can also suggest running it in a terminal
        print(f"\nAlternatively, run this in a terminal:")
        print(f"  {cmd}")
        
    def stop(self):
        """Note about stopping the watch command."""
        print("\n✅ To stop monitoring, press Ctrl+C in the terminal where watch is running")
                
    def display(self):
        """Display instructions for monitoring."""
        self.start()

class SweepDashboard:
    """Interactive dashboard for monitoring sweep progress."""
    
    def __init__(self, sweep_name: str):
        self.sweep_name = sweep_name
        self.update_thread = None
        self.stop_updates = False
        self.start_time = datetime.now()
        
        # Create dashboard widgets
        self.create_widgets()
        
    def create_widgets(self):
        """Create the dashboard UI components."""
        
        # Status cards
        self.status_html = widgets.HTML(
            value="<h3>Loading...</h3>",
            layout=widgets.Layout(width='100%')
        )
        

        
        # Plots
        self.score_plot = go.FigureWidget(
            layout=go.Layout(
                title="Score Over Time",
                xaxis_title="Time",
                yaxis_title="Score",
                height=400,
                template="plotly_white"
            )
        )
        
        self.importance_plot = go.FigureWidget(
            layout=go.Layout(
                title="Parameter Importance",
                xaxis_title="Importance",
                yaxis_title="Parameter",
                height=400,
                template="plotly_white"
            )
        )
        
        # Control buttons
        self.refresh_button = widgets.Button(
            description="Refresh Now",
            button_style='primary',
            icon='refresh'
        )
        self.refresh_button.on_click(lambda _: self.update_dashboard())
        
        # Layout
        self.dashboard = widgets.VBox([
            widgets.HTML(value="<h2>📊 Sweep Metrics Dashboard</h2>"),
            self.status_html,
            self.refresh_button,
            widgets.HBox([
                self.score_plot,
                self.importance_plot
            ])
        ])
        
    def fetch_sweep_data(self) -> Tuple[pd.DataFrame, List[Dict]]:
        """Fetch current sweep data from WandB."""
        try:
            runs = get_sweep_runs(
                sweep_name=self.sweep_name,
                entity=WANDB_ENTITY,
                project=WANDB_PROJECT
            )
            
            observations = []
            run_metadata = []
            
            for run in runs[:MAX_DISPLAY_RUNS]:
                protein_obs = run.summary.get("protein_observation")
                protein_suggestion = run.summary.get("protein_suggestion")
                
                # Collect run metadata for cost analysis
                metadata = {
                    'run_id': run.id,
                    'run_name': run.name,
                    'timestamp': run.created_at,
                    'runtime_seconds': run.summary.get('_wandb', {}).get('runtime', 0),
                    'state': run.state,
                    'score': run.summary.get('score', run.summary.get('protein.objective', 0))
                }
                
                # Try to get actual dollar cost from metadata
                # Check multiple possible locations for cost information
                dollar_cost = None
                if 'cost' in run.summary:
                    dollar_cost = run.summary['cost']
                elif '_wandb' in run.summary and 'cost' in run.summary['_wandb']:
                    dollar_cost = run.summary['_wandb']['cost']
                elif 'system' in run.summary and 'cost' in run.summary['system']:
                    dollar_cost = run.summary['system']['cost']
                
                # If no cost found, estimate based on runtime (assuming $0.50 per hour as placeholder)
                if dollar_cost is None:
                    runtime_hours = metadata['runtime_seconds'] / 3600
                    dollar_cost = runtime_hours * 0.5  # $0.50 per hour estimate
                    
                metadata['dollar_cost'] = dollar_cost
                run_metadata.append(metadata)
                
                if protein_obs:
                    obs = deep_clean(protein_obs)
                    if 'suggestion' not in obs and protein_suggestion:
                        obs['suggestion'] = deep_clean(protein_suggestion)
                    obs['timestamp'] = run.created_at
                    obs['runtime'] = run.summary.get('_wandb', {}).get('runtime', 0)
                    obs['dollar_cost'] = dollar_cost
                    observations.append(obs)
                elif protein_suggestion:
                    obs = {
                        'suggestion': deep_clean(protein_suggestion),
                        'objective': run.summary.get('score', np.nan),
                        'cost': run.summary.get('_wandb', {}).get('runtime', 0),
                        'timestamp': run.created_at,
                        'dollar_cost': dollar_cost
                    }
                    observations.append(obs)
            
            # Convert to DataFrame
            if observations:
                df_rows = []
                for obs in observations:
                    row = {}
                    if 'suggestion' in obs and isinstance(obs['suggestion'], dict):
                        row.update(flatten_nested_dict(obs['suggestion']))
                    row['score'] = obs.get('objective', np.nan)
                    row['runtime'] = obs.get('runtime', obs.get('cost', 0))
                    row['timestamp'] = obs.get('timestamp')
                    row['dollar_cost'] = obs.get('dollar_cost', 0)
                    df_rows.append(row)
                
                df = pd.DataFrame(df_rows)
                if 'timestamp' in df.columns:
                    df['timestamp'] = pd.to_datetime(df['timestamp'])
                    df = df.sort_values('timestamp')
                
                return df, run_metadata
            
            return pd.DataFrame(), run_metadata
            
        except Exception as e:
            print(f"Error fetching data: {str(e)}")
            return pd.DataFrame(), []
    
    def update_dashboard(self):
        """Update all dashboard components."""
        df, run_metadata = self.fetch_sweep_data()
        
        # Update status cards
        self.update_status_cards(df, run_metadata)
        
        # Update plots
        if not df.empty:
            self.update_score_plot(df)
            self.update_importance_plot(df)
    
    def update_status_cards(self, df: pd.DataFrame, run_metadata: List):
        """Update the status cards with current metrics."""
        
        if df.empty:
            total_runs = len(run_metadata)
            best_score = "N/A"
            # Calculate total cost from run metadata
            total_cost = sum(m.get('dollar_cost', 0) for m in run_metadata)
            runtime = "0s"
        else:
            total_runs = len(df)
            best_score = f"{df['score'].max():.4f}" if 'score' in df.columns else "N/A"
            # Use actual dollar cost from data
            total_cost = df['dollar_cost'].sum() if 'dollar_cost' in df.columns else 0
            
            runtime_seconds = (datetime.now() - self.start_time).total_seconds()
            runtime = format_duration(runtime_seconds)
        
        html = f"""
        <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 15px; margin: 20px 0;">
            <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; color: white;">
                <div style="font-size: 14px; opacity: 0.9;">💰 Total Cost</div>
                <div style="font-size: 28px; font-weight: bold;">${total_cost:.2f}</div>
            </div>
            <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; color: white;">
                <div style="font-size: 14px; opacity: 0.9;">📊 Total Runs</div>
                <div style="font-size: 28px; font-weight: bold;">{total_runs}</div>
            </div>
            <div style="background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); padding: 20px; border-radius: 10px; color: white;">
                <div style="font-size: 14px; opacity: 0.9;">🏆 Best Score</div>
                <div style="font-size: 28px; font-weight: bold;">{best_score}</div>
            </div>
            <div style="background: linear-gradient(135deg, #43e97b 0%, #38f9d7 100%); padding: 20px; border-radius: 10px; color: white;">
                <div style="font-size: 14px; opacity: 0.9;">⏱️ Runtime</div>
                <div style="font-size: 28px; font-weight: bold;">{runtime}</div>
            </div>
        </div>
        """
        
        self.status_html.value = html
    

    def update_score_plot(self, df: pd.DataFrame):
        """Update the score over time plot."""
        
        if 'timestamp' not in df.columns or 'score' not in df.columns:
            return
        
        # Clear existing traces
        self.score_plot.data = []
        
        # Add scatter plot of all scores
        self.score_plot.add_trace(
            go.Scatter(
                x=df['timestamp'],
                y=df['score'],
                mode='markers',
                name='Scores',
                marker=dict(
                    size=8,
                    color=df['score'],
                    colorscale='Viridis',
                    showscale=True,
                    colorbar=dict(title="Score")
                )
            )
        )
        
        # Add best score line
        df_sorted = df.sort_values('timestamp')
        df_sorted['best_score'] = df_sorted['score'].cummax()
        
        self.score_plot.add_trace(
            go.Scatter(
                x=df_sorted['timestamp'],
                y=df_sorted['best_score'],
                mode='lines',
                name='Best Score',
                line=dict(color='red', width=2, dash='dash')
            )
        )
    
    def update_importance_plot(self, df: pd.DataFrame):
        """Update parameter importance plot - matches sweep_analysis notebook exactly."""
        
        if 'score' not in df.columns:
            return
        
        # Select numeric columns excluding metrics
        metric_cols = ['score', 'cost', 'runtime', 'timestamp', 'dollar_cost']
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        param_cols = [col for col in numeric_cols if col not in metric_cols]
        
        if len(param_cols) == 0:
            return
            
        # Calculate correlations with score (keeping sign, not absolute)
        correlations = {}
        for col in param_cols:
            # Only include if we have enough unique values
            if df[col].nunique() >= 2:
                corr = df[col].corr(df['score'])
                if not np.isnan(corr):
                    correlations[col] = corr
        
        if not correlations:
            return
            
        # Sort by absolute correlation and take top parameters
        importance_data = pd.DataFrame([
            {'parameter': k, 'correlation': v, 'abs_correlation': abs(v)}
            for k, v in correlations.items()
        ]).sort_values('abs_correlation', ascending=False).head(TOP_K_PARAMETERS)
        
        # Clear and update plot
        self.importance_plot.data = []
        
        # Create color based on correlation direction (red for negative, blue for positive)
        colors = ['darkred' if x < 0 else 'darkblue' for x in importance_data['correlation']]
        
        # Add bar trace showing absolute correlation
        self.importance_plot.add_trace(
            go.Bar(
                x=importance_data['abs_correlation'],
                y=importance_data['parameter'],
                orientation='h',
                marker=dict(
                    color=colors,
                    opacity=0.7
                ),
                text=[f"{corr:.3f}" for corr in importance_data['correlation']],
                textposition='outside',
                textfont=dict(size=9),
                hovertemplate='%{y}<br>Correlation: %{text}<extra></extra>'
            )
        )
        
        # Update layout to match sweep_analysis notebook
        self.importance_plot.update_layout(
            title=f"Top {len(importance_data)} Most Important Parameters",
            xaxis_title="Absolute Correlation with Score",
            yaxis_title="Parameter",
            xaxis=dict(range=[0, 1.1]),
            showlegend=False,
            height=400
        )
    
    def start_auto_update(self, interval_seconds: int = 30):
        """Start automatic dashboard updates."""
        
        def update_loop():
            while not self.stop_updates:
                self.update_dashboard()
                time.sleep(interval_seconds)
        
        self.stop_updates = False
        self.update_thread = threading.Thread(target=update_loop)
        self.update_thread.daemon = True
        self.update_thread.start()
    
    def stop_auto_update(self):
        """Stop automatic updates."""
        self.stop_updates = True
        if self.update_thread:
            self.update_thread.join(timeout=5)
    

    def display(self):
        """Display the dashboard."""
        display(self.dashboard)
        self.update_dashboard()
        self.start_auto_update(UPDATE_INTERVAL_SECONDS)


In [None]:
class EnhancedSweepDashboard(SweepDashboard):
    """Enhanced dashboard with additional analytics and visualizations."""
    
    def __init__(self, sweep_name: str):
        super().__init__(sweep_name)
        self.convergence_plot = widgets.Output()
        self.efficiency_metrics = widgets.Output()
        self.parameter_heatmap = widgets.Output()
        
    def display(self):
        """Display the enhanced dashboard."""
        # Original dashboard
        super().display()
        
        # Additional sections
        print("\n📊 Advanced Analytics")
        
        # Efficiency metrics
        display(HTML("<h3>💰 Efficiency Metrics</h3>"))
        display(self.efficiency_metrics)
        
        # Convergence plot
        display(HTML("<h3>📈 Convergence Analysis</h3>"))
        display(self.convergence_plot)
        
        # Parameter heatmap placeholder
        display(HTML("<h3>🗺️ Parameter Interactions</h3>"))
        display(self.parameter_heatmap)
        
        # Start auto-update
        self.start_auto_update()
        
    def update_dashboard(self):
        """Update all dashboard components including enhanced features."""
        # Call parent update
        super().update_dashboard()
        
        # Get the dataframe
        df, _ = self.fetch_sweep_data()
        
        if not df.empty:
            # Update efficiency metrics
            with self.efficiency_metrics:
                clear_output(wait=True)
                metrics = calculate_efficiency_metrics(df)
                
                # Display as cards
                html = '<div style="display: flex; gap: 10px; flex-wrap: wrap;">'
                
                if metrics:
                    html += f'''
                    <div style="background: #f0f8ff; padding: 10px; border-radius: 5px; flex: 1;">
                        <b>Avg Score/$</b><br>{metrics["avg_score_per_dollar"]:.6f}
                    </div>
                    <div style="background: #f0fff0; padding: 10px; border-radius: 5px; flex: 1;">
                        <b>Best Score/$</b><br>{metrics["best_score_per_dollar"]:.6f}
                    </div>
                    <div style="background: #fff0f5; padding: 10px; border-radius: 5px; flex: 1;">
                        <b>Success Rate</b><br>{metrics["success_rate"]:.1f}%
                    </div>
                    <div style="background: #fffaf0; padding: 10px; border-radius: 5px; flex: 1;">
                        <b>Improvement/Run</b><br>{metrics["improvement_rate"]:.6f}
                    </div>
                    '''
                
                html += '</div>'
                display(HTML(html))
            
            # Update convergence plot
            with self.convergence_plot:
                clear_output(wait=True)
                fig = create_convergence_plot(df)
                fig.show()
            
            # Update parameter heatmap (pick top 2 important params)
            with self.parameter_heatmap:
                clear_output(wait=True)
                
                # Get numeric columns for heatmap
                numeric_cols = [col for col in df.columns 
                               if col not in ['score', 'cost', 'runtime', 'timestamp'] 
                               and df[col].dtype in ['float64', 'int64']]
                
                if len(numeric_cols) >= 2:
                    # Calculate correlations and pick top 2
                    correlations = {}
                    for col in numeric_cols:
                        if df[col].nunique() > 1:
                            correlations[col] = abs(df[col].corr(df['score']))
                    
                    if len(correlations) >= 2:
                        top_params = sorted(correlations.items(), key=lambda x: x[1], reverse=True)[:2]
                        param1, param2 = top_params[0][0], top_params[1][0]
                        
                        # Only create heatmap if we have enough unique values
                        if df[param1].nunique() > 2 and df[param2].nunique() > 2:
                            fig = create_parameter_heatmap(df, param1, param2)
                            fig.show()
                        else:
                            display(HTML("<i>Not enough parameter variation for heatmap</i>"))
                    else:
                        display(HTML("<i>Insufficient parameters for interaction analysis</i>"))
                else:
                    display(HTML("<i>Waiting for more data...</i>"))


## 🚀 Launch Sweep


In [18]:
# Initialize the launcher
launcher = SweepLauncher(SWEEP_NAME)

# Launch workers on SkyPilot
worker_jobs = launcher.launch_workers(
    num_workers=NUM_PARALLEL_WORKERS,
    gpus_per_worker=GPUS_PER_WORKER,
    nodes_per_worker=NODES_PER_WORKER
)


🚀 Launching 2 sweep workers...
   Sweep: axel.late_night_parallel
   GPUs per worker: 1
   Nodes per worker: 1
   Git commit: c1a6e966
   Launching worker 1/2: axel.late_night_parallel
['./devops/skypilot/launch.py', '--gpus=1', '--nodes=1', '--skip-git-check', 'sweep', 'run=axel.late_night_parallel']
   ✅ Worker 1 launched successfully
   Launching worker 2/2: axel.late_night_parallel
['./devops/skypilot/launch.py', '--gpus=1', '--nodes=1', '--skip-git-check', 'sweep', 'run=axel.late_night_parallel']
   ✅ Worker 2 launched successfully

✅ Launched 2/2 workers successfully


## 📊 Live Worker Monitoring

This shows the real-time status of all SkyPilot jobs (auto-refreshes every 5 seconds):


In [19]:
# Display live SkyPilot jobs queue
# Uses UPDATE_INTERVAL_SECONDS from configuration
sky_monitor = LiveSkyJobsMonitor(SWEEP_NAME)
sky_monitor.display()

# This will show the auto-updating sky jobs queue output
# The monitor will continue running in the background


🔍 Starting live monitoring for sweep: axel.late_night_parallel
   Refreshing every 120 seconds
   Press Ctrl+C in the terminal to stop monitoring
Running: watch -n120 -t 'sky jobs queue -s'
Note: This will run in your terminal. Check your terminal for the live output.

Alternatively, run this in a terminal:
  watch -n120 -t 'sky jobs queue -s'


## 📈 Sweep Metrics Dashboard

Monitor the sweep progress with WandB metrics:


In [20]:
# Create and display the metrics dashboard
metrics_dashboard = SweepDashboard(SWEEP_NAME)
metrics_dashboard.display()

# The dashboard will auto-update every UPDATE_INTERVAL_SECONDS
# Use the refresh button to update manually


VBox(children=(HTML(value='<h2>📊 Sweep Metrics Dashboard</h2>'), HTML(value='<h3>Loading...</h3>', layout=Layo…

## 🔍 Detailed Analysis

Fetch and analyze the full sweep data:


In [21]:
# Fetch full sweep data for detailed analysis
df, run_metadata = metrics_dashboard.fetch_sweep_data()

if not df.empty:
    print(f"📊 Sweep Statistics:")
    print(f"   Total runs: {len(df)}")
    print(f"   Best score: {df['score'].max():.4f}")
    print(f"   Average score: {df['score'].mean():.4f}")
    print(f"   Total runtime: {df['runtime'].sum()/3600:.1f} hours")
    
    # Calculate total cost
    total_cost = sum(m.get('dollar_cost', 0) for m in run_metadata)
    print(f"   Total cost: ${total_cost:.2f}")
    print(f"\n📈 Top 5 runs:")
    
    # Show top 5 runs
    top_runs = df.nlargest(5, 'score')[['score', 'runtime', 'timestamp']]
    display(top_runs)
    
    # Get parameter columns
    param_cols = [col for col in df.columns 
                 if col not in ['score', 'runtime', 'timestamp', 'cost']]
    
    if param_cols:
        print(f"\n🔧 Parameter ranges explored:")
        for col in param_cols[:10]:  # Show first 10 parameters
            if df[col].dtype in [np.float64, np.int64]:
                print(f"   {col}: [{df[col].min():.9f}, {df[col].max():.9f}]")
else:
    print("No data available yet. Wait for workers to start producing results.")


No data available yet. Wait for workers to start producing results.


In [None]:
# Stop monitoring
sky_monitor.stop()
metrics_dashboard.stop_auto_update()

# Stop all workers
print(f"🛑 Stopping all workers for sweep: {SWEEP_NAME}")
launcher.stop_all_workers()

print("\n✅ Sweep stopped successfully!")