# Interactive Heatmap Widget Example

This notebook demonstrates how to use the `HeatmapWidget` - an anywidget-based implementation of the observatory heatmap component for Jupyter notebooks.

The widget provides interactive policy evaluation heatmaps with:
- Hover effects showing detailed information
- Double-click to open replay URLs
- Dynamic control over number of policies displayed
- Automatic organization by evaluation categories


## Installation

First, make sure you have the required dependencies:
`pip install anywidget traitlets`


## Import and Basic Setup


In [155]:
%load_ext autoreload
%autoreload 2
import os
import sys
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from experiments.notebooks.utils.metrics import fetch_metrics
from experiments.notebooks.utils.monitoring import monitor_training_statuses
from experiments.notebooks.utils.replays import show_replay
from experiments.notebooks.utils.training import launch_training
from experiments.notebooks.utils.metrics import find_training_jobs

%matplotlib inline
plt.style.use("default")

%load_ext anywidget

print("Setup complete! Auto-reload enabled.")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The anywidget extension is already loaded. To reload it, use:
  %reload_ext anywidget
Setup complete! Auto-reload enabled.


## Example 1: Demo Heatmap with Sample Data

Let's start with a simple demo that includes sample data:


In [156]:
from experiments.notebooks.utils.heatmap_widget import HeatmapWidget, create_demo_heatmap, create_heatmap_widget
from IPython.display import display

# Create a demo heatmap with sample data
demo_widget = create_demo_heatmap()

# Display the widget
demo_widget

🎯 Creating demo heatmap widget...
🚀 HeatmapWidget initialized successfully!
📊 Data set with 3 policies and 4 evaluations
📈 Selected metric: reward
✅ Demo heatmap widget created with sample data!


HeatmapWidget(heatmap_data={'cells': {'policy_alpha_v1': {'navigation/maze1': {'value': 85.2, 'replayUrl': 'sa…

In [157]:
api_base_url = "https://api.observatory.softmax-research.net"
auth_token = ""


In [158]:
# 🔧 Fixed API Client Implementation
import asyncio
import httpx
from typing import Dict, List, Optional, Any

class MettaAPIClient:
    """Fixed client that properly handles authentication and response parsing."""
    
    def __init__(self, base_url: str, auth_token: Optional[str] = None):
        print(base_url)
        self.base_url = base_url.rstrip("/")
        self.headers = {"Content-Type": "application/json"}
        if auth_token:
            # Use X-Auth-Token header format like extract_training_rewards.py
            self.headers["X-Auth-Token"] = auth_token

    async def _make_request(self, method: str, endpoint: str, **kwargs):
        """Make an HTTP request to the API."""
        url = f"{self.base_url}{endpoint}"
        print(f"🔍 Making {method} request to: {url}")
        print(f"🔑 Headers: {self.headers}")
        if 'json' in kwargs:
            print(f"📦 Payload: {kwargs['json']}")
            
        async with httpx.AsyncClient() as client:
            response = await client.request(method, url, headers=self.headers, timeout=30.0, **kwargs)
            print(f"📨 Response status: {response.status_code}")
            if response.status_code >= 400:
                print(f"❌ Response body: {response.text}")
            response.raise_for_status()
            return response.json()
    
    async def get_policies(self, search_text: Optional[str] = None, page_size: int = 50):
        """Get available policies and training runs."""
        url = "/heatmap/policies"
        payload = {
            "search_text": search_text,  # Use None instead of empty string
            "pagination": {"page": 1, "page_size": page_size}
        }
        return await self._make_request("POST", url, json=payload)
    
    async def get_eval_names(self, training_run_ids: List[str], run_free_policy_ids: List[str] = []):
        """Get evaluation names for selected policies."""
        url = "/heatmap/evals"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids
        }
        return await self._make_request("POST", url, json=payload)
    
    async def get_available_metrics(self, training_run_ids: List[str], run_free_policy_ids: List[str], eval_names: List[str]):
        """Get available metrics for selected policies and evaluations."""
        url = "/heatmap/metrics"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names
        }
        return await self._make_request("POST", url, json=payload)
    
    async def generate_heatmap(self, training_run_ids: List[str], run_free_policy_ids: List[str], 
                        eval_names: List[str], metric: str, policy_selector: str = "best"):
        """Generate heatmap data."""
        url = "/heatmap/heatmap"
        payload = {
            "training_run_ids": training_run_ids,
            "run_free_policy_ids": run_free_policy_ids,
            "eval_names": eval_names,
            "metric": metric,
            "training_run_policy_selector": policy_selector
        }
        return await self._make_request("POST", url, json=payload)

    async def get_all_training_runs(self, search_text: Optional[str] = None, page_size: int = 100):
        """Get all training run names."""
        url = "/heatmap/policies"
        payload = {
            "search_text": search_text,  # Use None instead of empty string
            "pagination": {"page": 1, "page_size": page_size}
        }
        return await self._make_request("POST", url, json=payload)

# Test the fixed client with better debugging
async def test_api_connection_fixed(api_base_url: str, auth_token: str):
    """Test API connection with the fixed client."""
    print(f"🧪 Testing fixed API client...")
    client = MettaAPIClient(api_base_url, auth_token)
    
    try:
        # Test getting policies
        print("\n📋 Testing /heatmap/policies endpoint...")
        policies_response = await client.get_policies(page_size=5)
        print(f"✅ Success! Got response: {type(policies_response)}")
        if isinstance(policies_response, dict) and 'policies' in policies_response:
            print(f"📊 Found {len(policies_response['policies'])} policies")
            if policies_response['policies']:
                print(f"🔍 Sample policy: {policies_response['policies'][0]}")
        else:
            print(f"⚠️  Unexpected response structure: {policies_response}")
        return True
        
    except Exception as e:
        print(f"❌ Test failed: {e}")
        return False

# Run the test
print("🚀 Testing the fixed API client...")
test_result = await test_api_connection_fixed(
    api_base_url=api_base_url,
    auth_token=auth_token
)


🚀 Testing the fixed API client...
🧪 Testing fixed API client...
https://api.observatory.softmax-research.net

📋 Testing /heatmap/policies endpoint...
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/policies
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'search_text': None, 'pagination': {'page': 1, 'page_size': 5}}
📨 Response status: 200
✅ Success! Got response: <class 'dict'>
📊 Found 5 policies
🔍 Sample policy: {'id': 'db04a711-77fb-4c4d-bee6-bfac77356980', 'type': 'training_run', 'name': 'jack.research.0724_1312', 'user_id': 'jack@stem.ai', 'created_at': '2025-07-25 09:36:30.650867', 'tags': ['research', 'experiment', 'user:unknown']}


In [159]:
async def fetch_real_heatmap_data(
    training_run_names: List[str],
    metrics: List[str],
    policy_selector: str = "best",
    api_base_url: str = "http://localhost:8000",
    auth_token: Optional[str] = None,
    max_policies: int = 20
) -> HeatmapWidget:
    """
    Fetch real evaluation data using the metta HTTP API (same as repo.ts).
    
    Args:
        training_run_names: List of training run names (e.g., ["daveey.arena.rnd.16x4.2"])
        metrics: List of metrics to include (e.g., ["reward", "heart.get"])
        policy_selector: "best" or "latest" policy selection strategy
        api_base_url: Base URL for the metta API from ~/.metta/observatory_tokens.yaml
        auth_token: Auth token from ~/.metta/observatory_tokens.yaml
        max_policies: Maximum number of policies to display
        
    Returns:
        HeatmapWidget with real data
    """
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        
        # Step 1: Get available policies to find training run IDs
        policies_data = await client.get_policies(page_size=100)
        
        # Find training run IDs that match our training run names
        training_run_ids = []
        for policy in policies_data["policies"]:
            if policy["type"] == "training_run" and any(run_name in policy["name"] for run_name in training_run_names):
                training_run_ids.append(policy["id"])
        
        if not training_run_ids:
            print(f"❌ No training runs found matching: {training_run_names}")
            return create_heatmap_widget()
        
        # Step 2: Get available evaluations for these training runs
        eval_names = await client.get_eval_names(training_run_ids, [])
        if not eval_names:
            print("❌ No evaluations found for selected training runs")
            return create_heatmap_widget()
        
        # Step 3: Get available metrics 
        available_metrics = await get_available_metrics(training_run_ids, eval_names, api_base_url)
        if not available_metrics:
            print("❌ No metrics found")
            return create_heatmap_widget()
        
        # Filter to requested metrics that actually exist
        valid_metrics = [m for m in metrics if m in available_metrics]
        if not valid_metrics:
            print(f"❌ None of the requested metrics {metrics} are available")
            return create_heatmap_widget()
        
        # Step 4: Generate heatmap for the first metric
        primary_metric = valid_metrics[0]
        keys = eval_names
        heatmap_data = await client.generate_heatmap(
            training_run_ids, [], keys, primary_metric, policy_selector
        )
        
        if not heatmap_data["policyNames"]:
            print("❌ No heatmap data generated")
            return create_heatmap_widget()
        
        # Limit policies if requested
        policy_names = heatmap_data["policyNames"]
        if len(policy_names) > max_policies:
            # Sort by average score and take top N
            avg_scores = heatmap_data["policyAverageScores"]
            top_policies = sorted(avg_scores.keys(), key=lambda p: avg_scores[p], reverse=True)[:max_policies]
            
            # Filter the data
            filtered_cells = {p: heatmap_data["cells"][p] for p in top_policies if p in heatmap_data["cells"]}
            heatmap_data["policyNames"] = top_policies
            heatmap_data["cells"] = filtered_cells
            heatmap_data["policyAverageScores"] = {p: avg_scores[p] for p in top_policies if p in avg_scores}
        
        # Step 5: Convert to widget format
        cells = {}
        for policy_name in heatmap_data["policyNames"]:
            cells[policy_name] = {}
            for eval_name in heatmap_data["evalNames"]:
                cell = heatmap_data["cells"].get(policy_name, {}).get(eval_name, {})
                cells[policy_name][eval_name] = {
                    'metrics': {primary_metric: cell.get("value", 0.0)},
                    'replayUrl': cell.get("replayUrl"),
                    'evalName': eval_name
                }
        
        # Create widget
        widget = create_heatmap_widget()
        widget.set_multi_metric_data(
            cells=cells,
            eval_names=heatmap_data["evalNames"], 
            policy_names=heatmap_data["policyNames"],
            metrics=[primary_metric],
            selected_metric=primary_metric
        )
        
        return widget
        
    except httpx.ConnectError:
        print("❌ Could not connect to metta API server")
        print("💡 Check ~/.metta/observatory_tokens.yaml for the correct API base URL and auth token")
        print("💡 Check if app_backend server is running on http://localhost:8000")
        print("💡 You can start it with: cd app_backend && uv run python server.py")
        return create_heatmap_widget()
    except Exception as e:
        print(f"❌ Error fetching real data: {e}")
        return create_heatmap_widget()

async def get_available_policies(api_base_url: str = "http://localhost:8000", limit: int = 50):
    """Get available policies and training runs."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        return await client.get_policies(page_size=limit)
    except Exception as e:
        print(f"❌ Error fetching policies: {e}")
        return {"policies": []}

async def get_available_eval_names(training_run_ids: List[str], api_base_url: str = "http://localhost:8000"):
    """Get available evaluation names."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        return await client.get_eval_names(training_run_ids, [])
    except Exception as e:
        print(f"❌ Error fetching eval names: {e}")
        return []

async def get_available_metrics(training_run_ids: List[str], eval_names: List[str], api_base_url: str = "http://localhost:8000"):
    """Get available metrics."""
    try:
        client = MettaAPIClient(api_base_url, auth_token)
        
        # Debug: Let's try with a simplified version first
        print(f"🔍 Attempting to get metrics for {len(eval_names)} eval names...")
        print(f"📋 Eval names: {eval_names}")
        
        return await client.get_available_metrics(training_run_ids, [], eval_names)
        
    except Exception as e:
        print(f"❌ Error fetching metrics: {e}")
        print(f"📋 Eval names that caused the error: {eval_names}")
        print(f"🔧 This might be a server-side issue with processing eval_names list")
        
        # Let's try with a single eval name to see if it's a list processing issue
        if eval_names:
            try:
                print(f"🔄 Trying with just the first eval name: {eval_names[0]}")
                result = await client.get_available_metrics(training_run_ids, [], [eval_names[0]])
                print(f"✅ Single eval name worked! Got {len(result)} metrics")
                return result
            except Exception as e2:
                print(f"❌ Single eval name also failed: {e2}")
        
        print(f"💡 Falling back to common metrics...")
        # Return some common metrics as fallback
        return ["reward", "heart.get", "ore_red.get", "action.move.success"]

print("🚀 Metta HTTP API client loaded!")
print("📋 Available functions:")
print("   - fetch_real_heatmap_data(training_run_names, metrics, policy_selector)")
print("   - get_available_policies(api_base_url)")  
print("   - get_available_eval_names(training_run_ids, api_base_url)")
print("   - get_available_metrics(training_run_ids, eval_names, api_base_url)")
print("")
print("💡 This uses HTTP API calls exactly like repo.ts!")
print("🔗 Local development needs app_backend server on http://localhost:8000 and an auth token from ~/.metta/observatory_tokens.yaml")
print("🚀 Start with: cd app_backend && uv run python server.py")


🚀 Metta HTTP API client loaded!
📋 Available functions:
   - fetch_real_heatmap_data(training_run_names, metrics, policy_selector)
   - get_available_policies(api_base_url)
   - get_available_eval_names(training_run_ids, api_base_url)
   - get_available_metrics(training_run_ids, eval_names, api_base_url)

💡 This uses HTTP API calls exactly like repo.ts!
🔗 Local development needs app_backend server on http://localhost:8000 and an auth token from ~/.metta/observatory_tokens.yaml
🚀 Start with: cd app_backend && uv run python server.py


## Example: Using Real Data

Now let's explore what's available in the database and create a heatmap with real data:


In [160]:
# For now, let's try with some common metrics and see what we find:
real_heatmap = None
try:
    # specific_runs = [
    #     "daveey.arena.rnd.16x4.2",
    #     "relh.skypilot.fff.j20.666",
    #     "bullm.navigation.low_reward.baseline",
    #     "bullm.navigation.low_reward.baseline.07-17", 
    #     "bullm.navigation.low_reward.baseline.07-23",
    #     "relh.multigpu.fff.1",
    #     "relh.skypilot.fff.j21.2",
    # ]
    
    # Common metrics that are likely to exist:
    metrics_to_fetch = ["reward", "heart.get", "ore_red.get", "action.move.success"]

    print(api_base_url)
    client = MettaAPIClient(api_base_url, auth_token)
    runs = await client.get_all_training_runs()
    run_names = [policy["name"] for policy in runs["policies"]]
    
    real_heatmap = await fetch_real_heatmap_data(
        api_base_url=api_base_url,
        auth_token=auth_token,
        training_run_names=run_names,
        metrics=metrics_to_fetch,
        max_policies=50  # Limit display to keep it manageable
    )
    
    # Display the widget
    real_heatmap
    
except Exception as e:
    print(f"❌ Error fetching real data: {e}")
    print("💡 This might happen if:")
    print("   - The database URI is incorrect")
    print("   - You're not authenticated with wandb")
    print("   - The specified metrics don't exist in the database")
    print("   - You don't have access to the database")
    print("\n🔄 Falling back to demo data...")
    
    # Fall back to demo data if real data fails
    from experiments.notebooks.utils.heatmap_widget import create_demo_heatmap
    demo_fallback = create_demo_heatmap()
    demo_fallback

real_heatmap

https://api.observatory.softmax-research.net
https://api.observatory.softmax-research.net
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/policies
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'search_text': None, 'pagination': {'page': 1, 'page_size': 100}}
📨 Response status: 200
https://api.observatory.softmax-research.net
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/policies
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'search_text': None, 'pagination': {'page': 1, 'page_size': 100}}
📨 Response status: 200
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/evals
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'training_run_ids': ['67bfefee-82ca-415d-9f00-7357640107d5', 'db04a711-77fb-4c4d-bee6-bfac77356980', 'fe6d7ecb-3f13-4cc9-9955-5af1341d2e2d', '807c8e5b-f112-49c5-b344-9e97cb4e3d91', 'db92

HeatmapWidget(heatmap_data={'cells': {'yudhister.recipes.arena.2x4.efficiency_baseline.07-24-00-17:v19': {'are…

## Advanced Usage: Custom Training Runs and Metrics

Here's how to create a heatmap with your own training runs and metrics using the smart policy selection:


In [161]:
# Example: Create a custom heatmap with specific training runs and metrics

# Option 1: Use training run names (recommended - uses smart selection)
my_training_runs = [
    # Add your training run names here, for example:
    "relh.multigpu.fff.1",
    "relh.skypilot.fff.j21.2",
    "relh.skypilot.fff.j20.666",
]

# Option 2: Use exact policy URIs (if you know exactly which ones you want)
my_specific_policies = [
    # Add exact policy URIs here, for example:
    # "my_policy_name:v123",
    # "baseline_experiment:v456",
    # "new_approach:v789"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
if my_training_runs:  # Use smart policy selection from training runs
    print("🎯 Creating custom heatmap with best policies from training runs...")
    
    # Select best policies from training runs
    custom_heatmap = await fetch_real_heatmap_data(
        api_base_url=api_base_url,
        auth_token=auth_token,
        training_run_names=my_training_runs,
        metrics=my_metrics,
        policy_selector="best",
        max_policies=20
    )
    
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")
    
    custom_heatmap
    
else:
    print("📝 To use this example:")
    print("   - Add your training run names to 'my_training_runs' list above")

🎯 Creating custom heatmap with best policies from training runs...
https://api.observatory.softmax-research.net
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/policies
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'search_text': None, 'pagination': {'page': 1, 'page_size': 100}}
📨 Response status: 200
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/evals
🔑 Headers: {'Content-Type': 'application/json', 'X-Auth-Token': ''}
📦 Payload: {'training_run_ids': ['bf7df3be-f26f-43b7-be8e-6ea22998f120', '5c94c301-15f1-4983-be37-e9ef7d3aee33'], 'run_free_policy_ids': []}
📨 Response status: 200
https://api.observatory.softmax-research.net
🔍 Attempting to get metrics for 6 eval names...
📋 Eval names: ['arena/advanced', 'arena/advanced_poor', 'arena/basic', 'arena/combat', 'arena/tag', 'eval/training_task']
🔍 Making POST request to: https://api.observatory.softmax-research.net/heatmap/metrics
🔑 Headers: {

## Available Evaluation Database URIs

You can choose from several different evaluation databases depending on what type of data you want to analyze:


## Advanced Usage: Custom Policies and Metrics

Here's how to create a heatmap with specific policies and metrics of your choice:


In [162]:
# Example: Create a custom heatmap with specific policies and metrics

# Step 1: Define your policies of interest
my_policies = [
    # Add your policy names here, for example:
    # "my_experiment_1:v100",
    # "my_experiment_2:v200", 
    # "baseline_policy:v50"
]

# Step 2: Define metrics you want to compare
my_metrics = [
    "reward",
    "heart.get",           # Example game-specific metric
    "action.move.success", # Example action success rate
    # Add more metrics as needed
]

# Step 3: Optional - filter to specific evaluations
# eval_filter = "sim_env LIKE '%maze%'"  # Only maze environments
# eval_filter = "sim_env LIKE '%combat%'"  # Only combat environments  
eval_filter = None  # No filter - include all evaluations

# Step 4: Create the heatmap
custom_heatmap = None
if my_policies:  # Only run if you've specified policies
    print("🎯 Creating custom heatmap...")
    custom_heatmap = fetch_real_heatmap_data(
        api_base_url=api_base_url,
        auth_token=auth_token,
        training_run_names=my_policies,
        metrics=my_metrics,
        max_policies=20
    )
    
    # Step 5: Display and interact
    print("📊 Custom heatmap created! Try:")
    print("   - Hovering over cells to see detailed values")
    print("   - Changing metrics with: custom_heatmap.update_metric('heart.get')")
    print("   - Adjusting policies shown: custom_heatmap.set_num_policies(15)")

else:
    print("📝 To use this example:")
    print("1. Uncomment the exploration code in the previous cell to see available policies")
    print("2. Add your policy names to the 'my_policies' list above") 
    print("3. Customize the 'my_metrics' list with metrics you're interested in")
    print("4. Run this cell again")
    print("\n💡 Example policy names might look like:")
    print("   - 'my_policy_name:v123'")
    print("   - 'baseline_experiment:v456'") 
    print("   - 'new_approach:v789'")

custom_heatmap

📝 To use this example:
1. Uncomment the exploration code in the previous cell to see available policies
2. Add your policy names to the 'my_policies' list above
3. Customize the 'my_metrics' list with metrics you're interested in
4. Run this cell again

💡 Example policy names might look like:
   - 'my_policy_name:v123'
   - 'baseline_experiment:v456'
   - 'new_approach:v789'


## Example 4: Multiple Metrics with Working selectedMetric

Now let's see the `selectedMetric` functionality working properly! This example shows a heatmap where changing the metric actually changes the displayed values:


In [163]:
# Create a multi-metric heatmap widget
from experiments.notebooks.utils.heatmap_widget import create_multi_metric_demo

multi_metric_widget = create_multi_metric_demo()

# Display the widget
multi_metric_widget


🎯 Creating multi-metric demo heatmap widget...
🚀 HeatmapWidget initialized successfully!
📊 Multi-metric data set with 3 policies and 4 evaluations
📈 Available metrics: reward, episode_length, success_rate, completion_time
📈 Selected metric: reward
✅ Multi-metric demo heatmap widget created!
📈 Try widget.update_metric('episode_length') to see values change!


HeatmapWidget(heatmap_data={'cells': {'policy_alpha_v1': {'navigation/maze1': {'metrics': {'reward': 85.2, 'ep…

In [164]:
# Now try changing the metric to see the values actually change!
print("🔄 Changing metric to 'episode_length'...")
multi_metric_widget.update_metric('episode_length')

# NOTE: Notice how the values in the heatmap widget change as you switch
# metrics?  Do not display the widget again and try to change that. That ends up
# creating a seperate copy of the widget in a new output cell.  Instead just
# reference the one you originally rendered, call its functions, and watch it
# change in its Juypter notebook cell. Like we just did. Let's do it again in
# the next cell too.


🔄 Changing metric to 'episode_length'...
📊 Metric changed to: episode_length


In [165]:
# One more time. Run this cell then scroll back up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')



🔄 Changing metric to 'success_rate'...
📊 Metric changed to: success_rate


In [166]:
# Last one. Scroll up again to see the change.
print("\n🔄 Changing metric to 'success_rate'...")
multi_metric_widget.update_metric('success_rate')



🔄 Changing metric to 'success_rate'...
📈 Metric already set to: success_rate


# Custom metrics

We can really define our cells to have any metric data we want. This is useful because we plan to have all sorts of metrics. Let's look at an example of using any old metric we decide.

Here's how to create a heatmap with your own data:


In [167]:
# Create a new heatmap widget
custom_widget = create_heatmap_widget()

# Define your data structure
# This should match the format expected by the observatory dashboard
cells_data = {
    'my_policy_v1': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 85.2,
            },
            'replayUrl': 'https://example.com/replay1.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 87.5,
            },
            'replayUrl': 'https://example.com/replay2.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 92.5,
            },
            'replayUrl': 'https://example.com/replay3.json', 
            'evalName': 'task_b/challenge1'
        },
    },
    'my_policy_v2': {
        'task_a/level1': {
            'metrics': {
                'custom_score': 22.5,
            },
            'replayUrl': 'https://example.com/replay4.json', 
            'evalName': 'task_a/level1'
        },
        'task_a/level2': {
            'metrics': {
                'custom_score': 42.5,
            },
            'replayUrl': 'https://example.com/replay5.json', 
            'evalName': 'task_a/level2'
        },
        'task_b/challenge1': {
            'metrics': {
                'custom_score': 62.5,
            },
            'replayUrl': 'https://example.com/replay6.json', 
            'evalName': 'task_b/challenge1'
        },
    },
}

eval_names = ['task_a/level1', 'task_a/level2', 'task_b/challenge1']
policy_names = ['my_policy_v1', 'my_policy_v2']
policy_averages = {
    'my_policy_v1': 91.6,
    'my_policy_v2': 89.6,
}

# Set the data
custom_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="custom_score"
)

# Display the widget
custom_widget


🚀 HeatmapWidget initialized successfully!
📊 Data set with 2 policies and 3 evaluations
📈 Selected metric: custom_score


HeatmapWidget(heatmap_data={'cells': {'my_policy_v1': {'task_a/level1': {'metrics': {'custom_score': 85.2}, 'r…

## Adding Callbacks for Interactivity

You can add Python callbacks to respond to user interactions:


In [168]:
# NOTE: these callbacks do not work with print(), and that's really just how
# Jupyter widgets work.  Once the Jupyter python cell finishes running and
# outputs a widget, that widget won't be able to affect the output of the cell
# anymore. The only way to to print() from a python widget callback is to write
# to a file (or use a thread maybe). I give an example below.

# Create another widget for callback demonstration
callback_widget = create_heatmap_widget()

# Set up the same data as before
callback_widget.set_data(
    cells=cells_data,
    eval_names=eval_names,
    policy_names=policy_names,
    policy_average_scores=policy_averages,
    selected_metric="Interactive Score (%)"
)

# Define callback functions
def handle_cell_selection(cell_info):
    """Called when user hovers over a cell (not 'overall' column)."""
    with open("output_cell_selection.txt", "w") as f:
        f.write(f"📍 Cell selected: {cell_info['policyUri']} on evaluation '{cell_info['evalName']}'")

def handle_replay_opened(replay_info):
    """Called when user clicks to open a replay."""
    with open("output_replay_opened.txt", "w") as f:
        f.write(f"🎬 Replay opened: {replay_info['replayUrl']}")
        f.write(f"   Policy: {replay_info['policyUri']}")
        f.write(f"   Evaluation: {replay_info['evalName']}")

# Register the callbacks
callback_widget.on_cell_selected(handle_cell_selection)
callback_widget.on_replay_opened(handle_replay_opened)

# Display the widget
callback_widget


🚀 HeatmapWidget initialized successfully!
📊 Data set with 2 policies and 3 evaluations
📈 Selected metric: Interactive Score (%)


HeatmapWidget(heatmap_data={'cells': {'my_policy_v1': {'task_a/level1': {'metrics': {'custom_score': 85.2}, 'r…

In [169]:
# Delete the files created by the callbacks in the previous cell, if they exist
import os

for fname in ["output_cell_selection.txt", "output_replay_opened.txt"]:
    try:
        with open(fname, "r") as f:
            print(f.read())
        os.remove(fname)
        print(f"File {fname} deleted")
    except FileNotFoundError:
        pass


**Try interacting with the heatmap above to see the callback messages printed to
*output files!**

## Data Format Reference

The heatmap widget expects data in a specific format that matches the
observatory dashboard:

```python
cells = {
    'policy_name': {
        'eval_name': {
            'metrics': {
                'reward': 50,
                'heart.get': 98,
                'action.move.success': 5,
                'ore_red.get': 24.2,
                # ... more metrics
            },
            'replayUrl': str,         # URL to replay file
            'evalName': str,          # Should match the key
        },
        # ... more evaluations
    },
    # ... more policies
}
```

**Important notes:**
- Evaluation names with "/" will be grouped by category (the part before "/")
- The heatmap shows policies sorted by average score (worst to best, bottom to top)
- Policy names that contain ":v" will have WandB URLs generated automatically
- Replay URLs should be accessible URLs or file paths

This widget provides the same interactive functionality as the observatory dashboard but in a python environment, making it perfect for exploratory analysis and sharing results via Jupyter notebooks!
