# Interactive Heatmap Widget Example

This notebook demonstrates how to use the `HeatmapWidget` - an anywidget-based implementation of the observatory heatmap component for Jupyter notebooks.

The widget provides interactive policy evaluation heatmaps with:
- Hover effects showing detailed information
- Double-click to open replay URLs
- Dynamic control over number of policies displayed
- Automatic organization by evaluation categories


## Installation

First, make sure you have the required dependencies:
`pip install anywidget traitlets`


## Import and Basic Setup


In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from experiments.notebooks.utils.metrics import fetch_metrics
from experiments.notebooks.utils.monitoring import monitor_training_statuses
from experiments.notebooks.utils.replays import show_replay
from experiments.notebooks.utils.training import launch_training
from experiments.notebooks.utils.metrics import find_training_jobs

%matplotlib inline
plt.style.use("default")

# Add utils directory to path
sys.path.append(os.path.join(os.getcwd(), 'utils'))

%load_ext anywidget

print("Setup complete! Auto-reload enabled.")


## Example 1: Demo Heatmap with Sample Data

Let's start with a simple demo that includes sample data:


In [None]:
from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_demo_heatmap, create_heatmap_widget

# Create a demo heatmap with sample data
demo_widget = create_demo_heatmap()

# Display the widget
demo_widget


In [3]:
api_base_url = "https://api.observatory.softmax-research.net"
auth_token = "WTGipvHPU5RtZsN3S033H4QnWLKkOq3LtPmhes6iFQk"

In [4]:
# 🔧 Fixed API Client Implementation
import asyncio
import httpx
from typing import Dict, List, Optional, Any
import logging

class MettaAPIClientFixed:
    """Fixed client that properly handles authentication and response parsing."""
    
    def __init__(self, base_url: str, auth_token: Optional[str] = None):
        self.base_url = base_url.rstrip("/")
        self.headers = {"Content-Type": "application/json"}
        if auth_token:
            # Use X-Auth-Token header format like extract_training_rewards.py
            self.headers["X-Auth-Token"] = auth_token

    async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
        """Make an HTTP request to the API."""
        url = f"{self.base_url}{endpoint}"
        print(f"🔍 Making {method} request to: {url}")
        print(f"🔑 Headers: {self.headers}")
        if 'json' in kwargs:
            print(f"📦 Payload: {kwargs['json']}")
            
        async with httpx.AsyncClient() as client:
            response = await client.request(method, url, headers=self.headers, timeout=30.0, **kwargs)
            print(f"📨 Response status: {response.status_code}")
            if response.status_code >= 400:
                print(f"❌ Response body: {response.text}")
            response.raise_for_status()
            return response.json()
    
    async def get_policies(self, search_text: Optional[str] = None, page_size: int = 50):
        """Get available policies and training runs."""
        url = "/heatmap/policies"
        payload = {
            "search_text": search_text,  # Use None instead of empty string
            "pagination": {"page": 1, "page_size": page_size}
        }
        return await self._make_request("POST", url, json=payload)
    
    async def get_eval_names(self, training_run_ids: List[str], run_free_policy_ids: List[str] = []):
        """Get evaluation names for selected policies."""
        url = "/heatmap/evals"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids
        }
        return await self._make_request("POST", url, json=payload)
    
    async def get_available_metrics(self, training_run_ids: List[str], run_free_policy_ids: List[str], eval_names: List[str]):
        """Get available metrics for selected policies and evaluations."""
        url = "/heatmap/metrics"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names
        }
        return await self._make_request("POST", url, json=payload)
    
    async def generate_heatmap(self, training_run_ids: List[str], run_free_policy_ids: List[str], 
                        eval_names: List[str], metric: str, policy_selector: str = "best"):
        """Generate heatmap data."""
        url = "/heatmap/heatmap"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names,
            "metric": metric,
            "training_run_policy_selector": policy_selector
        }
        return await self._make_request("POST", url, json=payload)

# Test the fixed client with better debugging
async def test_api_connection_fixed(api_base_url: str, auth_token: str):
    """Test API connection with the fixed client."""
    print(f"🧪 Testing fixed API client...")
    client = MettaAPIClientFixed(api_base_url, auth_token)
    
    try:
        # Test getting policies
        print("\n📋 Testing /heatmap/policies endpoint...")
        policies_response = await client.get_policies(page_size=5)
        print(f"✅ Success! Got response: {type(policies_response)}")
        if isinstance(policies_response, dict) and 'policies' in policies_response:
            print(f"📊 Found {len(policies_response['policies'])} policies")
            if policies_response['policies']:
                print(f"🔍 Sample policy: {policies_response['policies'][0]}")
        else:
            print(f"⚠️  Unexpected response structure: {policies_response}")
        return True
        
    except Exception as e:
        print(f"❌ Test failed: {e}")
        return False

# Run the test
print("🚀 Testing the fixed API client...")
test_result = await test_api_connection_fixed(
    api_base_url=api_base_url,
    auth_token=auth_token
)


🚀 Testing the fixed API client...
🧪 Testing fixed API client...

📋 Testing /heatmap/policies endpoint...
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/policies
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': 'WTGipvHPU5RtZsN3S033H4QnWLKkOq3LtPmhes6iFQk'}
📦 Payload: {'search_text': None, 'pagination': {'page': 1, 'page_size': 5}}
📨 Response status: 200
✅ Success! Got response: <class 'dict'>
📊 Found 5 policies
🔍 Sample policy: {'id': 'a23fae54-0001-499a-8c86-a9083ad48c35', 'type': 'training_run', 'name': 'bullm.navigation.low_reward.with_context.07-24', 'user_id': 'matthew@stem.ai', 'created_at': '2025-07-25 07:01:19.783193', 'tags': ['user:unknown']}


In [48]:
import requests
from typing import Dict, List, Optional
import logging
import httpx
from typing import Any


from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_heatmap_widget

# Simple HTTP client approach - mimics repo.ts exactly
class MettaAPIClient:
    """Simple HTTP client that mimics the TypeScript repo.ts functionality."""
    
    def __init__(self, base_url: str, auth_token: Optional[str] = None):
        """
        Initialize the API client.

        Args:
            base_url: Base URL of the API server
            auth_token: Optional authentication token
        """
        if not base_url:
            base_url = "http://localhost:8000"
        self.base_url = base_url.rstrip("/")
        if not auth_token:
            raise ValueError("auth_token is required")
        self.headers = {}
        self.headers["X-Auth-Token"] = auth_token

    async def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
        """Make an HTTP request to the API."""
        url = f"{self.base_url}{endpoint}"
        async with httpx.AsyncClient() as client:
            response = await client.request(method, url, headers=self.headers, timeout=30.0, **kwargs)
            response.raise_for_status()
            return response.json()
    
    async def get_policies(self, search_text: str = "", page_size: int = 50):
        """Get available policies and training runs."""
        url = "/heatmap/policies"
        payload = {
            "search_text": search_text,
            "pagination": {"page": 1, "page_size": page_size}
        }
        data = await self._make_request("POST", url, json=payload)
        return data
    
    async def get_eval_names(self, training_run_ids: List[str], run_free_policy_ids: List[str] = []):
        """Get evaluation names for selected policies."""
        url = "/heatmap/evals"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids
        }
        data = await self._make_request("POST", url, json=payload)
        return data
    
    async def get_available_metrics(self, training_run_ids: List[str], run_free_policy_ids: List[str], eval_names: List[str]):
        """Get available metrics for selected policies and evaluations."""
        url = "/heatmap/metrics"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names
        }
        data = await self._make_request("POST", url, json=payload)
        return data.get("metrics", [])
    
    async def generate_heatmap(self, training_run_ids: List[str], run_free_policy_ids: List[str], 
                        eval_names: List[str], metric: str, policy_selector: str = "best"):
        """Generate heatmap data."""
        url = "/heatmap/heatmap"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names,
            "metric": metric,
            "training_run_policy_selector": policy_selector
        }
        data = await self._make_request("POST", url, json=payload)
        return data.get("heatmap", {})

async def fetch_real_heatmap_data(
    training_run_names: List[str],
    metrics: List[str],
    policy_selector: str = "best",
    api_base_url: str = "http://localhost:8000",
    auth_token: Optional[str] = None,
    max_policies: int = 20
) -> HeatmapWidget:
    """
    Fetch real evaluation data using the metta HTTP API (same as repo.ts).
    
    Args:
        training_run_names: List of training run names (e.g., ["daveey.arena.rnd.16x4.2"])
        metrics: List of metrics to include (e.g., ["reward", "heart.get"])
        policy_selector: "best" or "latest" policy selection strategy
        api_base_url: Base URL for the metta API
        max_policies: Maximum number of policies to display
        
    Returns:
        HeatmapWidget with real data
    """
    print(f"🔍 Fetching real heatmap data using HTTP API: {api_base_url}")
    print(f"📋 Training runs: {training_run_names}")
    print(f"📊 Metrics: {metrics}")
    print(f"🎯 Policy selector: {policy_selector}")
    
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        
        # Step 1: Get available policies to find training run IDs
        print("🔍 Getting available policies...")
        policies_data = await client.get_policies(page_size=1000)
        
        # Find training run IDs that match our training run names
        training_run_ids = []
        for policy in policies_data["policies"]:
            if policy["type"] == "training_run" and any(run_name in policy["name"] for run_name in training_run_names):
                training_run_ids.append(policy["id"])
        
        if not training_run_ids:
            print(f"❌ No training runs found matching: {training_run_names}")
            return create_heatmap_widget()
        
        print(f"✅ Found {len(training_run_ids)} matching training runs")
        
        # Step 2: Get available evaluations for these training runs
        print("🔍 Getting evaluation names...")
        eval_names = await client.get_eval_names(training_run_ids, [])
        if not eval_names:
            print("❌ No evaluations found for selected training runs")
            return create_heatmap_widget()
        
        print(f"✅ Found {len(eval_names)} evaluations")
        
        # Step 3: Get available metrics 
        print("🔍 Getting available metrics...")
        available_metrics = await client.get_available_metrics(training_run_ids, [], eval_names)
        
        # Filter to requested metrics that actually exist
        valid_metrics = [m for m in metrics if m in available_metrics]
        if not valid_metrics:
            print(f"❌ None of the requested metrics {metrics} are available")
            print(f"💡 Available metrics: {available_metrics[:10]}...")
            return create_heatmap_widget()
        
        print(f"✅ Using metrics: {valid_metrics}")
        
        # Step 4: Generate heatmap for the first metric
        primary_metric = valid_metrics[0]
        print(f"🔍 Generating heatmap for metric: {primary_metric}")
        keys = list(eval_names.keys())
        heatmap_data = await client.generate_heatmap(
            training_run_ids, [], keys, primary_metric, policy_selector
        )
        
        if not heatmap_data["policyNames"]:
            print("❌ No heatmap data generated")
            return create_heatmap_widget()
        
        # Limit policies if requested
        policy_names = heatmap_data["policyNames"]
        if len(policy_names) > max_policies:
            print(f"🔢 Limiting to {max_policies} policies (found {len(policy_names)})")
            # Sort by average score and take top N
            avg_scores = heatmap_data["policyAverageScores"]
            top_policies = sorted(avg_scores.keys(), key=lambda p: avg_scores[p], reverse=True)[:max_policies]
            
            # Filter the data
            filtered_cells = {p: heatmap_data["cells"][p] for p in top_policies if p in heatmap_data["cells"]}
            heatmap_data["policyNames"] = top_policies
            heatmap_data["cells"] = filtered_cells
            heatmap_data["policyAverageScores"] = {p: avg_scores[p] for p in top_policies if p in avg_scores}
        
        print(f"📊 Final dataset: {len(heatmap_data['policyNames'])} policies × {len(heatmap_data['evalNames'])} evaluations")
        
        # Step 5: Convert to widget format
        cells = {}
        for policy_name in heatmap_data["policyNames"]:
            cells[policy_name] = {}
            for eval_name in heatmap_data["evalNames"]:
                cell = heatmap_data["cells"].get(policy_name, {}).get(eval_name, {})
                cells[policy_name][eval_name] = {
                    'metrics': {primary_metric: cell.get("value", 0.0)},
                    'replayUrl': cell.get("replayUrl"),
                    'evalName': eval_name
                }
        
        # Create widget
        widget = create_heatmap_widget()
        widget.set_multi_metric_data(
            cells=cells,
            eval_names=heatmap_data["evalNames"], 
            policy_names=heatmap_data["policyNames"],
            metrics=[primary_metric],
            selected_metric=primary_metric
        )
        
        print("✅ Successfully created heatmap widget with real data!")
        return widget
        
    except requests.exceptions.ConnectionError:
        print("❌ Could not connect to metta API server")
        print("💡 Make sure the app_backend server is running on http://localhost:8000")
        print("💡 You can start it with: cd app_backend && uv run python server.py")
        return create_heatmap_widget()
    except Exception as e:
        print(f"❌ Error fetching real data: {e}")
        return create_heatmap_widget()

async def get_available_policies(api_base_url: str = "http://localhost:8000", limit: int = 50):
    """Get available policies and training runs."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        return await client.get_policies(page_size=limit)
    except Exception as e:
        print(f"❌ Error fetching policies: {e}")
        return {"policies": []}

async def get_available_eval_names(training_run_ids: List[str], api_base_url: str = "http://localhost:8000"):
    """Get available evaluation names."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        return await client.get_eval_names(training_run_ids, [])
    except Exception as e:
        print(f"❌ Error fetching eval names: {e}")
        return []

async def get_available_metrics(training_run_ids: List[str], eval_names: List[str], api_base_url: str = "http://localhost:8000"):
    """Get available metrics."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        return await client.get_available_metrics(training_run_ids, [], eval_names)
    except Exception as e:
        print(f"❌ Error fetching metrics: {e}")
        return []

print("🚀 Metta HTTP API client loaded!")
print("📋 Available functions:")
print("   - fetch_real_heatmap_data(training_run_names, metrics, policy_selector)")
print("   - get_available_policies(api_base_url)")  
print("   - get_available_eval_names(training_run_ids, api_base_url)")
print("   - get_available_metrics(training_run_ids, eval_names, api_base_url)")
print("")
print("💡 This uses HTTP API calls exactly like repo.ts!")
print("🔗 Requires app_backend server running on http://localhost:8000")
print("🚀 Start with: cd app_backend && uv run python server.py")


🚀 Metta HTTP API client loaded!
📋 Available functions:
   - fetch_real_heatmap_data(training_run_names, metrics, policy_selector)
   - get_available_policies(api_base_url)
   - get_available_eval_names(training_run_ids, api_base_url)
   - get_available_metrics(training_run_ids, eval_names, api_base_url)

💡 This uses HTTP API calls exactly like repo.ts!
🔗 Requires app_backend server running on http://localhost:8000
🚀 Start with: cd app_backend && uv run python server.py


## Example: Using Real Data

Now let's explore what's available in the database and create a heatmap with real data:


In [51]:
# For now, let's try with some common metrics and see what we find:
try:
    # specific_runs = [
    #     "daveey.arena.rnd.16x4.2",
    #     "relh.skypilot.fff.j20.666",
    #     "bullm.navigation.low_reward.baseline",
    #     "bullm.navigation.low_reward.baseline.07-17", 
    #     "bullm.navigation.low_reward.baseline.07-23",
    #     "relh.multigpu.fff.1",
    #     "relh.skypilot.fff.j21.2",
    # ]
    
    # Common metrics that are likely to exist:
    metrics_to_fetch = ["reward", "heart.get", "ore_red.get", "action.move.success"]
    
    print("🎯 Creating heatmap with real data...")
    real_heatmap = await fetch_real_heatmap_data(
        api_base_url=api_base_url,
        auth_token=auth_token,
        training_run_names=[
            "daveey.arena.rnd.16x4.2",
            "relh.skypilot.fff.j20.666",
            "bullm.navigation.low_reward.baseline",
            "bullm.navigation.low_reward.baseline.07-17", 
            "bullm.navigation.low_reward.baseline.07-23",
            "relh.multigpu.fff.1",
            "relh.skypilot.fff.j21.2",
        ],
        metrics=metrics_to_fetch,
        max_policies=15  # Limit display to keep it manageable
    )
    
    # Display the widget
    real_heatmap
    
except Exception as e:
    print(f"❌ Error fetching real data: {e}")
    print("💡 This might happen if:")
    print("   - The database URI is incorrect")
    print("   - You're not authenticated with wandb")
    print("   - The specified metrics don't exist in the database")
    print("   - You don't have access to the database")
    print("\n🔄 Falling back to demo data...")
    
    # Fall back to demo data if real data fails
    from experiments.notebooks.utils.heatmap_widget import create_demo_heatmap
    demo_fallback = create_demo_heatmap()
    demo_fallback


🎯 Creating heatmap with real data...
🔍 Fetching real heatmap data using HTTP API: https://api.observatory.softmax-research.net
📋 Training runs: ['daveey.arena.rnd.16x4.2', 'relh.skypilot.fff.j20.666', 'bullm.navigation.low_reward.baseline', 'bullm.navigation.low_reward.baseline.07-17', 'bullm.navigation.low_reward.baseline.07-23', 'relh.multigpu.fff.1', 'relh.skypilot.fff.j21.2']
📊 Metrics: ['reward', 'heart.get', 'ore_red.get', 'action.move.success']
🎯 Policy selector: best
🔍 Getting available policies...
❌ Error fetching real data: Client error '422 Unprocessable Entity' for url 'https://api.observatory.softmax-research.net/heatmap/policies'
For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/422
🚀 HeatmapWidget initialized successfully!


In [None]:
# Compare "best" vs "latest" policy selection for the same runs
sample_runs = [
    "daveey.arena.rnd.16x4.2",
    "bullm.navigation.low_reward.baseline",
    "relh.skypilot.fff.j20.666"
]

print("🔍 Comparing 'best' vs 'latest' policy selection strategies:")
print("=" * 60)

try:
    # Select using "latest" strategy
    print("\\n📈 LATEST strategy (highest version/epoch):")
    latest_policies = select_best_policies_from_runs(
        training_runs=sample_runs,
        selector="latest"
    )
    
    print("\\n🏆 BEST strategy (highest average reward):")
    best_policies = select_best_policies_from_runs(
        training_runs=sample_runs, 
        metric="reward",
        selector="best"
    )
    
    print("\\n📊 COMPARISON:")
    print(f"{'Run':<35} {'Latest':<25} {'Best':<25}")
    print("-" * 85)
    
    # Create lookup dictionaries
    latest_lookup = {}
    best_lookup = {}
    
    for policy in latest_policies:
        for run in sample_runs:
            if policy.startswith(run):
                latest_lookup[run] = policy
                break
                
    for policy in best_policies:
        for run in sample_runs:
            if policy.startswith(run):
                best_lookup[run] = policy
                break
    
    for run in sample_runs:
        latest = latest_lookup.get(run, "None")
        best = best_lookup.get(run, "None")
        same = "✅" if latest == best else "❌"
        print(f"{run:<35} {latest:<25} {best:<25} {same}")
        
    print("\\n💡 Key differences:")
    print("   - 'Latest' picks the most recent version (highest epoch/version number)")
    print("   - 'Best' picks the version with highest average performance across evaluations")
    print("   - They may differ when a later version performs worse than an earlier one")
    
except Exception as e:
    print(f"❌ Error comparing strategies: {e}")
    print("💡 Make sure you have access to the evaluation database")


In [None]:
# Example: Create a custom heatmap with specific training runs and metrics

# Option 1: Use training run names (recommended - uses smart selection)
my_training_runs = [
    # Add your training run names here, for example:
    # "my_experiment_batch_1",
    # "my_experiment_batch_2", 
    # "baseline_run_v1"
]

# Option 2: Use exact policy URIs (if you know exactly which ones you want)
my_specific_policies = [
    # Add exact policy URIs here, for example:
    # "my_policy_name:v123",
    # "baseline_experiment:v456",
    # "new_approach:v789"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
if my_training_runs:  # Use smart policy selection from training runs
    print("🎯 Creating custom heatmap with best policies from training runs...")
    
    # Select best policies from training runs
    selected_policies = select_best_policies_from_runs(
        training_runs=my_training_runs,
        eval_db_uri="wandb://stats/navigation_db",
        metric="reward",  # Metric to optimize for when selecting "best"
        selector="best"   # or "latest"
    )
    
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=selected_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",
        eval_filter=eval_filter,
        max_policies=20
    )
    
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")
    
    custom_heatmap
    
elif my_specific_policies:  # Use exact policy URIs
    print("🎯 Creating custom heatmap with specific policies...")
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=my_specific_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",
        eval_filter=eval_filter,
        max_policies=20
    )
    
    custom_heatmap
    
else:
    print("📝 To use this example:")
    print("\\n🚀 RECOMMENDED: Use training run names (Option 1)")
    print("1. Add your training run names to 'my_training_runs' list above")
    print("2. The system will automatically select the best policy from each run")
    print("3. Customize the 'my_metrics' list with metrics you're interested in")
    print("4. Run this cell again")
    print("\\n💡 Example training run names:")
    print("   - 'my_experiment_batch_1'")
    print("   - 'baseline_run_v2'")
    print("   - 'new_approach_test'")
    print("\\n⚙️  ALTERNATIVE: Use exact policy URIs (Option 2)")
    print("1. Add exact policy URIs to 'my_specific_policies' list")
    print("2. Example: 'my_policy_name:v123', 'baseline:v456'")
    print("\\n🔍 TIP: Run the exploration code in previous cells to see available options")


In [None]:
# 🗃️ AVAILABLE EVALUATION DATABASES

# Domain-specific databases (most commonly used)
available_databases = {
    "navigation_db": "wandb://stats/navigation_db",      # Navigation tasks
    "memory_db": "wandb://stats/memory_db",              # Memory tasks  
    "objectuse_db": "wandb://stats/objectuse_db",        # Object use tasks
    "nav_sequence_db": "wandb://stats/nav_sequence_db",  # Navigation sequence tasks
    # User-specific databases
    "jack_db": "wandb://stats/jack_db",                  # Jack's personal database
}

print("🗄️  Available Evaluation Databases:")
print("=" * 50)
for name, uri in available_databases.items():
    print(f"📊 {name:<20} → {uri}")

print("\n💡 Usage examples:")
print("   # For navigation analysis:")
print("   fetch_real_heatmap_data(..., eval_db_uri='wandb://stats/navigation_db')")
print("   # For memory analysis:")  
print("   fetch_real_heatmap_data(..., eval_db_uri='wandb://stats/memory_db')")

print("\n🔍 You can also use:")
print("   • Local files: './path/to/my_stats.db'")
print("   • S3 buckets: 's3://bucket/path/stats.db'")

# Quick function to check what's in each database
def explore_database(db_name: str, db_uri: str, limit: int = 5):
    """Quickly explore what's available in a database"""
    print(f"\n🔍 Exploring {db_name} ({db_uri}):")
    print("-" * 40)
    
    try:
        # Get a small sample of data
        policies = get_available_policy_names(eval_db_uri=db_uri, limit=limit)
        metrics = get_available_metrics(eval_db_uri=db_uri, limit=10)
        evals = get_available_evaluations(eval_db_uri=db_uri, limit=10)
        
        print(f"📋 Sample policies ({len(policies)}): {policies[:3]}...")
        print(f"📊 Sample metrics ({len(metrics)}): {metrics[:5]}...")  
        print(f"🏃 Sample evaluations ({len(evals)}): {evals[:5]}...")
        
    except Exception as e:
        print(f"❌ Error accessing {db_name}: {e}")
        if "wandb" in str(e).lower():
            print("💡 You may need to authenticate with wandb or check permissions")

# Uncomment to explore different databases:
# explore_database("Navigation DB", "wandb://stats/navigation_db")
# explore_database("Memory DB", "wandb://stats/memory_db")  
# explore_database("Object Use DB", "wandb://stats/objectuse_db")

print("\n📝 To explore a database, uncomment the explore_database() calls above!")
print("\n🚀 Quick start: Most users will want 'wandb://stats/navigation_db'")


In [None]:
# 🔍 Example: Comparing Policies Across Different Task Categories

# Here's how to create heatmaps from different evaluation databases:

# Example 1: Navigation tasks
print("🧭 Creating navigation heatmap...")
try:
    navigation_runs = ["daveey.arena.rnd.16x4.2", "bullm.navigation.low_reward.baseline"]
    
    nav_policies = select_best_policies_from_runs(
        training_runs=navigation_runs,
        eval_db_uri="wandb://stats/navigation_db",  # Navigation database
        selector="best",
        metric="reward"
    )
    
    nav_heatmap = fetch_real_heatmap_data(
        policy_names=nav_policies,
        metrics=["reward", "heart.get", "action.move.success"],
        eval_db_uri="wandb://stats/navigation_db",
        max_policies=5
    )
    
    print("✅ Navigation heatmap created successfully!")
    # nav_heatmap  # Uncomment to display
    
except Exception as e:
    print(f"❌ Navigation heatmap failed: {e}")

# Example 2: Memory tasks (if available)
print("\n🧠 Memory tasks would use:")
print("   eval_db_uri='wandb://stats/memory_db'")
print("   # Likely different metrics like memory.recall, sequence.accuracy, etc.")

# Example 3: Object use tasks (if available)  
print("\n🔧 Object use tasks would use:")
print("   eval_db_uri='wandb://stats/objectuse_db'")
print("   # Likely different metrics like tool.use.success, manipulation.accuracy, etc.")

print("\n💡 Pro tip: Each database specializes in different task types:")
print("   🧭 navigation_db    → spatial reasoning, pathfinding")
print("   🧠 memory_db        → recall, sequence learning") 
print("   🔧 objectuse_db     → manipulation, tool use")
print("   📚 nav_sequence_db  → sequential navigation tasks")

print("\n📊 To switch databases, just change the eval_db_uri parameter!")
print("   Example: eval_db_uri='wandb://stats/memory_db'")


In [None]:
# Example: Create a custom heatmap with specific policies and metrics

# Step 1: Define your policies of interest
my_policies = [
    # Add your policy names here, for example:
    # "my_experiment_1:v100",
    # "my_experiment_2:v200", 
    # "baseline_policy:v50"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
if my_policies:  # Only run if you've specified policies
    print("🎯 Creating custom heatmap...")
    custom_heatmap = fetch_real_heatmap_data(
        policy_names=my_policies,
        metrics=my_metrics,
        eval_db_uri="wandb://stats/navigation_db",  # Adjust as needed
        eval_filter=eval_filter,
        max_policies=20
    )
    
    # Step 5: Display and interact
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")
    
    custom_heatmap
else:
    print("📝 To use this example:")
    print("1. Uncomment the exploration code in the previous cell to see available policies")
    print("2. Add your policy names to the 'my_policies' list above") 
    print("3. Customize the 'my_metrics' list with metrics you're interested in")
    print("4. Run this cell again")
    print("\n💡 Example policy names might look like:")
    print("   - 'my_policy_name:v123'")
    print("   - 'baseline_experiment:v456'") 
    print("   - 'new_approach:v789'")


In [None]:
import ipywidgets as widgets
from IPython.display import display

from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_demo_heatmap, create_heatmap_widget

# Create a demo heatmap with sample data
demo_widget = create_demo_heatmap()
display(demo_widget)

w = widgets.Button(description="Click me", style=dict(width="200px", height="50px"))
display(w)
# Display the widget
print(demo_widget)

**Try interacting with the heatmap above:**
- Hover over cells to see detailed information
- Click on a row's left policy title label to "open" that policy's Wandb URL in a new tab
- Adjust the "Policies to show" input to change how many policies are displayed
- Click on policy names to open WandB links (in demo, these won't work)


## Example 2: Creating Your Own Heatmap Data

Here's how to create a heatmap with your own data:


## Example 4: Multiple Metrics with Working selectedMetric

Now let's see the `selectedMetric` functionality working properly! This example shows a heatmap where changing the metric actually changes the displayed values:


In [None]:
# Create a multi-metric heatmap widget
from experiments.notebooks.utils.heatmap_widget import create_multi_metric_demo

multi_metric_widget = create_multi_metric_demo()

# Display the widget
multi_metric_widget


In [None]:
# Now try changing the metric to see the values actually change!
print("🔄 Changing metric to 'episode_length'...")
multi_metric_widget.update_metric('episode_length')

# NOTE: Notice how the values in the heatmap widget change as you switch
# metrics?  Do not display the widget again and try to change that. That ends up
# creating a seperate copy of the widget in a new output cell.  Instead just
# reference the one you originally rendered, call its functions, and watch it
# change in its Juypter notebook cell. Like we just did. Let's do it again in
# the next cell too.


In [None]:
# One more time. Run this cell then scroll back up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')


In [None]:
# Last one. Scroll up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')


# Custom metrics

We can really define our cells to have any metric data we want. This is useful because we plan to have all sorts of metrics. Let's look at an example of using any old metric we decide:

In [None]:
# Create a new heatmap widget
custom_widget = create_heatmap_widget()

# Define your data structure
# This should match the format expected by the observatory dashboard
cells_data = {
    'my_policy_v1': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 85.2,
            },
            'replayUrl': 'https://example.com/replay1.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 87.5,
            },
            'replayUrl': 'https://example.com/replay2.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 92.5,
            },
            'replayUrl': 'https://example.com/replay3.json', 
            'evalName': 'task_b/challenge1'
        },
    },
    'my_policy_v2': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 22.5,
            },
            'replayUrl': 'https://example.com/replay4.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 42.5,
            },
            'replayUrl': 'https://example.com/replay5.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 62.5,
            },
            'replayUrl': 'https://example.com/replay6.json', 
            'evalName': 'task_b/challenge1'
        },
    },
}

eval_names = ['task_a/level1', 'task_a/level2', 'task_b/challenge1']
policy_names = ['my_policy_v1', 'my_policy_v2']
policy_averages = {
    'my_policy_v1': 91.6,
    'my_policy_v2': 89.6,
}

# Set the data
custom_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="custom_score"
)

# Display the widget
custom_widget


In [None]:
# NOTE: these callbacks do not work with print(), and that's really just how
# Jupyter widgets work.  Once the Jupyter python cell finishes running and
# outputs a widget, that widget won't be able to affect the output of the cell
# anymore. The only way to to print() from a python widget callback is to write
# to a file (or use a thread maybe). I give an example below.

# Create another widget for callback demonstration
callback_widget = create_heatmap_widget()

# Set up the same data as before
callback_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="Interactive Score (%)"
)

# Define callback functions
def handle_cell_selection(cell_info):
    """Called when user hovers over a cell (not 'overall' column)."""
    with open("output_cell_selection.txt", "w") as f:
        f.write(f"📍 Cell selected: {cell_info['policyUri']} on evaluation '{cell_info['evalName']}'")

def handle_replay_opened(replay_info):
    """Called when user clicks to open a replay."""
    with open("output_replay_opened.txt", "w") as f:
        f.write(f"🎬 Replay opened: {replay_info['replayUrl']}")
        f.write(f"   Policy: {replay_info['policyUri']}")
        f.write(f"   Evaluation: {replay_info['evalName']}")

# Register the callbacks
callback_widget.on_cell_selected(handle_cell_selection)
callback_widget.on_replay_opened(handle_replay_opened)

# Display the widget
callback_widget


In [None]:
# Delete the files created by the callbacks in the previous cell, if they exist
import os

for fname in ["output_cell_selection.txt", "output_replay_opened.txt"]:
    try:
        with open(fname, "r") as f:
            print(f.read())
        os.remove(fname)
        print(f"File {fname} deleted")
    except FileNotFoundError:
        pass


**Try interacting with the heatmap above to see the callback messages printed to
*output files!**

## Data Format Reference

The heatmap widget expects data in a specific format that matches the
observatory dashboard:

```python
cells = {
    'policy_name': {
        'eval_name': {
            'metrics': {
                'reward': 50,
                'heart.get': 98,
                'action.move.success': 5,
                'ore_red.get': 24.2,
                # ... more metrics
            },
            'replayUrl': str,         # URL to replay file
            'evalName': str,          # Should match the key
        },
        # ... more evaluations
    },
    # ... more policies
}
```

**Important notes:**
- Evaluation names with "/" will be grouped by category (the part before "/")
- The heatmap shows policies sorted by average score (worst to best, bottom to top)
- Policy names that contain ":v" will have WandB URLs generated automatically
- Replay URLs should be accessible URLs or file paths

This widget provides the same interactive functionality as the observatory dashboard but in a python environment, making it perfect for exploratory analysis and sharing results via Jupyter notebooks!
