# Interactive Heatmap Widget Example

This notebook demonstrates how to use the `HeatmapWidget` - an anywidget-based implementation of the observatory heatmap component for Jupyter notebooks.

The widget provides interactive policy evaluation heatmaps with:
- Hover effects showing detailed information
- Double-click to open replay URLs
- Dynamic control over number of policies displayed
- Automatic organization by evaluation categories


## Installation

First, make sure you have the required dependencies:
`pip install anywidget traitlets`


## Import and Basic Setup


In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from experiments.notebooks.utils.metrics import fetch_metrics
from experiments.notebooks.utils.monitoring import monitor_training_statuses
from experiments.notebooks.utils.replays import show_replay
from experiments.notebooks.utils.training import launch_training
from experiments.notebooks.utils.metrics import find_training_jobs

%matplotlib inline
plt.style.use("default")

# Add utils directory to path
sys.path.append(os.path.join(os.getcwd(), 'utils'))

%load_ext anywidget

print("Setup complete! Auto-reload enabled.")


## Example 1: Demo Heatmap with Sample Data

Let's start with a simple demo that includes sample data:


In [None]:
from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_demo_heatmap, create_heatmap_widget

# Create a demo heatmap with sample data
demo_widget = create_demo_heatmap()

# Display the widget
demo_widget


In [17]:
import pandas as pd
from pathlib import Path
from typing import Dict, List, Optional, Set
import logging

# Add metta paths
import sys
sys.path.append('/Users/zfogg/src/github.com/metta-ai/metta')

from metta.eval.eval_stats_db import EvalStatsDB
from metta.agent.policy_record import PolicyRecord
from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_heatmap_widget

def fetch_real_heatmap_data(
    policy_names: List[str], 
    metrics: List[str],
    eval_db_uri: str = "wandb://stats/navigation_db",
    eval_filter: Optional[str] = None,
    max_policies: int = 20
) -> HeatmapWidget:
    """
    Fetch real evaluation data from metta's database and create a heatmap widget.
    
    Args:
        policy_names: List of policy names/URIs to include (e.g., ["my_policy:v123", "other_policy:v456"])
        metrics: List of metrics to fetch (e.g., ["reward", "heart.get", "action.move.success"])
        eval_db_uri: URI to the evaluation database (default: navigation_db)
        eval_filter: Optional SQL filter for evaluations (e.g., "sim_name LIKE '%maze%'")
        max_policies: Maximum number of policies to display
        
    Returns:
        HeatmapWidget with real data from the database
    """
    print(f"🔍 Fetching real evaluation data from: {eval_db_uri}")
    print(f"📋 Policy names: {policy_names}")
    print(f"📊 Metrics: {metrics}")
    
    # Connect to the evaluation database
    with EvalStatsDB.from_uri(eval_db_uri) as stats_db:
        # Store all data for all metrics
        all_metric_data = {}
        all_eval_names = set()
        valid_policy_names = set()
        
        # Fetch data for each metric
        for metric in metrics:
            print(f"📈 Fetching metric: {metric}")
            
            # Get all policy-eval combinations for this metric
            df = stats_db.metric_by_policy_eval(metric, policy_record=None)
            
            if df.empty:
                print(f"⚠️  No data found for metric: {metric}")
                continue
                
            # Filter to requested policies if specified
            if policy_names:
                df = df[df['policy_uri'].isin(policy_names)]
                
            if df.empty:
                print(f"⚠️  No data found for requested policies in metric: {metric}")
                continue
                
            # Apply evaluation filter if specified
            if eval_filter:
                # We need to query the database again with the filter
                policy_clause = "1=1"
                if policy_names:
                    policy_uris_str = "', '".join(policy_names)
                    policy_clause = f"(policy_key || ':v' || policy_version) IN ('{policy_uris_str}')"
                    
                sql = f"""
                WITH potential AS (
                    SELECT policy_key, policy_version, sim_env, COUNT(*) AS potential_cnt
                      FROM policy_simulation_agent_samples
                     WHERE {policy_clause} AND ({eval_filter})
                     GROUP BY policy_key, policy_version, sim_env
                ),
                recorded AS (
                    SELECT policy_key,
                           policy_version,
                           sim_env,
                           SUM(value) AS recorded_sum
                      FROM policy_simulation_agent_metrics
                     WHERE metric = '{metric}' AND {policy_clause}
                     GROUP BY policy_key, policy_version, sim_env
                )
                SELECT
                    potential.policy_key || ':v' || potential.policy_version AS policy_uri,
                    potential.sim_env AS eval_name,
                    COALESCE(recorded.recorded_sum, 0) * 1.0 / potential.potential_cnt AS value
                FROM potential
                LEFT JOIN recorded USING (policy_key, policy_version, sim_env)
                ORDER BY policy_uri, eval_name
                """
                df = stats_db.query(sql)
                
            # Store the data for this metric
            all_metric_data[metric] = df
            all_eval_names.update(pd.Series(df['eval_name']).unique())
            valid_policy_names.update(pd.Series(df['policy_uri']).unique())
            
            print(f"   Found {len(df)} policy-eval combinations")
        
        if not all_metric_data:
            print("❌ No data found for any metrics!")
            return create_heatmap_widget()
            
        # Convert to lists and sort
        eval_names = sorted(list(all_eval_names))
        policy_names_list = sorted(list(valid_policy_names))
        
        # Limit number of policies if requested
        if len(policy_names_list) > max_policies:
            print(f"🔢 Limiting to {max_policies} policies (found {len(policy_names_list)})")
            # Calculate average scores to pick the best policies
            first_metric = next(iter(all_metric_data.keys()))
            avg_scores = all_metric_data[first_metric].groupby('policy_uri')['value'].mean().sort_values(ascending=False)
            policy_names_list = avg_scores.head(max_policies).index.tolist()
        
        print(f"📊 Final dataset: {len(policy_names_list)} policies × {len(eval_names)} evaluations")
        
        # Build the cells data structure for the widget
        cells = {}
        
        for policy_name in policy_names_list:
            cells[policy_name] = {}
            
            for eval_name in eval_names:
                # Create the metrics dict for this cell
                cell_metrics = {}
                
                for metric in metrics:
                    if metric in all_metric_data:
                        df = all_metric_data[metric]
                        # Find the value for this policy-eval combination
                        match = df[(df['policy_uri'] == policy_name) & (df['eval_name'] == eval_name)]
                        if not match.empty:
                            cell_metrics[metric] = float(match['value'].iloc[0])
                        else:
                            cell_metrics[metric] = 0.0
                    else:
                        cell_metrics[metric] = 0.0
                
                # Create the cell with metrics and metadata
                cells[policy_name][eval_name] = {
                    'metrics': cell_metrics,
                    'replayUrl': f"https://example.com/replay/{policy_name}/{eval_name}.json",  # Placeholder
                    'evalName': eval_name
                }
        
        # Create and configure the widget
        widget = create_heatmap_widget()
        
        widget.set_multi_metric_data(
            cells=cells,
            eval_names=eval_names,
            policy_names=policy_names_list,
            metrics=metrics,
            selected_metric=metrics[0] if metrics else "reward"
        )
        
        print("✅ Successfully created heatmap widget with real data!")
        return widget


def get_available_policy_names(eval_db_uri: str = "wandb://stats/navigation_db", limit: int = 50) -> List[str]:
    """
    Get a list of available policy names from the database.
    
    Args:
        eval_db_uri: URI to the evaluation database
        limit: Maximum number of policy names to return
        
    Returns:
        List of policy URI strings
    """
    print(f"🔍 Fetching available policy names from: {eval_db_uri}")
    
    with EvalStatsDB.from_uri(eval_db_uri) as stats_db:
        sql = """
        SELECT DISTINCT policy_key || ':v' || policy_version AS policy_uri
        FROM policy_simulation_agent_samples
        ORDER BY policy_uri
        LIMIT ?
        """
        df = stats_db.con.execute(sql, [limit]).fetchdf()
        policy_names = df['policy_uri'].tolist()
        
    print(f"📋 Found {len(policy_names)} policies")
    return policy_names


def get_available_metrics(eval_db_uri: str = "wandb://stats/navigation_db", limit: int = 50) -> List[str]:
    """
    Get a list of available metrics from the database.
    
    Args:
        eval_db_uri: URI to the evaluation database
        limit: Maximum number of metrics to return
        
    Returns:
        List of metric names
    """
    print(f"🔍 Fetching available metrics from: {eval_db_uri}")
    
    with EvalStatsDB.from_uri(eval_db_uri) as stats_db:
        sql = """
        SELECT DISTINCT metric
        FROM policy_simulation_agent_metrics
        ORDER BY metric
        LIMIT ?
        """
        df = stats_db.con.execute(sql, [limit]).fetchdf()
        metric_names = df['metric'].tolist()
        
    print(f"📊 Found {len(metric_names)} metrics")
    return metric_names


def get_available_evaluations(eval_db_uri: str = "wandb://stats/navigation_db", limit: int = 100) -> List[str]:
    """
    Get a list of available evaluation names from the database.
    
    Args:
        eval_db_uri: URI to the evaluation database
        limit: Maximum number of evaluation names to return
        
    Returns:
        List of evaluation names
    """
    print(f"🔍 Fetching available evaluation names from: {eval_db_uri}")
    
    with EvalStatsDB.from_uri(eval_db_uri) as stats_db:
        sql = """
        SELECT DISTINCT sim_env AS eval_name
        FROM policy_simulation_agent_samples
        ORDER BY eval_name
        LIMIT ?
        """
        df = stats_db.con.execute(sql, [limit]).fetchdf()
        eval_names = df['eval_name'].tolist()
        
    print(f"🏃 Found {len(eval_names)} evaluations")
    return eval_names


print("🚀 Real data fetching functions loaded!")
print("📋 Available functions:")
print("   - fetch_real_heatmap_data(policy_names, metrics, eval_db_uri)")
print("   - get_available_policy_names(eval_db_uri)")  
print("   - get_available_metrics(eval_db_uri)")
print("   - get_available_evaluations(eval_db_uri)")


🚀 Real data fetching functions loaded!
📋 Available functions:
   - fetch_real_heatmap_data(policy_names, metrics, eval_db_uri)
   - get_available_policy_names(eval_db_uri)
   - get_available_metrics(eval_db_uri)
   - get_available_evaluations(eval_db_uri)


## Example: Using Real Data

Now let's explore what's available in the database and create a heatmap with real data:


In [52]:
# Let's first explore what's available in the database
# Note: This will download the database file from wandb, which may take a moment

# Uncomment these lines to explore available data:
# available_policies = get_available_policy_names(limit=10)
# print("Sample policies:", available_policies[:5])

# available_metrics = get_available_metrics(limit=20) 
# print("Sample metrics:", available_metrics[:10])

# available_evals = get_available_evaluations(limit=20)
# print("Sample evaluations:", available_evals[:10])


def select_best_policies_from_runs(
    training_runs: List[str], 
    eval_db_uri: str = "wandb://stats/navigation_db",
    metric: str = "reward",
    selector: str = "best"  # "best" or "latest"
) -> List[str]:
    """
    Select the best or latest policy from each training run using the same logic as heatmap_routes.py
    
    Args:
        training_runs: List of training run prefixes (e.g., ["run1", "run2"])
        eval_db_uri: URI to the evaluation database
        metric: Metric to use for "best" selection (only used if selector="best")
        selector: "best" (highest average score) or "latest" (highest epoch/version)
        
    Returns:
        List of selected policy URIs
    """
    print(f"🔍 Selecting {selector} policies from {len(training_runs)} training runs...")
    
    with EvalStatsDB.from_uri(eval_db_uri) as stats_db:
        # Get all available policies
        all_policies = get_available_policy_names(eval_db_uri, limit=1000)
        print("all_policies", len(all_policies), all_policies[:10])
        
        # Group policies by training run
        run_policies = {}
        for run in training_runs:
            run_policies[run] = [p for p in all_policies if p.startswith(run)]
            
        selected_policies = []
        
        for run, policies in run_policies.items():
            if not policies:
                print(f"⚠️  No policies found for run: {run}")
                continue
                
            if selector == "latest":
                # Select latest by version/epoch number
                def extract_version(policy_uri):
                    if ":" in policy_uri:
                        try:
                            return int(policy_uri.split(":")[-1].replace("v", ""))
                        except:
                            return 0
                    return 0
                
                best_policy = max(policies, key=extract_version)
                selected_policies.append(best_policy)
                print(f"📈 Latest for {run}: {best_policy}")
                
            elif selector == "best":
                # Select best by average performance - same logic as heatmap_routes.py
                policy_scores = {}
                policy_versions = {}
                
                # Get evaluation data for all policies in this run
                for policy_uri in policies:
                    df = stats_db.metric_by_policy_eval(metric, policy_record=None)
                    policy_data = df[df['policy_uri'] == policy_uri]
                    
                    if not policy_data.empty:
                        # Calculate average score across all evaluations
                        avg_score = policy_data['value'].mean()
                        policy_scores[policy_uri] = avg_score
                        
                        # Extract version for tie-breaking
                        try:
                            version = int(policy_uri.split(":")[-1].replace("v", "")) if ":" in policy_uri else 0
                        except:
                            version = 0
                        policy_versions[policy_uri] = version
                
                if policy_scores:
                    # Find best policy (highest score, ties broken by latest version)
                    best_policy = max(policy_scores.keys(), 
                                    key=lambda p: (policy_scores[p], policy_versions[p]))
                    selected_policies.append(best_policy)
                    print(f"🏆 Best for {run}: {best_policy} (score: {policy_scores[best_policy]:.3f})")
                else:
                    print(f"⚠️  No evaluation data found for policies in run: {run}")
        
        print(f"✅ Selected {len(selected_policies)} policies")
        return selected_policies


# For now, let's try with some common metrics and see what we find:
real_heatmap = None
try:
    # Example: Create a heatmap with some policies and common metrics
    # You can customize these based on what you find in your database
    
    # If you know specific policy names, specify them:
    specific_policies = []  # Empty list = fetch all available policies
    # specific_runs = [
    #     "daveey.arena.rnd.16x4.2",
    #     "relh.skypilot.fff.j20.666",
    #     "bullm.navigation.low_reward.baseline",
    #     "bullm.navigation.low_reward.baseline.07-17", 
    #     "bullm.navigation.low_reward.baseline.07-23",
    #     "relh.multigpu.fff.1",
    #     "relh.skypilot.fff.j21.2",
    # ]
    
    all_policies = get_available_policy_names(eval_db_uri="wandb://stats/navigation_db")
    all_runs = list(dict.fromkeys([p.split(":")[0] for p in all_policies]))
    print("all_runs", len(all_runs), all_runs[:10])
    # Select the best policies using the same logic as heatmap_routes.py
    print("🎯 Selecting best policies from training runs...")
    specific_policies = select_best_policies_from_runs(
        training_runs=all_runs,
        eval_db_uri="wandb://stats/navigation_db",
        metric="reward",  # Metric to use for "best" selection
        selector="best"   # "best" or "latest"
    )
    
    # Common metrics that are likely to exist:
    metrics_to_fetch = ["reward", "heart.get", "ore_red.get", "action.move.success"]
    
    # Optional: filter to specific evaluations  
    eval_filter = None  # e.g., "sim_env LIKE '%navigation%'" to only include navigation tasks
    
    print("🎯 Creating heatmap with real data...")
    real_heatmap = fetch_real_heatmap_data(
        policy_names=specific_policies,
        metrics=metrics_to_fetch,
        eval_db_uri="wandb://stats/navigation_db",  # You may need to change this URI
        eval_filter=eval_filter,
        max_policies=15  # Limit display to keep it manageable
    )
    
except Exception as e:
    print(f"❌ Error fetching real data: {e}")
    print("💡 This might happen if:")
    print("   - The database URI is incorrect")
    print("   - You're not authenticated with wandb")
    print("   - The specified metrics don't exist in the database")
    print("   - You don't have access to the database")
    print("\n🔄 Falling back to demo data...")
    
    # Fall back to demo data if real data fails
    from experiments.notebooks.utils.heatmap_widget import create_demo_heatmap
    demo_fallback = create_demo_heatmap()
    demo_fallback
    exit(1)

real_heatmap

🔍 Fetching available policy names from: wandb://stats/navigation_db


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (325.8MB/s)


📋 Found 50 policies
all_runs 40 ['alex_obs_cat_02', 'alex_obs_latent_attn_add_tokens_01', 'alex_obs_latent_fourier_01', 'alex_obs_robust_cross_01', 'alexv_lfourier8_01', 'alexv_resMLP_enc_256x64_01', 'alexv_resMLP_enc_512x32_01', 'alexv_resMLP_enc_value_512x32_01', 'alexv_resMLP_value_256x64_01', 'alexv_resMLP_value_512x32_01']
🎯 Selecting best policies from training runs...
🔍 Selecting best policies from 40 training runs...


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.3 (359.8MB/s)


🔍 Fetching available policy names from: wandb://stats/navigation_db


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (326.4MB/s)


📋 Found 271 policies
all_policies 271 ['alex_obs_cat_02:v18', 'alex_obs_latent_attn_add_tokens_01:v30', 'alex_obs_latent_attn_add_tokens_01:v4', 'alex_obs_latent_fourier_01:v19', 'alex_obs_robust_cross_01:v28', 'alexv_lfourier8_01:v29', 'alexv_resMLP_enc_256x64_01:v1', 'alexv_resMLP_enc_512x32_01:v12', 'alexv_resMLP_enc_value_512x32_01:v1', 'alexv_resMLP_value_256x64_01:v12']
🏆 Best for alex_obs_cat_02: alex_obs_cat_02:v18 (score: 0.799)
🏆 Best for alex_obs_latent_attn_add_tokens_01: alex_obs_latent_attn_add_tokens_01:v4 (score: 0.813)
🏆 Best for alex_obs_latent_fourier_01: alex_obs_latent_fourier_01:v19 (score: 0.866)
🏆 Best for alex_obs_robust_cross_01: alex_obs_robust_cross_01:v28 (score: 0.799)
🏆 Best for alexv_lfourier8_01: alexv_lfourier8_01:v29 (score: 0.919)
🏆 Best for alexv_resMLP_enc_256x64_01: alexv_resMLP_enc_256x64_01:v1 (score: 0.200)
🏆 Best for alexv_resMLP_enc_512x32_01: alexv_resMLP_enc_512x32_01:v12 (score: 0.067)
🏆 Best for alexv_resMLP_enc_value_512x32_01: alexv_res

[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (321.5MB/s)


📈 Fetching metric: reward
   Found 950 policy-eval combinations
📈 Fetching metric: heart.get
   Found 950 policy-eval combinations
📈 Fetching metric: ore_red.get
   Found 950 policy-eval combinations
📈 Fetching metric: action.move.success
   Found 950 policy-eval combinations
🔢 Limiting to 15 policies (found 38)
📊 Final dataset: 15 policies × 25 evaluations
🚀 HeatmapWidget initialized successfully!
📊 Multi-metric data set with 15 policies and 25 evaluations
📈 Available metrics: reward, heart.get, ore_red.get, action.move.success
📈 Selected metric: reward
✅ Successfully created heatmap widget with real data!


HeatmapWidget(heatmap_data={'cells': {'relh.nav.new.42:v26': {'env/mettagrid/navigation/evals/corridors': {'me…

In [15]:
# Compare "best" vs "latest" policy selection for the same runs
sample_runs = [
    "daveey.arena.rnd.16x4.2",
    "bullm.navigation.low_reward.baseline",
    "relh.skypilot.fff.j20.666"
]

print("🔍 Comparing 'best' vs 'latest' policy selection strategies:")
print("=" * 60)

try:
    # Select using "latest" strategy
    print("\\n📈 LATEST strategy (highest version/epoch):")
    latest_policies = select_best_policies_from_runs(
        training_runs=sample_runs,
        selector="latest"
    )
    
    print("\\n🏆 BEST strategy (highest average reward):")
    best_policies = select_best_policies_from_runs(
        training_runs=sample_runs, 
        metric="reward",
        selector="best"
    )
    
    print("\\n📊 COMPARISON:")
    print(f"{'Run':<35} {'Latest':<25} {'Best':<25}")
    print("-" * 85)
    
    # Create lookup dictionaries
    latest_lookup = {}
    best_lookup = {}
    
    for policy in latest_policies:
        for run in sample_runs:
            if policy.startswith(run):
                latest_lookup[run] = policy
                break
                
    for policy in best_policies:
        for run in sample_runs:
            if policy.startswith(run):
                best_lookup[run] = policy
                break
    
    for run in sample_runs:
        latest = latest_lookup.get(run, "None")
        best = best_lookup.get(run, "None")
        same = "✅" if latest == best else "❌"
        print(f"{run:<35} {latest:<25} {best:<25} {same}")
        
    print("\\n💡 Key differences:")
    print("   - 'Latest' picks the most recent version (highest epoch/version number)")
    print("   - 'Best' picks the version with highest average performance across evaluations")
    print("   - They may differ when a later version performs worse than an earlier one")
    
except Exception as e:
    print(f"❌ Error comparing strategies: {e}")
    print("💡 Make sure you have access to the evaluation database")


🔍 Comparing 'best' vs 'latest' policy selection strategies:
\n📈 LATEST strategy (highest version/epoch):
🔍 Selecting latest policies from 3 training runs...


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (279.8MB/s)


🔍 Fetching available policy names from: wandb://stats/navigation_db


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (280.1MB/s)


📋 Found 271 policies
all_policies ['alex_obs_cat_02:v18', 'alex_obs_latent_attn_add_tokens_01:v30', 'alex_obs_latent_attn_add_tokens_01:v4', 'alex_obs_latent_fourier_01:v19', 'alex_obs_robust_cross_01:v28', 'alexv_lfourier8_01:v29', 'alexv_resMLP_enc_256x64_01:v1', 'alexv_resMLP_enc_512x32_01:v12', 'alexv_resMLP_enc_value_512x32_01:v1', 'alexv_resMLP_value_256x64_01:v12', 'alexv_resMLP_value_512x32_01:v13', 'b.daphne.test_nav_bucketedcurriculum:v15', 'b.daphne.test_nav_curriculum:v9', 'b.daphne.test_nav_curriculum_full:v12', 'daphne.lighter_nav:v10', 'daphne.navbucketedopt_devbox:v47', 'daphne.navnoterrain:v6', 'daphne.navopt_devbox:v12', 'daphne.optimize_nav:v164', 'daphne.optimize_nav_aws:v46', 'daphne.optimize_nav_tokenized:v42', 'daphne_nav_bucketed:v19', 'daphne_navigation_bucketed:v18', 'dd.nav_optimized:v11', 'dd.nav_optimized:v31', 'dd.nav_optimized_bucket:v11', 'dd.navbucketed_2:v1', 'dd.navbucketed_sparser:v7', 'dd_curriculum_navigation_tokenized:v23', 'dd_navigation_curricul

[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4 (316.7MB/s)


🔍 Fetching available policy names from: wandb://stats/navigation_db


[34m[1mwandb[0m: Downloading large artifact navigation_db:latest, 124.26MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5 (240.7MB/s)


📋 Found 271 policies
all_policies ['alex_obs_cat_02:v18', 'alex_obs_latent_attn_add_tokens_01:v30', 'alex_obs_latent_attn_add_tokens_01:v4', 'alex_obs_latent_fourier_01:v19', 'alex_obs_robust_cross_01:v28', 'alexv_lfourier8_01:v29', 'alexv_resMLP_enc_256x64_01:v1', 'alexv_resMLP_enc_512x32_01:v12', 'alexv_resMLP_enc_value_512x32_01:v1', 'alexv_resMLP_value_256x64_01:v12', 'alexv_resMLP_value_512x32_01:v13', 'b.daphne.test_nav_bucketedcurriculum:v15', 'b.daphne.test_nav_curriculum:v9', 'b.daphne.test_nav_curriculum_full:v12', 'daphne.lighter_nav:v10', 'daphne.navbucketedopt_devbox:v47', 'daphne.navnoterrain:v6', 'daphne.navopt_devbox:v12', 'daphne.optimize_nav:v164', 'daphne.optimize_nav_aws:v46', 'daphne.optimize_nav_tokenized:v42', 'daphne_nav_bucketed:v19', 'daphne_navigation_bucketed:v18', 'dd.nav_optimized:v11', 'dd.nav_optimized:v31', 'dd.nav_optimized_bucket:v11', 'dd.navbucketed_2:v1', 'dd.navbucketed_sparser:v7', 'dd_curriculum_navigation_tokenized:v23', 'dd_navigation_curricul

In [None]:
# Example: Create a custom heatmap with specific training runs and metrics

# Option 1: Use training run names (recommended - uses smart selection)
my_training_runs = [
    # Add your training run names here, for example:
    # "my_experiment_batch_1",
    # "my_experiment_batch_2", 
    # "baseline_run_v1"
]

# Option 2: Use exact policy URIs (if you know exactly which ones you want)
my_specific_policies = [
    # Add exact policy URIs here, for example:
    # "my_policy_name:v123",
    # "baseline_experiment:v456",
    # "new_approach:v789"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
if my_training_runs:  # Use smart policy selection from training runs
    print("🎯 Creating custom heatmap with best policies from training runs...")
    
    # Select best policies from training runs
    selected_policies = select_best_policies_from_runs(
        training_runs=my_training_runs,
        eval_db_uri="wandb://stats/navigation_db",
        metric="reward",  # Metric to optimize for when selecting "best"
        selector="best"   # or "latest"
    )
    
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=selected_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",
        eval_filter=eval_filter,
        max_policies=20
    )
    
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")
    
    custom_heatmap
    
elif my_specific_policies:  # Use exact policy URIs
    print("🎯 Creating custom heatmap with specific policies...")
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=my_specific_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",
        eval_filter=eval_filter,
        max_policies=20
    )
    
    custom_heatmap
    
else:
    print("📝 To use this example:")
    print("\\n🚀 RECOMMENDED: Use training run names (Option 1)")
    print("1. Add your training run names to 'my_training_runs' list above")
    print("2. The system will automatically select the best policy from each run")
    print("3. Customize the 'my_metrics' list with metrics you're interested in")
    print("4. Run this cell again")
    print("\\n💡 Example training run names:")
    print("   - 'my_experiment_batch_1'")
    print("   - 'baseline_run_v2'")
    print("   - 'new_approach_test'")
    print("\\n⚙️  ALTERNATIVE: Use exact policy URIs (Option 2)")
    print("1. Add exact policy URIs to 'my_specific_policies' list")
    print("2. Example: 'my_policy_name:v123', 'baseline:v456'")
    print("\\n🔍 TIP: Run the exploration code in previous cells to see available options")


In [None]:
# 🗃️ AVAILABLE EVALUATION DATABASES

# Domain-specific databases (most commonly used)
available_databases = {
    "navigation_db": "wandb://stats/navigation_db",      # Navigation tasks
    "memory_db": "wandb://stats/memory_db",              # Memory tasks  
    "objectuse_db": "wandb://stats/objectuse_db",        # Object use tasks
    "nav_sequence_db": "wandb://stats/nav_sequence_db",  # Navigation sequence tasks
    # User-specific databases
    "jack_db": "wandb://stats/jack_db",                  # Jack's personal database
}

print("🗄️  Available Evaluation Databases:")
print("=" * 50)
for name, uri in available_databases.items():
    print(f"📊 {name:<20} → {uri}")

print("\n💡 Usage examples:")
print("   # For navigation analysis:")
print("   fetch_real_heatmap_data(..., eval_db_uri='wandb://stats/navigation_db')")
print("   # For memory analysis:")  
print("   fetch_real_heatmap_data(..., eval_db_uri='wandb://stats/memory_db')")

print("\n🔍 You can also use:")
print("   • Local files: './path/to/my_stats.db'")
print("   • S3 buckets: 's3://bucket/path/stats.db'")

# Quick function to check what's in each database
def explore_database(db_name: str, db_uri: str, limit: int = 5):
    """Quickly explore what's available in a database"""
    print(f"\n🔍 Exploring {db_name} ({db_uri}):")
    print("-" * 40)
    
    try:
        # Get a small sample of data
        policies = get_available_policy_names(eval_db_uri=db_uri, limit=limit)
        metrics = get_available_metrics(eval_db_uri=db_uri, limit=10)
        evals = get_available_evaluations(eval_db_uri=db_uri, limit=10)
        
        print(f"📋 Sample policies ({len(policies)}): {policies[:3]}...")
        print(f"📊 Sample metrics ({len(metrics)}): {metrics[:5]}...")  
        print(f"🏃 Sample evaluations ({len(evals)}): {evals[:5]}...")
        
    except Exception as e:
        print(f"❌ Error accessing {db_name}: {e}")
        if "wandb" in str(e).lower():
            print("💡 You may need to authenticate with wandb or check permissions")

# Uncomment to explore different databases:
# explore_database("Navigation DB", "wandb://stats/navigation_db")
# explore_database("Memory DB", "wandb://stats/memory_db")  
# explore_database("Object Use DB", "wandb://stats/objectuse_db")

print("\n📝 To explore a database, uncomment the explore_database() calls above!")
print("\n🚀 Quick start: Most users will want 'wandb://stats/navigation_db'")


In [None]:
# 🔍 Example: Comparing Policies Across Different Task Categories

# Here's how to create heatmaps from different evaluation databases:

# Example 1: Navigation tasks
print("🧭 Creating navigation heatmap...")
try:
    navigation_runs = ["daveey.arena.rnd.16x4.2", "bullm.navigation.low_reward.baseline"]
    
    nav_policies = select_best_policies_from_runs(
        training_runs=navigation_runs,
        eval_db_uri="wandb://stats/navigation_db",  # Navigation database
        selector="best",
        metric="reward"
    )
    
    nav_heatmap = fetch_real_heatmap_data(
        policy_names=nav_policies,
        metrics=["reward", "heart.get", "action.move.success"],
        eval_db_uri="wandb://stats/navigation_db",
        max_policies=5
    )
    
    print("✅ Navigation heatmap created successfully!")
    # nav_heatmap  # Uncomment to display
    
except Exception as e:
    print(f"❌ Navigation heatmap failed: {e}")

# Example 2: Memory tasks (if available)
print("\n🧠 Memory tasks would use:")
print("   eval_db_uri='wandb://stats/memory_db'")
print("   # Likely different metrics like memory.recall, sequence.accuracy, etc.")

# Example 3: Object use tasks (if available)  
print("\n🔧 Object use tasks would use:")
print("   eval_db_uri='wandb://stats/objectuse_db'")
print("   # Likely different metrics like tool.use.success, manipulation.accuracy, etc.")

print("\n💡 Pro tip: Each database specializes in different task types:")
print("   🧭 navigation_db    → spatial reasoning, pathfinding")
print("   🧠 memory_db        → recall, sequence learning") 
print("   🔧 objectuse_db     → manipulation, tool use")
print("   📚 nav_sequence_db  → sequential navigation tasks")

print("\n📊 To switch databases, just change the eval_db_uri parameter!")
print("   Example: eval_db_uri='wandb://stats/memory_db'")


In [None]:
# Example: Create a custom heatmap with specific policies and metrics

# Step 1: Define your policies of interest
my_policies = [
    # Add your policy names here, for example:
    # "my_experiment_1:v100",
    # "my_experiment_2:v200", 
    # "baseline_policy:v50"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
if my_policies:  # Only run if you've specified policies
    print("🎯 Creating custom heatmap...")
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=my_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",  # Adjust as needed
        eval_filter=eval_filter,
        max_policies=20
    )
    
    # Step 5: Display and interact
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")
    
    custom_heatmap
else:
    print("📝 To use this example:")
    print("1. Uncomment the exploration code in the previous cell to see available policies")
    print("2. Add your policy names to the 'my_policies' list above") 
    print("3. Customize the 'my_metrics' list with metrics you're interested in")
    print("4. Run this cell again")
    print("\n💡 Example policy names might look like:")
    print("   - 'my_policy_name:v123'")
    print("   - 'baseline_experiment:v456'") 
    print("   - 'new_approach:v789'")


In [None]:
import ipywidgets as widgets
from IPython.display import display

from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_demo_heatmap, create_heatmap_widget

# Create a demo heatmap with sample data
demo_widget = create_demo_heatmap()
display(demo_widget)

w = widgets.Button(description="Click me", style=dict(width="200px", height="50px"))
display(w)
# Display the widget
print(demo_widget)

**Try interacting with the heatmap above:**
- Hover over cells to see detailed information
- Click on a row's left policy title label to "open" that policy's Wandb URL in a new tab
- Adjust the "Policies to show" input to change how many policies are displayed
- Click on policy names to open WandB links (in demo, these won't work)


## Example 2: Creating Your Own Heatmap Data

Here's how to create a heatmap with your own data:


## Example 4: Multiple Metrics with Working selectedMetric

Now let's see the `selectedMetric` functionality working properly! This example shows a heatmap where changing the metric actually changes the displayed values:


In [None]:
# Create a multi-metric heatmap widget
from experiments.notebooks.utils.heatmap_widget import create_multi_metric_demo

multi_metric_widget = create_multi_metric_demo()

# Display the widget
multi_metric_widget


In [None]:
# Now try changing the metric to see the values actually change!
print("🔄 Changing metric to 'episode_length'...")
multi_metric_widget.update_metric('episode_length')

# NOTE: Notice how the values in the heatmap widget change as you switch
# metrics?  Do not display the widget again and try to change that. That ends up
# creating a seperate copy of the widget in a new output cell.  Instead just
# reference the one you originally rendered, call its functions, and watch it
# change in its Juypter notebook cell. Like we just did. Let's do it again in
# the next cell too.


In [None]:
# One more time. Run this cell then scroll back up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')


In [None]:
# Last one. Scroll up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')


# Custom metrics

We can really define our cells to have any metric data we want. This is useful because we plan to have all sorts of metrics. Let's look at an example of using any old metric we decide:

In [None]:
# Create a new heatmap widget
custom_widget = create_heatmap_widget()

# Define your data structure
# This should match the format expected by the observatory dashboard
cells_data = {
    'my_policy_v1': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 85.2,
            },
            'replayUrl': 'https://example.com/replay1.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 87.5,
            },
            'replayUrl': 'https://example.com/replay2.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 92.5,
            },
            'replayUrl': 'https://example.com/replay3.json', 
            'evalName': 'task_b/challenge1'
        },
    },
    'my_policy_v2': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 22.5,
            },
            'replayUrl': 'https://example.com/replay4.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 42.5,
            },
            'replayUrl': 'https://example.com/replay5.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 62.5,
            },
            'replayUrl': 'https://example.com/replay6.json', 
            'evalName': 'task_b/challenge1'
        },
    },
}

eval_names = ['task_a/level1', 'task_a/level2', 'task_b/challenge1']
policy_names = ['my_policy_v1', 'my_policy_v2']
policy_averages = {
    'my_policy_v1': 91.6,
    'my_policy_v2': 89.6,
}

# Set the data
custom_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="custom_score"
)

# Display the widget
custom_widget


In [None]:
# NOTE: these callbacks do not work with print(), and that's really just how
# Jupyter widgets work.  Once the Jupyter python cell finishes running and
# outputs a widget, that widget won't be able to affect the output of the cell
# anymore. The only way to to print() from a python widget callback is to write
# to a file (or use a thread maybe). I give an example below.

# Create another widget for callback demonstration
callback_widget = create_heatmap_widget()

# Set up the same data as before
callback_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="Interactive Score (%)"
)

# Define callback functions
def handle_cell_selection(cell_info):
    """Called when user hovers over a cell (not 'overall' column)."""
    with open("output_cell_selection.txt", "w") as f:
        f.write(f"📍 Cell selected: {cell_info['policyUri']} on evaluation '{cell_info['evalName']}'")

def handle_replay_opened(replay_info):
    """Called when user clicks to open a replay."""
    with open("output_replay_opened.txt", "w") as f:
        f.write(f"🎬 Replay opened: {replay_info['replayUrl']}")
        f.write(f"   Policy: {replay_info['policyUri']}")
        f.write(f"   Evaluation: {replay_info['evalName']}")

# Register the callbacks
callback_widget.on_cell_selected(handle_cell_selection)
callback_widget.on_replay_opened(handle_replay_opened)

# Display the widget
callback_widget


In [None]:
# Delete the files created by the callbacks in the previous cell, if they exist
import os

for fname in ["output_cell_selection.txt", "output_replay_opened.txt"]:
    try:
        with open(fname, "r") as f:
            print(f.read())
        os.remove(fname)
        print(f"File {fname} deleted")
    except FileNotFoundError:
        pass


**Try interacting with the heatmap above to see the callback messages printed to
*output files!**

## Data Format Reference

The heatmap widget expects data in a specific format that matches the
observatory dashboard:

```python
cells = {
    'policy_name': {
        'eval_name': {
            'metrics': {
                'reward': 50,
                'heart.get': 98,
                'action.move.success': 5,
                'ore_red.get': 24.2,
                # ... more metrics
            },
            'replayUrl': str,         # URL to replay file
            'evalName': str,          # Should match the key
        },
        # ... more evaluations
    },
    # ... more policies
}
```

**Important notes:**
- Evaluation names with "/" will be grouped by category (the part before "/")
- The heatmap shows policies sorted by average score (worst to best, bottom to top)
- Policy names that contain ":v" will have WandB URLs generated automatically
- Replay URLs should be accessible URLs or file paths

This widget provides the same interactive functionality as the observatory dashboard but in a python environment, making it perfect for exploratory analysis and sharing results via Jupyter notebooks!
