# Notebook 02: Interactive Heuristic Tuning Dashboard

**Project:** The Causal Scrutinizer

**Objective:** This notebook provides a rapid, visual feedback loop for tuning the weights of our criticality heuristics. Its purpose is to help us understand the relationship between our calculated scores and the actual visual events in a scenario.

**Workflow:**
1. **(Slow Part) Load a random scenario:** Run Cell #2 to load the pre-computed scores and raw data path for a random scenario from our validation set.
2. **(Fast Part) Tune the weights:** Use the interactive sliders in Cell #3 to see how changing the heuristic weights affects the final, per-timestep criticality score. This is instantaneous.
3. **(Slow Part) Verify visually:** Once you find an interesting score profile, run Cell #4 to generate the high-fidelity `pygame` GIF for that scenario. This allows you to confirm if a peak in the score corresponds to a genuinely critical event.

In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
import random
from ipywidgets import interact, FloatSlider, VBox, HBox, Output, Button
from IPython.display import display, clear_output, Image as IPImage

# --- Add project root to path for our src imports ---
# This ensures that no matter where the notebook is, we start from the project base
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from src.utils.config_loader import load_config
from src.utils.generate_gif_for_notebook import render_gif_for_scenario

# --- Load project config from the correct location ---
# The config file is in the root, so we build the path from PROJECT_ROOT
config = load_config(config_path=os.path.join(PROJECT_ROOT, 'configs/main_config.yaml'))
sns.set_theme(style="whitegrid")

# --- NEW: Define a root for all our output directories ---
OUTPUT_DIR_ROOT = os.path.join(PROJECT_ROOT, 'outputs')

print("Setup Complete.")

pygame 2.6.1 (SDL 2.28.4, Python 3.10.19)
Hello from the pygame community. https://www.pygame.org/contribute.html
✅ Configuration loaded successfully from: /home/anton/casual_scrutinizer/configs/main_config.yaml
Setup Complete.


  from pkg_resources import resource_stream, resource_exists


### Cell #2: Load a New Random Scenario

Run this cell to load a new scenario. This performs the slow I/O operations.

In [24]:
# --- Global variables to hold the currently loaded data ---
CURRENT_SCENARIO_ID = None
CURRENT_SCORES_DATA = None

def load_new_random_scenario(scenario_id=None):
    """Loads a new scenario into the global variables. If scenario_id is provided, loads that specific scenario."""
    global CURRENT_SCENARIO_ID, CURRENT_SCORES_DATA
    
    # --- CORRECTED PATH LOGIC ---
    # We build the full, absolute path from the project root.
    scores_dir_config_key = config['data']['criticality_scores_dir'] # e.g., "outputs/criticality_scores"
    # We need to construct the path relative to the PROJECT_ROOT, not the notebook's CWD
    scores_dir_abs = os.path.join(PROJECT_ROOT, scores_dir_config_key, 'validation')
    
    all_score_files = glob(os.path.join(scores_dir_abs, '*.npz'))

    if not all_score_files:
        print(f"❌ CRITICAL ERROR: No score files found in '{scores_dir_abs}'.")
        print("   Please run the 'score_criticality_heuristic.py' script first and check your config path.")
        return

    if scenario_id:
        specific_file_path = os.path.join(scores_dir_abs, f"{scenario_id}.npz")
        if specific_file_path in all_score_files:
            random_score_file_path = specific_file_path
        else:
            print(f"❌ ERROR: Scenario ID '{scenario_id}' not found in '{scores_dir_abs}'.")
            return
    else:
        random_score_file_path = random.choice(all_score_files)
    
    CURRENT_SCENARIO_ID = os.path.splitext(os.path.basename(random_score_file_path))[0]
    
    CURRENT_SCORES_DATA = np.load(random_score_file_path)
    print(f"--- ✅ Successfully Loaded Scenario: {CURRENT_SCENARIO_ID} ---")
    
# --- Initial Load ---
load_new_random_scenario("599e00c98611394e")

--- ✅ Successfully Loaded Scenario: 599e00c98611394e ---


### Cell #3: Interactive Plotting Dashboard

Adjust the sliders to recalculate the final score in real-time. The plot will update instantly.

In [25]:
def plot_weighted_scores(w_volatility, w_interaction, w_off_road, w_density, w_lane_deviation):
    """(Fast Version) Re-calculates and plots scores using pre-loaded data."""
    if CURRENT_SCORES_DATA is None:
        print("Please run the cell above to load a scenario first.")
        return

    # --- Use pre-loaded data ---
    volatility = CURRENT_SCORES_DATA['volatility']
    interaction = CURRENT_SCORES_DATA['interaction']
    off_road = CURRENT_SCORES_DATA['off_road']
    density = CURRENT_SCORES_DATA['density']
    lane_deviation = CURRENT_SCORES_DATA['lane_deviation']
    
    # --- Combine Scores ---
    final_scores = (w_volatility * volatility +
                    w_interaction * interaction +
                    w_off_road * off_road +
                    w_density * density +
                    w_lane_deviation * lane_deviation)

    # Normalize by the sum of weights to keep the score between 0 and 1
    total_weight = w_volatility + w_interaction + w_off_road + w_density + w_lane_deviation
    if total_weight > 0:
        final_scores /= total_weight

    # --- Plotting ---
    fig, axes = plt.subplots(2, 1, figsize=(16, 8), sharex=True, gridspec_kw={'height_ratios': [1, 2]})
    fig.suptitle(f'Interactive Weight Tuning for Scenario: {CURRENT_SCENARIO_ID}', fontsize=16)

    timesteps = np.arange(len(final_scores))
    axes[0].plot(timesteps, final_scores, color='black', linewidth=2.5, label='Final Combined Score')
    axes[0].set_ylabel('Final Score')
    axes[0].set_ylim(0, 1.0)
    
    most_critical_timestep = np.argmax(final_scores)
    peak_score = np.max(final_scores)
    axes[0].axvline(x=most_critical_timestep, color='r', linestyle='--', label=f'Peak Score Timestep: {most_critical_timestep}')
    axes[0].annotate(f'Peak: {peak_score:.2f}', xy=(most_critical_timestep, peak_score), 
                     xytext=(most_critical_timestep + 2, peak_score + 0.1), fontsize=12, color='red')
    axes[0].legend()
    axes[0].grid(True, which='both', linestyle=':')
    
    # Plot individual components
    axes[1].plot(timesteps, volatility, label=f'Volatility (w={w_volatility:.2f})', alpha=0.8)
    axes[1].plot(timesteps, interaction, label=f'Interaction (w={w_interaction:.2f})', alpha=0.8)
    axes[1].plot(timesteps, lane_deviation, label=f'Lane Deviation (w={w_lane_deviation:.2f})', alpha=0.8)
    axes[1].plot(timesteps, off_road, label=f'Off-Road (w={w_off_road:.2f})', alpha=0.8)
    axes[1].plot(timesteps, density, label=f'Density (w={w_density:.2f})', alpha=0.8)
    axes[1].set_xlabel('Timestep (0.1s increments)')
    axes[1].set_ylabel('Raw Score Components')
    axes[1].legend(loc='upper right')
    axes[1].grid(True, which='both', linestyle=':')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# --- Create the interactive widget ---
# We use the weights from our config file as the default values
default_weights = config['scoring']['heuristic']
interact(
    plot_weighted_scores,
    w_volatility=FloatSlider(min=0.0, max=1.0, step=0.05, value=default_weights['weight_volatility'], description='Volatility'),
    w_interaction=FloatSlider(min=0.0, max=1.0, step=0.05, value=default_weights['weight_interaction'], description='Interaction'),
    w_lane_deviation=FloatSlider(min=0.0, max=1.0, step=0.05, value=default_weights['weight_lane_deviation'], description='Lane Deviate'),
    w_off_road=FloatSlider(min=0.0, max=1.0, step=0.05, value=default_weights['weight_off_road'], description='Off-Road'),
    w_density=FloatSlider(min=0.0, max=1.0, step=0.05, value=default_weights['weight_density'], description='Density')
)

interactive(children=(FloatSlider(value=0.4, description='Volatility', max=1.0, step=0.05), FloatSlider(value=…

<function __main__.plot_weighted_scores(w_volatility, w_interaction, w_off_road, w_density, w_lane_deviation)>

### Cell #4: Generate and Display GIF

Run this cell to generate the `pygame` GIF for the currently loaded scenario and display it below.

In [None]:
# This cell handles the GIF generation and display
output_panel = Output()
gif_button = Button(description="Generate/Show GIF", button_style='success')

def on_gif_button_clicked(b):
    with output_panel:
        clear_output(wait=True)
        if CURRENT_SCENARIO_ID is None:
            print("No scenario loaded. Please run Cell #2 first.")
            return
            
        # --- CORRECTED PATH LOGIC ---
        # Build the output path from our OUTPUT_DIR_ROOT
        gif_output_dir = os.path.join(OUTPUT_DIR_ROOT, 'notebook_gifs')
        
        # The helper script handles the actual rendering
        gif_path = render_gif_for_scenario(CURRENT_SCENARIO_ID, gif_output_dir)
        
        if gif_path:
            # Display the generated GIF
            display(IPImage(filename=gif_path))

gif_button.on_click(on_gif_button_clicked)

# Display the button and the output area
display(VBox([gif_button, output_panel]))

VBox(children=(Button(button_style='success', description='Generate/Show GIF', style=ButtonStyle()), Output())…