In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import colorsys
import itertools
import json
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch

from lead.inference.config_open_loop import OpenLoopConfig
from lead.inference.open_loop_inference import OpenLoopInference
from lead.training.config_training import TrainingConfig

input_root = "outputs/local_evaluation/2397"
frame_index = 8
input_path = os.path.join(input_root, f"input_log/{frame_index:05d}.pth")
camera_path = os.path.join(input_root, f"demo_images/{frame_index:05d}.png")
camera_calibration = {
    "name": "bev_camera",
    "draw_target_points": False,
    "draw_planning": False,
    "image_size_x": "1080",
    "image_size_y": "786",
    "fov": "100",
    "x": 0.0,
    "y": 0.0,
    "z": 20.0,
    "pitch": -90.0,
    "yaw": 0.0,
}

camera_image = cv2.imread(camera_path)
camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)
anchors = [-20, 0, 20]

num = len(anchors) ** 2

rainbow_colors = []
for k in range(num):
    r, g, b = colorsys.hsv_to_rgb(k / num, 0.9, 0.9)
    rgb = (np.array([r, g, b]) * 255).astype(np.uint8)

In [None]:
from lead.common import common_utils


def project_points_to_image(camera_rot, camera_pos, camera_fov, camera_width, camera_height, points):
    """
    Project 2D points (with z=0) to image coordinates.

    Returns:
        List of ((x, y), inside_bounds) tuples
    """
    camera_pos = np.array(camera_pos)
    points_2d = np.array(points)

    # Make points 3D (z=0)
    points_3d = np.column_stack([points_2d, np.zeros(len(points_2d))])

    # Get rotation matrix
    roll, pitch, yaw = camera_rot
    R = common_utils.euler_deg_to_mat(roll, pitch, yaw)

    # Transform to camera coordinates
    points_translated = points_3d - camera_pos
    points_camera = (R @ points_translated.T).T

    # Remap to camera coordinate system (x=right, y=down, z=forward)
    points_cam_remapped = np.zeros_like(points_camera)
    points_cam_remapped[:, 0] = points_camera[:, 1]  # x_cam = y_world (right)
    points_cam_remapped[:, 1] = -points_camera[:, 2]  # y_cam = -z_world (down)
    points_cam_remapped[:, 2] = points_camera[:, 0]  # z_cam = x_world (forward)

    # Camera intrinsics
    fov_rad = np.radians(camera_fov)
    focal_length_y = camera_height / (2 * np.tan(fov_rad / 2))
    aspect_ratio = camera_width / camera_height
    focal_length_x = focal_length_y * aspect_ratio
    cx = camera_width / 2
    cy = camera_height / 2

    # Project to image
    projected = []
    for i in range(len(points_cam_remapped)):
        z = points_cam_remapped[i, 2]
        if z > 1e-6:  # Point in front of camera
            x_img = (points_cam_remapped[i, 0] * focal_length_x / z) + cx
            y_img = (points_cam_remapped[i, 1] * focal_length_y / z) + cy
            inside = 0 <= x_img < camera_width and 0 <= y_img < camera_height
            projected.append(((x_img, y_img), inside))
        else:
            projected.append(((0, 0), False))

    return projected


def draw_route_and_waypoints_with_config(image, pred_route, pred_waypoints, camera_config, route_color):
    """
    Draw predicted route (blue), waypoints (red), and dense path (small red dots) on an image using specific camera config.

    Args:
        image: BGR image to draw on
        pred_route: Predicted route tensor (N, 2) in vehicle coordinates
        pred_waypoints: Predicted waypoints tensor (N, 2) in vehicle coordinates
        camera_config: Dictionary with camera parameters (x, y, z, pitch, yaw, fov)

    Returns:
        Image with route and waypoints drawn
    """
    route_color = (0, 255, 255) if route_color is None else route_color
    img_with_viz = image.copy()
    camera_height = image.shape[0]
    camera_width = image.shape[1]

    # Extract camera parameters from config
    camera_fov = float(camera_config["fov"])
    camera_pos = [camera_config["x"], camera_config["y"], camera_config["z"]]
    camera_rot = [0, camera_config["pitch"], camera_config["yaw"]]  # roll, pitch, yaw

    # Draw route in blue
    if pred_waypoints is not None and len(pred_waypoints) > 0:
        route_points = pred_waypoints.detach().cpu().float().numpy()
        projected_route = project_points_to_image(camera_rot, camera_pos, camera_fov, camera_width, camera_height, route_points)

        # Draw circles for waypoints
        for pt, inside in projected_route:
            if inside:
                cv2.circle(
                    img_with_viz,
                    (int(pt[0]), int(pt[1])),
                    radius=3,
                    color=route_color,
                    thickness=-1,  # Red in BGR
                    lineType=cv2.LINE_AA,
                )
        # Draw connected line for route
        for i in range(len(projected_route) - 1):
            pt1, inside1 = projected_route[i]
            pt2, inside2 = projected_route[i + 1]
            if inside1 and inside2:
                cv2.line(
                    img_with_viz,
                    (int(pt1[0]), int(pt1[1])),
                    (int(pt2[0]), int(pt2[1])),
                    route_color,  # Blue in BGR
                    thickness=2,
                    lineType=cv2.LINE_AA,
                )

    return img_with_viz


def draw_target_points_bev(image, target_points, camera_config, tp_color):
    """
    Draw previous, current, and next target points on BEV camera image as squares.

    Args:
        image: BGR image to draw on
        target_points: Dictionary with keys 'previous', 'current', 'next' containing (x, y) coordinates
        camera_config: Dictionary with camera parameters (x, y, z, pitch, yaw, fov)

    Returns:
        Image with target points drawn
    """
    img_with_targets = image.copy()
    camera_height = image.shape[0]
    camera_width = image.shape[1]

    # Extract camera parameters from config
    camera_fov = float(camera_config["fov"])
    camera_pos = [camera_config["x"], camera_config["y"], camera_config["z"]]
    camera_rot = [0, camera_config["pitch"], camera_config["yaw"]]

    # Define colors and sizes for each target point (BGR format)
    targets_config = [
        # ("previous", (255, 0, 0), 9),  # Gray, smaller square
        ("current", (255, 0, 0) if tp_color is None else tp_color, 9),  # Green, bigger square
        # ("next", (255, 0, 0), 9),  # Cyan, smaller square
    ]

    for key, color, size in targets_config:
        if key in target_points and target_points[key] is not None:
            # Get target point in vehicle coordinates
            target_point = np.array([[target_points[key][0], target_points[key][1]]])

            # Project to image
            projected = project_points_to_image(camera_rot, camera_pos, camera_fov, camera_width, camera_height, target_point)

            if len(projected) > 0:
                pt, inside = projected[0]
                if inside:
                    # Draw square (rectangle with equal width and height)
                    x, y = int(pt[0]), int(pt[1])
                    cv2.circle(
                        img_with_targets,
                        (x, y),
                        size + 2,
                        (255, 255, 255),
                        thickness=-1,
                        lineType=cv2.LINE_AA,
                    )

                    # Filled colored circle
                    cv2.circle(
                        img_with_targets,
                        (x, y),
                        size,
                        color,
                        thickness=-1,
                        lineType=cv2.LINE_AA,
                    )

    return img_with_targets


def draw(image, pred_route=None, pred_waypoints=None, pred_bboxes=None, target_points=None, color=None):
    """Save demo camera images with camera name labels and optional route/waypoint/bbox visualization."""
    camera_config = camera_calibration

    # Start with the base image
    processed_image = image.copy()

    # Add route and waypoint visualization if enabled for this camera
    processed_image = draw_route_and_waypoints_with_config(processed_image, pred_route, pred_waypoints, camera_config, color)

    # Add target points visualization for BEV camera
    processed_image = draw_target_points_bev(processed_image, target_points, camera_config, color)
    return processed_image

In [None]:
import colorsys

camera_image = cv2.imread(camera_path)
camera_image = cv2.cvtColor(camera_image, cv2.COLOR_BGR2RGB)
anchors = [-25, 0, 25]

num = len(anchors) ** 2

rainbow_colors = []
for k in range(num):
    r, g, b = colorsys.hsv_to_rgb(k / num, 0.9, 0.9)
    rgb = (np.array([r, g, b]) * 255).astype(np.uint8)
    rainbow_colors.append(tuple(int(x) for x in rgb))

for i, (ax, ay) in enumerate(itertools.product(anchors, anchors)):
    if ax == 0 and ay == 0:
        continue
    model_input["speed"] = torch.Tensor([[15.0]]).to(model_input["target_point"].device)
    model_input["target_point"] = torch.tensor([[ax, ay]], device=model_input["target_point"].device)
    model_prediction = open_loop_inference.forward(model_input)
    target_points = {
        "previous": model_input.get("target_point_previous").cpu().numpy().squeeze(),
        "current": model_input.get("target_point").cpu().numpy().squeeze(),
        "next": model_input.get("target_point_next").cpu().numpy().squeeze(),
    }
    camera_image = draw(
        camera_image,
        model_prediction.pred_route[0],
        model_prediction.pred_future_waypoints[0],
        target_points=target_points,
        color=rainbow_colors[i],
    )
plt.imshow(camera_image)
plt.axis("off")
plt.show()
print(target_points)

In [None]:
def convert_to_image_space(points_2d, camera_config, camera_width, camera_height):
    """
    Convert 2D world coordinates to image space.

    Args:
        points_2d: numpy array of shape (N, 2) with (x, y) coordinates
        camera_config: Dictionary with camera parameters
        camera_width: Image width
        camera_height: Image height

    Returns:
        List of (x, y) tuples in image space (or None if out of bounds)
    """
    camera_fov = float(camera_config["fov"])
    camera_pos = [camera_config["x"], camera_config["y"], camera_config["z"]]
    camera_rot = [0, camera_config["pitch"], camera_config["yaw"]]

    projected = project_points_to_image(camera_rot, camera_pos, camera_fov, camera_width, camera_height, points_2d)

    result = []
    for (x, y), inside in projected:
        if inside:
            result.append([float(x), float(y)])
        else:
            result.append(None)

    return result


def generate_target_point_grid(x_range, y_range, step=1.0):
    """
    Generate grid of target points.

    Args:
        x_range: (min, max) for x coordinate
        y_range: (min, max) for y coordinate
        step: Grid step size in meters

    Returns:
        List of (x, y) tuples
    """
    x_values = np.arange(x_range[0], x_range[1] + step, step)
    y_values = np.arange(y_range[0], y_range[1] + step, step)

    grid_points = []
    for x in x_values:
        for y in y_values:
            if x == 0 and y == 0:  # Skip origin
                continue
            grid_points.append([float(x), float(y)])

    return grid_points

In [None]:
# Generate JSON data for interactive widget
camera_height = camera_image.shape[0]
camera_width = camera_image.shape[1]

# Generate target point grid (1m spacing from -20m to 20m in both x and y)
target_points_grid = generate_target_point_grid(x_range=(-20, 20), y_range=(-20, 20), step=1.0)

print(f"Generated {len(target_points_grid)} target points")

# Initialize result structure
widget_data = {"image_width": camera_width, "image_height": camera_height, "misaligned": [], "aligned": []}

# Process both models
model_configs = [
    {
        "checkpoint_dir": "outputs/training/700_regnety_032/010_postrain32_0/250913_153308",
        "key": "misaligned",
        "use_tfv5": True,
    },
    {
        "checkpoint_dir": "outputs/training/733_scaled_regnety/012_postrain32_2/251025_182334",
        "key": "aligned",
        "use_tfv5": False,
    },
]

for model_config in model_configs:
    checkpoint_dir = model_config["checkpoint_dir"]
    data_key = model_config["key"]

    print(f"\nProcessing {data_key} model...")

    # Load model
    open_loop_config = OpenLoopConfig()
    open_loop_config.strict_weight_load = False

    with open(os.path.join(checkpoint_dir, "config.json"), encoding="utf-8") as f:
        json_config = json.loads(f.read())

    training_config = TrainingConfig(json_config)
    training_config.use_tfv5_planning_decoder = model_config["use_tfv5"]

    open_loop_inference = OpenLoopInference(
        config_training=training_config,
        config_open_loop=open_loop_config,
        model_path=checkpoint_dir,
        device=torch.device("cuda:0"),
        prefix="model",
    )

    # Load input data
    model_input = torch.load(input_path)

    # Process each target point
    for i, (tx, ty) in enumerate(target_points_grid):
        if i % 100 == 0:
            print(f"  Processing target point {i + 1}/{len(target_points_grid)}")

        # Set target point and speed
        model_input["speed"] = torch.Tensor([[10.0]]).to(model_input["target_point"].device)
        model_input["target_point"] = torch.tensor([[tx, ty]], device=model_input["target_point"].device)

        # Run inference
        model_prediction = open_loop_inference.forward(model_input)

        # Convert target point to image space
        target_point_world = np.array([[tx, ty]])
        target_point_img = convert_to_image_space(target_point_world, camera_calibration, camera_width, camera_height)

        # Convert predictions to image space
        pred_waypoints = model_prediction.pred_future_waypoints[0].detach().cpu().float().numpy()
        waypoints_img = convert_to_image_space(pred_waypoints, camera_calibration, camera_width, camera_height)

        # Store data
        data_entry = {
            "target_point_world": [tx, ty],
            "target_point_image": target_point_img[0] if target_point_img[0] is not None else None,
            "waypoints_image": waypoints_img,
        }

        widget_data[data_key].append(data_entry)

    print(f"  Completed {data_key} model: {len(widget_data[data_key])} entries")

# Save to JSON
output_json_path = os.path.join(input_root, "widget_data.json")
with open(output_json_path, "w") as f:
    json.dump(widget_data, f, indent=2)

print(f"\nSaved widget data to: {output_json_path}")
print(f"Total misaligned entries: {len(widget_data['misaligned'])}")
print(f"Total aligned entries: {len(widget_data['aligned'])}")

In [None]:
# Visualize a sample to verify the data looks correct
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

for idx, (model_name, ax) in enumerate(zip(["misaligned", "aligned"], axes)):
    # Show base image
    ax.imshow(camera_image)
    ax.set_title(f"{model_name.capitalize()} Model - Sample Predictions")

    # Sample every 10th entry to avoid clutter
    sample_indices = range(0, len(widget_data[model_name]), 10)

    for i in sample_indices:
        entry = widget_data[model_name][i]

        # Draw target point
        if entry["target_point_image"] is not None:
            tx, ty = entry["target_point_image"]
            ax.plot(tx, ty, "ro", markersize=8, markeredgecolor="white", markeredgewidth=2)

        # Draw waypoints
        waypoints = [w for w in entry["waypoints_image"] if w is not None]
        if waypoints:
            xs = [w[0] for w in waypoints]
            ys = [w[1] for w in waypoints]
            ax.plot(xs, ys, "b-", linewidth=2, alpha=0.6)

    ax.axis("off")

plt.tight_layout()
plt.show()

print("\nData structure preview:")
print(f"Misaligned sample entry: {widget_data['misaligned'][0]}")
print(f"\nAligned sample entry: {widget_data['aligned'][0]}")