In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import os
from gfn_environments.single_color_ramp import *

def screenshot_viewport_to_png(filepath: str, resolution_x: int = 800, resolution_y: int = 600):
    """
    Quick render of the scene from a default viewpoint (no camera setup needed).

    Args:
        filepath: Path where the PNG should be saved (e.g., "./output/screenshot.png")
        resolution_x: Width of the image in pixels
        resolution_y: Height of the image in pixels
    """
    # Ensure directory exists
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    scene = bpy.context.scene

    # Store original settings
    original_camera = scene.camera
    original_engine = scene.render.engine

    # Create a temporary camera if none exists
    camera_data = bpy.data.cameras.new(name="TempCamera")
    camera_object = bpy.data.objects.new("TempCamera", camera_data)
    bpy.context.collection.objects.link(camera_object)

    # Position camera to view the mesh (adjust these values as needed)
    camera_object.location = (7, -7, 5)
    camera_object.rotation_euler = (1.1, 0, 0.785)

    # Set as active camera
    scene.camera = camera_object

    # Add lighting - Sun light for overall illumination
    sun_data = bpy.data.lights.new(name="TempSun", type='SUN')
    sun_data.energy = 3.0  # Brightness
    sun_object = bpy.data.objects.new("TempSun", sun_data)
    bpy.context.collection.objects.link(sun_object)
    sun_object.location = (5, 5, 10)
    sun_object.rotation_euler = (0.7, 0.3, 0)

    # Add fill light for better visibility
    fill_data = bpy.data.lights.new(name="TempFill", type='AREA')
    fill_data.energy = 100.0
    fill_data.size = 5.0
    fill_object = bpy.data.objects.new("TempFill", fill_data)
    bpy.context.collection.objects.link(fill_object)
    fill_object.location = (-5, -5, 8)
    fill_object.rotation_euler = (1.2, 0, -0.785)

    # Configure render settings for quick preview
    scene.render.engine = 'BLENDER_EEVEE_NEXT'  # Fast rendering (Blender 4.x)
    scene.render.image_settings.file_format = 'PNG'
    scene.render.filepath = filepath
    scene.render.resolution_x = resolution_x
    scene.render.resolution_y = resolution_y
    scene.render.resolution_percentage = 100

    # Quick render
    bpy.ops.render.render(write_still=True)

    # Cleanup: remove temporary objects
    bpy.data.objects.remove(camera_object, do_unlink=True)
    bpy.data.cameras.remove(camera_data)
    bpy.data.objects.remove(sun_object, do_unlink=True)
    bpy.data.lights.remove(sun_data)
    bpy.data.objects.remove(fill_object, do_unlink=True)
    bpy.data.lights.remove(fill_data)

    # Restore original settings
    scene.camera = original_camera
    scene.render.engine = original_engine

    print(f"Viewport screenshot saved to: {filepath}")

def visualize_action_render_state(
    action_tensor: torch.Tensor,
    action_name: str,
    render_filepath: str,
    heightmap: torch.Tensor,
    previous_heightmap: torch.Tensor,  # Previous heightmap for diff
    state_tensor: torch.Tensor,
    state: 'State',  # Add state object to access color assignments
    output_filepath: str
):
    """
    Create a 6-panel figure showing:
    1. Action tensor (as human-readable visualization)
    2. Blender screenshot
    3. Color ramp visualization
    4. Previous Heightmap
    5. Current Heightmap
    6. Heightmap Diff (heatmap of change)

    Args:
        action_tensor: The action tensor applied
        action_name: Name of the action for labeling
        render_filepath: Path to the rendered PNG image
        heightmap: Tensor of the heightmap from Blender
        previous_heightmap: Previous heightmap for comparison
        state_tensor: The resulting state tensor
        state: State object to get color assignments
        output_filepath: Path to save the comparison figure
    """
    # Load the rendered image
    render_img = Image.open(render_filepath)

    # Convert tensors to numpy for visualization
    heightmap_np = heightmap.detach().cpu().numpy()
    previous_heightmap_np = previous_heightmap.detach().cpu().numpy()

    # Create figure with 6 subplots
    fig = plt.figure(figsize=(18, 14))
    gs = fig.add_gridspec(3, 3, hspace=0.4, wspace=0.3, height_ratios=[1, 0.3, 1])

    # 1. Action Tensor (top left)
    ax1 = fig.add_subplot(gs[0, 0])
    action_np = action_tensor.detach().cpu().numpy()
    ax1.bar(range(len(action_np)), action_np)
    ax1.set_title(f'Action Tensor: {action_name}', fontsize=12, fontweight='bold')
    ax1.set_xlabel('Action Index')
    ax1.set_ylabel('Value')
    ax1.grid(True, alpha=0.3)

    # 2. Blender Screenshot (top middle)
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(render_img)
    ax2.set_title('Blender Render', fontsize=12, fontweight='bold')
    ax2.axis('off')

    # 3. State Tensor (top right)
    ax3 = fig.add_subplot(gs[0, 2])
    state_np = state_tensor.detach().cpu().numpy()
    ax3.bar(range(len(state_np)), state_np)
    ax3.set_title('State Tensor', fontsize=12, fontweight='bold')
    ax3.set_xlabel('State Index')
    ax3.set_ylabel('Value')
    ax3.grid(True, alpha=0.3)

    # 4. Color Ramp Visualization (middle row, spans all columns)
    ax4 = fig.add_subplot(gs[1, :])
    if state.color_assignments:
        # Create a color gradient showing the ramp
        num_slots = len(state.color_assignments)
        colors_array = np.zeros((1, 100, 3))

        # Get sorted slot assignments
        sorted_slots = sorted(state.color_assignments.items())

        for i in range(100):
            # Interpolate position in ramp
            position = i / 99.0 * (num_slots - 1)
            lower_idx = int(np.floor(position))
            upper_idx = min(int(np.ceil(position)), num_slots - 1)
            blend = position - lower_idx

            # Get colors from palette
            lower_color = ActionRegistry.COLOR_PALETTE[sorted_slots[lower_idx][1]]
            upper_color = ActionRegistry.COLOR_PALETTE[sorted_slots[upper_idx][1]]

            # Blend colors
            colors_array[0, i] = [
                lower_color[0] * (1 - blend) + upper_color[0] * blend,
                lower_color[1] * (1 - blend) + upper_color[1] * blend,
                lower_color[2] * (1 - blend) + upper_color[2] * blend
            ]

        ax4.imshow(colors_array, aspect='auto')
        ax4.set_title(f'Color Ramp ({len(state.color_assignments)} colors)', fontsize=12, fontweight='bold')

        # Add color labels
        label_text = " â†’ ".join([
            f"Slot {slot}: RGB{tuple(int(c*255) for c in ActionRegistry.COLOR_PALETTE[palette_idx])}"
            for slot, palette_idx in sorted_slots
        ])
        ax4.set_xlabel(label_text, fontsize=9)
    else:
        ax4.text(0.5, 0.5, 'No colors assigned yet',
                ha='center', va='center', fontsize=14)
        ax4.set_title('Color Ramp (empty)', fontsize=12, fontweight='bold')

    ax4.set_yticks([])
    ax4.set_xticks([])

    # 5. Previous Heightmap (bottom left)
    ax5 = fig.add_subplot(gs[2, 0])
    im5 = ax5.imshow(previous_heightmap_np, cmap='terrain', interpolation='nearest')
    ax5.set_title('Previous Heightmap', fontsize=12, fontweight='bold')
    ax5.axis('off')
    plt.colorbar(im5, ax=ax5, fraction=0.046, pad=0.04)

    # 6. Current Heightmap (bottom middle)
    ax6 = fig.add_subplot(gs[2, 1])
    im6 = ax6.imshow(heightmap_np, cmap='terrain', interpolation='nearest')
    ax6.set_title('Current Heightmap', fontsize=12, fontweight='bold')
    ax6.axis('off')
    plt.colorbar(im6, ax=ax6, fraction=0.046, pad=0.04)

    # 7. Heightmap Diff (bottom right)
    ax7 = fig.add_subplot(gs[2, 2])
    diff = heightmap_np - previous_heightmap_np
    im7 = ax7.imshow(diff, cmap='RdBu_r', interpolation='nearest',
                     vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
    ax7.set_title('Heightmap Change (Diff)', fontsize=12, fontweight='bold')
    ax7.axis('off')
    cbar = plt.colorbar(im7, ax=ax7, fraction=0.046, pad=0.04)
    cbar.set_label('Change', rotation=270, labelpad=15)

    plt.savefig(output_filepath, dpi=150, bbox_inches='tight')
    plt.show()  # Display in notebook

    print(f"Visualization saved to: {output_filepath}")


def test_state_with_visualization():
    """Test workflow that creates visualization for each action"""
    load_blend_single_color_ramp()
    api_instance = BlenderTerrainAPI()
    state = get_initial_environment_state()

    # Initial state
    state.apply_to_blender(blender_api=api_instance)
    screenshot_viewport_to_png(filepath="./tests/file_dump/initial_screenshot.png")
    heightmap = api_instance.get_heightmap()
    previous_heightmap = heightmap.clone()  # Store for comparison
    state_tensor = state.to_state_tensor()

    # Create initial action tensor (all zeros for blank state)
    action_tensor = torch.zeros(State.get_action_tensor_dim())

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name="Initial (No Action)",
        render_filepath="./tests/file_dump/initial_screenshot.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=state,
        output_filepath="./tests/file_dump/viz_initial.png"
    )

    # Apply action 1: set_w
    action_name = 'set_w'
    value_idx = 2
    new_state = state.apply_action(action_name, value_idx)
    new_state.execute_action_on_blender(api_instance, action_name, value_idx)

    # Create action tensor (one-hot encoding of the action)
    action_tensor = torch.zeros(State.get_action_tensor_dim())
    action_offset = ActionRegistry.get_action_offset(action_name)
    action_tensor[action_offset + value_idx] = 1.0

    screenshot_viewport_to_png(filepath="./tests/file_dump/after_w_screenshot.png")
    previous_heightmap = heightmap.clone()
    heightmap = api_instance.get_heightmap()
    state_tensor = new_state.to_state_tensor()

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name=f"{action_name} (idx={value_idx})",
        render_filepath="./tests/file_dump/after_w_screenshot.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=new_state,
        output_filepath="./tests/file_dump/viz_after_w.png"
    )

    state = new_state

    # Apply action 2: set_scale
    action_name = 'set_scale'
    value_idx = 1
    new_state = state.apply_action(action_name, value_idx)
    new_state.execute_action_on_blender(api_instance, action_name, value_idx)

    action_tensor = torch.zeros(State.get_action_tensor_dim())
    action_offset = ActionRegistry.get_action_offset(action_name)
    action_tensor[action_offset + value_idx] = 1.0

    screenshot_viewport_to_png(filepath="./tests/file_dump/after_scale_screenshot.png")
    previous_heightmap = heightmap.clone()
    heightmap = api_instance.get_heightmap()
    state_tensor = new_state.to_state_tensor()

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name=f"{action_name} (idx={value_idx})",
        render_filepath="./tests/file_dump/after_scale_screenshot.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=new_state,
        output_filepath="./tests/file_dump/viz_after_scale.png"
    )

    state = new_state

    # Add colors to the ramp (reduced to MAX_COLORS - 5)
    num_colors_to_add = ActionRegistry.MAX_COLORS - 5
    for color_idx in range(num_colors_to_add):
        action_name = 'add_color'
        value_idx = color_idx % len(ActionRegistry.VALID_COLOR_INDICES)  # Cycle through available colors

        new_state = state.apply_action(action_name, value_idx)
        new_state.execute_action_on_blender(api_instance, action_name, value_idx)

        action_tensor = torch.zeros(State.get_action_tensor_dim())
        action_offset = ActionRegistry.get_action_offset(action_name)
        action_tensor[action_offset + value_idx] = 1.0

        screenshot_filepath = f"./tests/file_dump/after_color_{color_idx}_screenshot.png"
        viz_filepath = f"./tests/file_dump/viz_after_color_{color_idx}.png"

        screenshot_viewport_to_png(filepath=screenshot_filepath)
        previous_heightmap = heightmap.clone()
        heightmap = api_instance.get_heightmap()
        state_tensor = new_state.to_state_tensor()

        visualize_action_render_state(
            action_tensor=action_tensor,
            action_name=f"{action_name} #{color_idx} (palette_idx={value_idx})",
            render_filepath=screenshot_filepath,
            heightmap=heightmap,
            previous_heightmap=previous_heightmap,
            state_tensor=state_tensor,
            state=new_state,
            output_filepath=viz_filepath
        )

        state = new_state
        print(f"Added color {color_idx + 1}/{num_colors_to_add}")

    print("All visualizations complete!")

# Run in notebook
test_state_with_visualization()

#


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import os

def test_step_w_experiment_with_visualization():
    """Test StepWExperimentDefinition workflow with visualization for each action"""
    load_blend_single_color_ramp()
    blender_api = BlenderTerrainAPI()

    # Create trajectory
    trajectory = StepWExperimentDefinition.Trajectory(blender_api)

    # Initial state visualization
    screenshot_viewport_to_png(filepath="./tests/file_dump/stepw_initial_screenshot.png")
    heightmap = blender_api.get_heightmap()
    previous_heightmap = heightmap.clone()
    state_tensor = trajectory.get_state_tensor()

    # Create initial action tensor (all zeros for blank state)
    action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name="Initial (No Action)",
        render_filepath="./tests/file_dump/stepw_initial_screenshot.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=trajectory.current_state,
        output_filepath="./tests/file_dump/stepw_viz_initial.png"
    )

    print(f"Initial state - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")

    # Step 1: Step W up (by 0.1)
    action = StepWExperimentDefinition.Action(action_name='step_w', value_idx=0)
    trajectory.step(action, reward=0.0)

    action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
    action_tensor[action.to_flat_index()] = 1.0

    screenshot_viewport_to_png(filepath="./tests/file_dump/stepw_after_step1.png")
    previous_heightmap = heightmap.clone()
    heightmap = blender_api.get_heightmap()
    state_tensor = trajectory.get_state_tensor()

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name=f"step_w +{action.value}",
        render_filepath="./tests/file_dump/stepw_after_step1.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=trajectory.current_state,
        output_filepath="./tests/file_dump/stepw_viz_step1.png"
    )

    print(f"After step_w - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")

    # Step 2: Step W up again (by 0.1)
    action = StepWExperimentDefinition.Action(action_name='step_w', value_idx=0)
    trajectory.step(action, reward=0.0)

    action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
    action_tensor[action.to_flat_index()] = 1.0

    screenshot_viewport_to_png(filepath="./tests/file_dump/stepw_after_step2.png")
    previous_heightmap = heightmap.clone()
    heightmap = blender_api.get_heightmap()
    state_tensor = trajectory.get_state_tensor()

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name=f"step_w +{action.value}",
        render_filepath="./tests/file_dump/stepw_after_step2.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=trajectory.current_state,
        output_filepath="./tests/file_dump/stepw_viz_step2.png"
    )

    print(f"After 2nd step_w - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")

    # Step 3: Step Scale up (by 0.1)
    action = StepWExperimentDefinition.Action(action_name='step_scale', value_idx=0)
    trajectory.step(action, reward=0.0)

    action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
    action_tensor[action.to_flat_index()] = 1.0

    screenshot_viewport_to_png(filepath="./tests/file_dump/stepw_after_step3.png")
    previous_heightmap = heightmap.clone()
    heightmap = blender_api.get_heightmap()
    state_tensor = trajectory.get_state_tensor()

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name=f"step_scale +{action.value}",
        render_filepath="./tests/file_dump/stepw_after_step3.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=trajectory.current_state,
        output_filepath="./tests/file_dump/stepw_viz_step3.png"
    )

    print(f"After step_scale - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")

    # Step 4-8: Add some colors
    for color_idx in range(5):
        action = StepWExperimentDefinition.Action(action_name='add_color', value_idx=color_idx * 6)  # Spread out colors
        trajectory.step(action, reward=0.0)

        action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
        action_tensor[action.to_flat_index()] = 1.0

        screenshot_filepath = f"./tests/file_dump/stepw_after_color_{color_idx}.png"
        viz_filepath = f"./tests/file_dump/stepw_viz_color_{color_idx}.png"

        screenshot_viewport_to_png(filepath=screenshot_filepath)
        previous_heightmap = heightmap.clone()
        heightmap = blender_api.get_heightmap()
        state_tensor = trajectory.get_state_tensor()

        visualize_action_render_state(
            action_tensor=action_tensor,
            action_name=f"add_color palette_idx={action.value}",
            render_filepath=screenshot_filepath,
            heightmap=heightmap,
            previous_heightmap=previous_heightmap,
            state_tensor=state_tensor,
            state=trajectory.current_state,
            output_filepath=viz_filepath
        )

        print(f"Added color {color_idx + 1}/5 - Colors assigned: {trajectory.current_state.num_colors_assigned}")

    # Step 9: Step W a few more times to show continuous stepping
    for i in range(3):
        action = StepWExperimentDefinition.Action(action_name='step_w', value_idx=0)
        trajectory.step(action, reward=0.0)

        action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
        action_tensor[action.to_flat_index()] = 1.0

        screenshot_filepath = f"./tests/file_dump/stepw_final_step_{i}.png"
        viz_filepath = f"./tests/file_dump/stepw_viz_final_step_{i}.png"

        screenshot_viewport_to_png(filepath=screenshot_filepath)
        previous_heightmap = heightmap.clone()
        heightmap = blender_api.get_heightmap()
        state_tensor = trajectory.get_state_tensor()

        visualize_action_render_state(
            action_tensor=action_tensor,
            action_name=f"step_w +{action.value} (step {i+1})",
            render_filepath=screenshot_filepath,
            heightmap=heightmap,
            previous_heightmap=previous_heightmap,
            state_tensor=state_tensor,
            state=trajectory.current_state,
            output_filepath=viz_filepath
        )

        print(f"Final step_w {i+1} - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")

    # Final step: Stop
    action = StepWExperimentDefinition.Action(action_name='stop', value_idx=0)
    trajectory.step(action, reward=1.0)

    print(f"\nTrajectory complete!")
    print(f"Total steps: {len(trajectory)}")
    print(f"Final W: {trajectory.current_state.noise_w}")
    print(f"Final Scale: {trajectory.current_state.noise_scale}")
    print(f"Total colors: {trajectory.current_state.num_colors_assigned}")
    print(f"Is terminal: {trajectory.is_terminal()}")

    # Test the flat action interface
    print("\n--- Testing flat action interface ---")
    trajectory2 = StepWExperimentDefinition.Trajectory(blender_api)

    # Get action mask
    action_mask = trajectory2.get_action_mask()
    print(f"Valid actions at start: {action_mask.sum().item()}")

    # Take a step using flat index
    flat_action_idx = StepWExperimentDefinition.encode_action('step_w', 0)
    print(f"Encoded step_w action to flat index: {flat_action_idx}")
    trajectory2.step_from_flat_action(flat_action_idx, reward=0.1)
    print(f"After flat action step - W: {trajectory2.current_state.noise_w}")

# Run the test
test_step_w_experiment_with_visualization()

# random sampleing sandbox w step


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import os
import random

def test_random_trajectory_with_visualization(max_steps: int = 20):
    """
    Test StepWExperimentDefinition with randomly sampled actions.
    Samples ACTION TYPE first (step_w, step_scale, add_color),
    then samples a value within that action type.

    Stop action is only taken at the final step.

    Args:
        max_steps: Maximum number of steps in the trajectory
    """
    load_blend_single_color_ramp()
    blender_api = BlenderTerrainAPI()

    # Create trajectory
    trajectory = StepWExperimentDefinition.Trajectory(blender_api)

    # Initial state visualization
    screenshot_viewport_to_png(filepath="./tests/file_dump/random_initial_screenshot.png")
    heightmap = blender_api.get_heightmap()
    previous_heightmap = heightmap.clone()
    state_tensor = trajectory.get_state_tensor()

    action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())

    visualize_action_render_state(
        action_tensor=action_tensor,
        action_name="Initial (No Action)",
        render_filepath="./tests/file_dump/random_initial_screenshot.png",
        heightmap=heightmap,
        previous_heightmap=previous_heightmap,
        state_tensor=state_tensor,
        state=trajectory.current_state,
        output_filepath="./tests/file_dump/random_viz_initial.png"
    )

    print(f"Initial state - W: {trajectory.current_state.noise_w}, Scale: {trajectory.current_state.noise_scale}")
    print(f"Starting random trajectory with max {max_steps} steps\n")

    # Sample random trajectory
    for step_idx in range(max_steps):
        # Last step - take stop action
        if step_idx == max_steps - 1:
            action_name = 'stop'
            value_idx = 0
            action_name_str = "stop (final)"
        else:
            # Determine which action types are valid (excluding stop)
            action_mask = trajectory.get_action_mask()

            valid_action_types = []

            # Check step_w
            step_w_offset = StepWExperimentDefinition.get_action_offset('step_w')
            if action_mask[step_w_offset:step_w_offset + len(StepWExperimentDefinition.STEP_W)].any():
                valid_action_types.append('step_w')

            # Check step_scale
            step_scale_offset = StepWExperimentDefinition.get_action_offset('step_scale')
            if action_mask[step_scale_offset:step_scale_offset + len(StepWExperimentDefinition.STEP_SCALE)].any():
                valid_action_types.append('step_scale')

            # Check add_color
            color_offset = StepWExperimentDefinition.get_action_offset('add_color')
            if action_mask[color_offset:color_offset + len(StepWExperimentDefinition.VALID_COLOR_INDICES)].any():
                valid_action_types.append('add_color')

            # DON'T check stop - we only stop at the end

            if not valid_action_types:
                print(f"No valid action types at step {step_idx}. Ending trajectory early.")
                break

            # Randomly choose an action TYPE
            action_name = random.choice(valid_action_types)

            # Randomly choose a VALUE within that action type
            action_info = StepWExperimentDefinition.ACTIONS[action_name]
            valid_values = action_info['valid_values']

            if action_name == 'add_color':
                value_idx = random.randint(0, len(valid_values) - 1)
            elif action_name in ['step_w', 'step_scale']:
                value_idx = 0
            else:
                value_idx = 0

            action_name_str = f"{action_name} (value_idx={value_idx}, value={action_info['valid_values'][value_idx]})"

        # Create action and take step
        action = StepWExperimentDefinition.Action(action_name=action_name, value_idx=value_idx)
        trajectory.step(action, reward=random.random())

        # Create action tensor (one-hot)
        action_tensor = torch.zeros(StepWExperimentDefinition.State.get_action_dim())
        action_tensor[action.to_flat_index()] = 1.0

        # Render and visualize
        screenshot_filepath = f"./tests/file_dump/random_step_{step_idx:02d}.png"
        viz_filepath = f"./tests/file_dump/random_viz_step_{step_idx:02d}.png"

        screenshot_viewport_to_png(filepath=screenshot_filepath)
        previous_heightmap = heightmap.clone()
        heightmap = blender_api.get_heightmap()
        state_tensor = trajectory.get_state_tensor()

        visualize_action_render_state(
            action_tensor=action_tensor,
            action_name=action_name_str,
            render_filepath=screenshot_filepath,
            heightmap=heightmap,
            previous_heightmap=previous_heightmap,
            state_tensor=state_tensor,
            state=trajectory.current_state,
            output_filepath=viz_filepath
        )

        # Print step info
        print(f"Step {step_idx + 1}/{max_steps}:")
        print(f"  Action type chosen: {action_name}")
        print(f"  Full action: {action_name_str}")
        print(f"  W: {trajectory.current_state.noise_w:.2f}, Scale: {trajectory.current_state.noise_scale:.2f}")
        print(f"  Colors: {trajectory.current_state.num_colors_assigned}/{StepWExperimentDefinition.MAX_COLORS}")
        print(f"  Is terminal: {trajectory.is_terminal()}")
        print()

    # Print summary
    print("\n" + "="*60)
    print("TRAJECTORY SUMMARY")
    print("="*60)
    print(f"Total steps taken: {len(trajectory)}")
    print(f"Final W: {trajectory.current_state.noise_w:.2f}")
    print(f"Final Scale: {trajectory.current_state.noise_scale:.2f}")
    print(f"Total colors assigned: {trajectory.current_state.num_colors_assigned}")
    print(f"Is terminal: {trajectory.is_terminal()}")
    print(f"Total reward: {sum(trajectory.rewards):.3f}")

    # Action breakdown
    action_counts = {}
    for action in trajectory.actions:
        action_counts[action.action_name] = action_counts.get(action.action_name, 0) + 1

    print("\nAction breakdown:")
    for action_name, count in sorted(action_counts.items()):
        print(f"  {action_name}: {count}")

    print("="*60)

    return trajectory


# Run single random trajectory with full visualization
print("Testing single random trajectory with visualization...")
trajectory = test_random_trajectory_with_visualization(max_steps=20)