# Bounding Box Overlay Dataset Processor

A custom LeRobot processor that overlays bounding box annotations onto image observations.


In [None]:
from dataclasses import dataclass, field
import torch
import numpy as np
from typing import Any
from PIL import Image, ImageDraw

from lerobot.processor.pipeline import ObservationProcessorStep, ProcessorStepRegistry
from lerobot.configs.types import PipelineFeatureType, PolicyFeature, FeatureType
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES


@dataclass
@ProcessorStepRegistry.register(name="bounding_box_overlay_processor")
class BoundingBoxOverlayProcessor(ObservationProcessorStep):
    """
    A processor that overlays bounding boxes onto image observations.
    
    This processor takes bounding box data and draws it on image observations,
    replacing the original images with annotated versions.
    
    Attributes:
        bbox_key: The key in the observation dict where bounding box data is stored.
                  Expected format: dict with image keys mapping to list of bboxes.
                  Each bbox should be [x1, y1, x2, y2] or [x1, y1, x2, y2, label].
        box_color: RGB tuple for bounding box color (default: red).
        box_width: Width of the bounding box lines in pixels.
        remove_bbox_key: Whether to remove the bbox_key from observations after processing.
    """
    
    bbox_key: str = "bounding_boxes"
    box_color: tuple[int, int, int] = (255, 0, 0)  # Red
    box_width: int = 2
    remove_bbox_key: bool = True
    
    def observation(self, obs: dict[str, Any]) -> dict[str, Any]:
        """
        Process observation by drawing bounding boxes on images.
        
        Args:
            obs: Observation dictionary containing images and bounding boxes.
            
        Returns:
            Modified observation with bounding boxes drawn on images.
        """
        new_obs = obs.copy()
        
        # Check if bounding box data exists
        if self.bbox_key not in obs:
            return new_obs
            
        bboxes = obs[self.bbox_key]
        
        # Handle single image case (OBS_IMAGE key)
        if OBS_IMAGE in new_obs:
            img_data = new_obs[OBS_IMAGE]
            if isinstance(bboxes, (list, np.ndarray)):
                # Bboxes are for this single image
                new_obs[OBS_IMAGE] = self._draw_boxes_on_image(img_data, bboxes)
        
        # Handle multiple images case (OBS_IMAGES.* keys or dictionary)
        if isinstance(bboxes, dict):
            for img_key, img_bboxes in bboxes.items():
                # Construct the full observation key
                full_key = f"{OBS_IMAGES}.{img_key}" if not img_key.startswith(OBS_IMAGES) else img_key
                
                if full_key in new_obs:
                    img_data = new_obs[full_key]
                    new_obs[full_key] = self._draw_boxes_on_image(img_data, img_bboxes)
        
        # Remove bounding box data if requested
        if self.remove_bbox_key and self.bbox_key in new_obs:
            del new_obs[self.bbox_key]
            
        return new_obs
    
    def _draw_boxes_on_image(self, img_data: Any, bboxes: list) -> torch.Tensor:
        """
        Draw bounding boxes on a single image.
        
        Args:
            img_data: Image data (can be torch.Tensor or np.ndarray).
            bboxes: List of bounding boxes, each as [x1, y1, x2, y2] or [x1, y1, x2, y2, label].
            
        Returns:
            Annotated image as a torch.Tensor in the same format as input.
        """
        if len(bboxes) == 0:
            # No boxes to draw, return original
            if isinstance(img_data, torch.Tensor):
                return img_data
            return torch.from_numpy(img_data)
        
        # Convert to numpy for PIL processing
        if isinstance(img_data, torch.Tensor):
            img_np = img_data.cpu().numpy()
            was_tensor = True
        else:
            img_np = img_data
            was_tensor = False
        
        # Handle different tensor formats
        # Expected formats: (B, C, H, W) or (C, H, W) or (H, W, C)
        original_shape = img_np.shape
        is_batched = False
        is_normalized = img_np.dtype == np.float32 or img_np.dtype == np.float64
        
        # Remove batch dimension if present
        if len(img_np.shape) == 4:
            is_batched = True
            img_np = img_np[0]  # Take first image in batch
        
        # Convert channel-first to channel-last if needed
        if len(img_np.shape) == 3 and img_np.shape[0] in [1, 3, 4]:  # Likely (C, H, W)
            img_np = np.transpose(img_np, (1, 2, 0))
            is_channel_first = True
        else:
            is_channel_first = False
        
        # Denormalize if needed
        if is_normalized:
            img_np = (img_np * 255).astype(np.uint8)
        else:
            img_np = img_np.astype(np.uint8)
        
        # Convert to PIL for drawing
        if img_np.shape[2] == 1:  # Grayscale
            pil_img = Image.fromarray(img_np.squeeze(), mode='L')
        else:
            pil_img = Image.fromarray(img_np, mode='RGB')
        
        # Draw bounding boxes
        draw = ImageDraw.Draw(pil_img)
        for bbox in bboxes:
            if len(bbox) >= 4:
                x1, y1, x2, y2 = bbox[:4]
                draw.rectangle([x1, y1, x2, y2], outline=self.box_color, width=self.box_width)
                
                # Optionally draw label if provided
                if len(bbox) > 4:
                    label = str(bbox[4])
                    draw.text((x1, y1 - 10), label, fill=self.box_color)
        
        # Convert back to numpy
        img_np = np.array(pil_img)
        
        # Restore grayscale channel dimension if needed
        if len(img_np.shape) == 2:
            img_np = img_np[:, :, np.newaxis]
        
        # Normalize back if needed
        if is_normalized:
            img_np = img_np.astype(np.float32) / 255.0
        
        # Convert back to channel-first if needed
        if is_channel_first:
            img_np = np.transpose(img_np, (2, 0, 1))
        
        # Add batch dimension back if needed
        if is_batched:
            img_np = img_np[np.newaxis, :]
        
        # Convert to tensor if original was tensor
        if was_tensor:
            return torch.from_numpy(img_np)
        return torch.from_numpy(img_np)
    
    def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
        """
        Declare feature transformations - images remain the same shape, bbox data is removed.
        
        Args:
            features: Input feature specifications.
            
        Returns:
            Modified feature specifications (bbox key removed if remove_bbox_key=True).
        """
        new_features = features.copy()
        
        # Remove bounding box feature if it exists and we're configured to remove it
        if self.remove_bbox_key and PipelineFeatureType.OBSERVATION in new_features:
            obs_features = new_features[PipelineFeatureType.OBSERVATION].copy()
            if self.bbox_key in obs_features:
                del obs_features[self.bbox_key]
            new_features[PipelineFeatureType.OBSERVATION] = obs_features
        
        # Image features remain unchanged (same dimensions)
        return new_features
    
    def get_config(self) -> dict[str, Any]:
        """Return configuration for serialization."""
        return {
            "bbox_key": self.bbox_key,
            "box_color": self.box_color,
            "box_width": self.box_width,
            "remove_bbox_key": self.remove_bbox_key,
        }


# Bounding Box Overlay Processor

This processor takes bounding box data and overlays it onto image observations, then replaces the original images with the annotated versions.

## Features:
- Handles single images (`observation.image`) or multiple images (`observation.images.*`)
- Supports both batched and unbatched image formats
- Works with normalized (0-1) or unnormalized (0-255) images
- Automatically handles channel-first (C,H,W) or channel-last (H,W,C) formats
- Removes bounding box data from observations after processing (configurable)

## Expected Bounding Box Format:
- For single image: `{"bounding_boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2, label], ...]}`
- For multiple images: `{"bounding_boxes": {"camera1": [[x1, y1, x2, y2], ...], "camera2": [...]}}`


In [None]:
# Example 2: Using with multiple cameras
observation_multi = {
    "pixels": {
        "front": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
        "wrist": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
    },
    "bounding_boxes": {
        "front": [[50, 50, 150, 150], [200, 200, 300, 300]],
        "wrist": [[100, 100, 200, 200, "gripper"]],
    }
}

# Process multiple camera observations
transition_multi = create_transition(observation=observation_multi)
processed_multi = bbox_processor(transition_multi)

print("Multi-camera processing complete")
print("Keys:", list(processed_multi[TransitionKey.OBSERVATION].keys()))


## Usage in Dataset Recording

When recording a dataset with bounding box data, you can use this processor to create annotated images:

```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import create_initial_features, aggregate_pipeline_dataset_features

# Define your robot's features including bounding boxes
initial_features = create_initial_features(
    observation={
        "pixels": (480, 640, 3),  # Camera image
        "bounding_boxes": list,    # Bounding box data
        "agent_pos": (7,),         # Robot state
    },
    action={"joint_positions": (7,)}
)

# Create pipeline with bounding box overlay
recording_pipeline = PolicyProcessorPipeline(
    steps=[
        BoundingBoxOverlayProcessor(bbox_key="bounding_boxes"),
        VanillaObservationProcessorStep(),
    ],
    to_transition=batch_to_transition,
    to_output=transition_to_batch
)

# Get final features after processing
final_features = aggregate_pipeline_dataset_features(
    pipeline=recording_pipeline,
    initial_features=initial_features,
    use_videos=True
)

# Create dataset with processed features
dataset = LeRobotDataset.create(
    repo_id="user/dataset_with_bbox_overlay",
    features=final_features,
    # ... other parameters
)
```

The bounding boxes will be drawn on the images before they're saved to the dataset!


## Customization Options

You can customize the appearance and behavior of the bounding box overlay:

```python
# Custom colors and line widths
processor_green = BoundingBoxOverlayProcessor(
    bbox_key="detections",          # Custom key name
    box_color=(0, 255, 0),          # Green boxes
    box_width=5,                    # Thicker lines
    remove_bbox_key=False           # Keep bbox data in observations
)

# Multiple colors for different object types (requires custom logic)
# You can extend the processor to support different colors per bbox
```

## Key Features:
- ✅ Automatically handles batched/unbatched images
- ✅ Works with channel-first (C,H,W) and channel-last (H,W,C) formats
- ✅ Preserves normalized (0-1) or denormalized (0-255) ranges
- ✅ Supports single or multiple camera setups
- ✅ Optional text labels on bounding boxes
- ✅ Integrates seamlessly with LeRobot processor pipelines
- ✅ Registered in ProcessorStepRegistry for easy configuration


In [None]:
# Example 3: Using in a PolicyProcessorPipeline
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.processor import VanillaObservationProcessorStep
from lerobot.processor.converters import batch_to_transition, transition_to_batch

# Create a pipeline that:
# 1. Draws bounding boxes on images
# 2. Processes observations (converts to LeRobot format)
pipeline = PolicyProcessorPipeline(
    steps=[
        BoundingBoxOverlayProcessor(bbox_key="bounding_boxes"),
        VanillaObservationProcessorStep(),
    ],
    name="bbox_overlay_pipeline",
    to_transition=batch_to_transition,
    to_output=transition_to_batch
)

# Create sample batch data
sample_batch = {
    "pixels": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
    "bounding_boxes": [[100, 150, 200, 250], [300, 300, 400, 400]],
    "agent_pos": np.array([0.1, 0.2, 0.3]),
}

# Process through pipeline
processed_batch = pipeline(sample_batch)

print("Pipeline processing complete!")
print("Output keys:", list(processed_batch.keys()))
print("Has observation.image:", "observation.image" in processed_batch)
print("Has observation.state:", "observation.state" in processed_batch)
print("Bounding boxes removed:", "bounding_boxes" not in processed_batch)


In [None]:
# Example 1: Using the processor with a single image
from lerobot.processor.converters import create_transition
from lerobot.processor import TransitionKey

# Create a sample observation with image and bounding boxes
image = np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8)
bboxes = [
    [100, 100, 200, 200],  # x1, y1, x2, y2
    [300, 300, 400, 450, "object"],  # with label
]

observation = {
    "pixels": image,
    "bounding_boxes": bboxes
}

# Create processor
bbox_processor = BoundingBoxOverlayProcessor(
    bbox_key="bounding_boxes",
    box_color=(255, 0, 0),  # Red
    box_width=3
)

# Process the observation
transition = create_transition(observation=observation)
processed_transition = bbox_processor(transition)

# The processed observation now has bounding boxes drawn on the image
processed_obs = processed_transition[TransitionKey.OBSERVATION]
print("Keys in processed observation:", list(processed_obs.keys()))
print("Bounding boxes removed:", "bounding_boxes" not in processed_obs)
