# Sweep Analysis Notebook

This notebook provides comprehensive analysis and visualization of hyperparameter sweep results from WandB.

## Features:
- Summary statistics of sweep parameters and metrics
- Cost vs Score scatter plot with level sets
- Time vs Score plot
- Time vs Cost plot
- Parameter correlation matrix
- Parameter importance analysis


In [None]:
# Settings Cell - Configure your sweep analysis here

# WandB Configuration
WANDB_ENTITY = "metta-research"  # Replace with your WandB entity
WANDB_PROJECT = "metta"  # Replace with your WandB project
WANDB_SWEEP_NAME = "axel.sky_sweep.v2"  # Replace with your sweep ID
WANDB_SWEEP_ID = None
SWEEP_SERVER_URI = "https://api.observatory.softmax-research.net"
# Analysis Configuration
MAX_OBSERVATIONS = 1000  # Maximum number of observations to load
TOP_K_PARAMETERS = 15  # Number of top parameters to show in importance plot
FIGURE_SIZE = (10, 8)  # Default figure size for plots
COLOR_PALETTE = "viridis"  # Color palette for visualizations


In [7]:
# Import required libraries
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.interpolate import griddata
from scipy.stats import gaussian_kde
from typing import Dict, List, Any, Tuple
import warnings
warnings.filterwarnings('ignore')

# Add the metta module to path if running from notebooks directory
sys.path.append(os.path.abspath('..'))

# Import sweep utilities
from metta.sweep.wandb_utils import fetch_protein_observations_from_wandb
from cogweb.cogweb_client import CogwebClient
# Set seaborn style
sns.set_style("whitegrid")
sns.set_palette(COLOR_PALETTE)


In [8]:
# Helper functions for data processing

def flatten_nested_dict(d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
    """Flatten a nested dictionary structure."""
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def extract_observations_to_dataframe(observations: List[Dict[str, Any]]) -> pd.DataFrame:
    """Convert protein observations to a pandas DataFrame."""
    data_rows = []
    
    for obs in observations:
        if obs.get('is_failure', False):
            continue
            
        suggestion = obs.get('suggestion', {})
        flat_params = flatten_nested_dict(suggestion)
        
        # Add metrics
        flat_params['score'] = obs.get('objective', np.nan)
        flat_params['cost'] = obs.get('cost', np.nan)
        flat_params['time'] = obs.get('cost', np.nan)  # Using cost as time proxy
        
        data_rows.append(flat_params)
    
    return pd.DataFrame(data_rows)


In [10]:
# Load sweep data
print(f"Loading sweep data from WandB...")
print(f"Entity: {WANDB_ENTITY}")
print(f"Project: {WANDB_PROJECT}")
print(f"Sweep ID: {WANDB_SWEEP_ID}")

cogweb_client = CogwebClient.get_client(base_url=SWEEP_SERVER_URI)
sweep_client = cogweb_client.sweep_client()
WANDB_SWEEP_ID = sweep_client.get_sweep(WANDB_SWEEP_NAME).wandb_sweep_id

observations = fetch_protein_observations_from_wandb(
    wandb_entity=WANDB_ENTITY,
    wandb_project=WANDB_PROJECT,
    wandb_sweep_id=WANDB_SWEEP_ID,
    max_observations=MAX_OBSERVATIONS
)

print(f"\nLoaded {len(observations)} observations")

# Convert to DataFrame
df = extract_observations_to_dataframe(observations)
print(f"Valid observations after filtering: {len(df)}")


Loading sweep data from WandB...
Entity: metta-research
Project: metta
Sweep ID: None


UnsupportedProtocol: Request URL is missing an 'http://' or 'https://' protocol.

In [None]:
# Calculate summary statistics
if not df.empty:
    # Identify parameter columns (exclude metrics)
    metric_cols = ['score', 'cost', 'time']
    param_cols = [col for col in df.columns if col not in metric_cols]
    
    # Get unique parameter values
    param_summary = {}
    for col in param_cols:
        if pd.api.types.is_numeric_dtype(df[col]):
            param_summary[col] = {
                'type': 'numeric',
                'min': df[col].min(),
                'max': df[col].max(),
                'mean': df[col].mean(),
                'std': df[col].std(),
                'unique_values': df[col].nunique()
            }
        else:
            param_summary[col] = {
                'type': 'categorical',
                'unique_values': df[col].nunique(),
                'values': df[col].unique().tolist()[:10]  # Show first 10 unique values
            }
    
    # Display sweep summary
    print("=" * 80)
    print("SWEEP SUMMARY")
    print("=" * 80)
    
    print(f"\nTotal Runs: {len(df)}")
    print(f"Successful Runs: {len(df[df['score'].notna()])}")
    print(f"Failed Runs: {len(observations) - len(df)}")
    
    print(f"\nMetrics Summary:")
    print(f"  Average Score: {df['score'].mean():.4f} (±{df['score'].std():.4f})")
    print(f"  Best Score: {df['score'].max():.4f}")
    print(f"  Worst Score: {df['score'].min():.4f}")
    
    print(f"\n  Average Run Time: {df['time'].mean():.2f} seconds")
    print(f"  Total Run Time: {df['time'].sum():.2f} seconds ({df['time'].sum()/3600:.2f} hours)")
    
    print(f"\n  Average Cost: {df['cost'].mean():.2f}")
    print(f"  Total Cost: {df['cost'].sum():.2f}")
    
    print(f"\nParameters Being Swept ({len(param_cols)} total):")
    for i, (param, info) in enumerate(param_summary.items()):
        if i < 10:  # Show first 10 parameters
            if info['type'] == 'numeric':
                print(f"  - {param}: {info['min']:.4f} to {info['max']:.4f} ({info['unique_values']} unique values)")
            else:
                print(f"  - {param}: {info['unique_values']} unique values")
    
    if len(param_cols) > 10:
        print(f"  ... and {len(param_cols) - 10} more parameters")
else:
    print("No valid observations found in the sweep data.")


In [None]:
if not df.empty and 'cost' in df.columns and 'score' in df.columns:
    fig, ax = plt.subplots(figsize=FIGURE_SIZE)
    
    # Remove any NaN values
    plot_df = df[['cost', 'score']].dropna()
    
    if len(plot_df) > 0:
        # Create scatter plot
        scatter = ax.scatter(plot_df['cost'], plot_df['score'], 
                           alpha=0.6, s=50, edgecolors='black', linewidth=0.5)
        
        # Add level sets if we have enough points
        if len(plot_df) > 10:
            try:
                # Create a grid for interpolation
                xi = np.linspace(plot_df['cost'].min(), plot_df['cost'].max(), 100)
                yi = np.linspace(plot_df['score'].min(), plot_df['score'].max(), 100)
                xi, yi = np.meshgrid(xi, yi)
                
                # Interpolate the data
                zi = griddata((plot_df['cost'], plot_df['score']), 
                            plot_df['score'], 
                            (xi, yi), 
                            method='cubic')
                
                # Create filled contours
                contourf = ax.contourf(xi, yi, zi, levels=10, alpha=0.3, cmap='viridis')
                
                # Add contour lines
                contour = ax.contour(xi, yi, zi, levels=10, colors='gray', 
                                   alpha=0.4, linewidths=0.5)
                
                # Add colorbar
                cbar = plt.colorbar(contourf, ax=ax)
                cbar.set_label('Score Level', rotation=270, labelpad=20)
            except:
                print("Could not generate level sets - using scatter plot only")
        
        # Highlight best point
        best_idx = plot_df['score'].idxmax()
        ax.scatter(float(plot_df.loc[best_idx, 'cost']), 
                  float(plot_df.loc[best_idx, 'score']),
                  color='red', s=200, marker='*', 
                  edgecolors='darkred', linewidth=2,
                  label=f'Best Score: {plot_df.loc[best_idx, "score"]:.4f}')
        
        ax.set_xlabel('Cost', fontsize=12)
        ax.set_ylabel('Score', fontsize=12)
        ax.set_title('Cost vs Score Analysis', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    else:
        print("No valid cost/score data to plot")
else:
    print("Missing cost or score data")


In [None]:
if not df.empty and 'time' in df.columns and 'score' in df.columns:
    fig, ax = plt.subplots(figsize=FIGURE_SIZE)
    
    # Remove NaN values
    plot_df = df[['time', 'score']].dropna()
    
    if len(plot_df) > 0:
        # Create scatter plot with regression line
        sns.regplot(data=plot_df, x='time', y='score', 
                   scatter_kws={'alpha': 0.6, 's': 50}, 
                   line_kws={'color': 'red', 'linewidth': 2},
                   ax=ax)
        
        # Calculate correlation
        correlation = plot_df['time'].corr(plot_df['score'])
        
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_ylabel('Score', fontsize=12)
        ax.set_title(f'Time vs Score Analysis\n(Correlation: {correlation:.3f})', 
                    fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    else:
        print("No valid time/score data to plot")
else:
    print("Missing time or score data")


In [None]:
if not df.empty and 'time' in df.columns and 'cost' in df.columns:
    fig, ax = plt.subplots(figsize=FIGURE_SIZE)
    
    # Remove NaN values
    plot_df = df[['time', 'cost']].dropna()
    
    if len(plot_df) > 0:
        # Create scatter plot with regression line
        sns.regplot(data=plot_df, x='time', y='cost', 
                   scatter_kws={'alpha': 0.6, 's': 50}, 
                   line_kws={'color': 'red', 'linewidth': 2},
                   ax=ax)
        
        # Calculate correlation
        correlation = plot_df['time'].corr(plot_df['cost'])
        
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_ylabel('Cost', fontsize=12)
        ax.set_title(f'Time vs Cost Analysis\n(Correlation: {correlation:.3f})', 
                    fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    else:
        print("No valid time/cost data to plot")
else:
    print("Missing time or cost data")


In [None]:
if not df.empty:
    # Select only numeric columns
    numeric_df = df.select_dtypes(include=[np.number])
    
    # Remove columns with too few unique values
    valid_columns = []
    for col in numeric_df.columns:
        if numeric_df[col].nunique() >= 2:
            valid_columns.append(col)
    
    if len(valid_columns) > 1:
        # Calculate correlation matrix
        corr_matrix = numeric_df[valid_columns].corr()
        
        # If we have too many parameters, focus on those most correlated with score
        if 'score' in corr_matrix.columns and len(corr_matrix) > 20:
            # Get top correlated parameters with score
            score_corr = corr_matrix['score'].abs().sort_values(ascending=False)
            top_params = score_corr.head(20).index.tolist()
            corr_matrix = corr_matrix.loc[top_params, top_params]
        
        # Create heatmap
        plt.figure(figsize=(12, 10))
        
        # Create mask for upper triangle
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
        
        # Draw heatmap
        sns.heatmap(corr_matrix, 
                   mask=mask,
                   annot=True, 
                   fmt='.2f', 
                   cmap='RdBu_r',
                   vmin=-1, 
                   vmax=1, 
                   center=0,
                   square=True,
                   linewidths=0.5,
                   cbar_kws={"shrink": 0.8})
        
        plt.title('Parameter Correlation Matrix', fontsize=16, fontweight='bold', pad=20)
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
    else:
        print("Not enough numeric parameters for correlation analysis")
else:
    print("No data available for correlation analysis")


In [None]:
if not df.empty and 'score' in df.columns:
    # Select numeric columns excluding metrics
    metric_cols = ['score', 'cost', 'time']
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    param_cols = [col for col in numeric_cols if col not in metric_cols]
    
    if len(param_cols) > 0:
        # Calculate correlations with score
        correlations = {}
        for col in param_cols:
            # Only include if we have enough unique values
            if df[col].nunique() >= 2:
                corr = df[col].corr(df['score'])
                if not np.isnan(corr):
                    correlations[col] = corr
        
        if correlations:
            # Sort by absolute correlation
            importance_df = pd.DataFrame([
                {'parameter': k, 'correlation': v, 'abs_correlation': abs(v)}
                for k, v in correlations.items()
            ]).sort_values('abs_correlation', ascending=False).head(TOP_K_PARAMETERS)
            
            # Create importance plot
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # Create color map based on correlation direction
            colors = ['darkred' if x < 0 else 'darkblue' for x in importance_df['correlation']]
            
            # Create horizontal bar plot
            bars = ax.barh(importance_df['parameter'], 
                          importance_df['abs_correlation'],
                          color=colors, alpha=0.7)
            
            # Add correlation values as text
            for i, (idx, row) in enumerate(importance_df.iterrows()):
                ax.text(row['abs_correlation'] + 0.01, i, 
                       f"{row['correlation']:.3f}",
                       va='center', fontsize=9)
            
            ax.set_xlabel('Absolute Correlation with Score', fontsize=12)
            ax.set_ylabel('Parameter', fontsize=12)
            ax.set_title(f'Top {len(importance_df)} Most Important Parameters', 
                        fontsize=14, fontweight='bold')
            ax.set_xlim(0, 1.1)
            ax.grid(True, axis='x', alpha=0.3)
            
            # Add legend
            from matplotlib.patches import Patch
            legend_elements = [
                Patch(facecolor='darkblue', alpha=0.7, label='Positive correlation'),
                Patch(facecolor='darkred', alpha=0.7, label='Negative correlation')
            ]
            ax.legend(handles=legend_elements, loc='lower right')
            
            plt.tight_layout()
            plt.show()
            
            # Print detailed importance table
            print("\nParameter Importance Details:")
            print("=" * 60)
            for idx, row in importance_df.iterrows():
                direction = "↑" if row['correlation'] > 0 else "↓"
                print(f"{row['parameter']:30s} {direction} {row['correlation']:+.4f}")
        else:
            print("No valid parameter correlations found")
    else:
        print("No numeric parameters found for importance analysis")
else:
    print("No score data available for importance analysis")


In [None]:
# Show top performing configurations
if not df.empty and 'score' in df.columns:
    # Get top 5 configurations
    top_configs = df.nlargest(5, 'score')
    
    print("\nTop 5 Best Performing Configurations:")
    print("=" * 80)
    
    for i, (idx, row) in enumerate(top_configs.iterrows(), 1):
        print(f"\nRank {i}: Score = {row['score']:.4f}, Cost = {row.get('cost', 'N/A')}, Time = {row.get('time', 'N/A')}")
        print("-" * 40)
        
        # Show only parameters that differ from the median
        param_cols = [col for col in df.columns if col not in ['score', 'cost', 'time']]
        for col in param_cols:
            if pd.api.types.is_numeric_dtype(df[col]):
                median_val = df[col].median()
                if abs(row[col] - median_val) > 0.01 * abs(median_val):  # More than 1% different
                    print(f"  {col}: {row[col]:.4f} (median: {median_val:.4f})")


In [None]:
# Optional: Export processed data for further analysis
if not df.empty:
    output_file = f"sweep_analysis_{WANDB_SWEEP_ID}.csv"
    df.to_csv(output_file, index=False)
    print(f"\nExported sweep data to: {output_file}")
    print(f"Total rows exported: {len(df)}")
