In [None]:
# Import required libraries
import torch
from PIL import Image
from rex_omni import RexOmniWrapper, RexOmniVisualize
import matplotlib.pyplot as plt
import numpy as np

print("✅ All libraries imported successfully!")


In [None]:
# Initialize Rex Omni wrapper
model_path = "IDEA-Research/Rex-Omni"  # Replace with your model path

print("🚀 Initializing Rex Omni model...")

rex_model = RexOmniWrapper(
    model_path=model_path,
    backend="transformers",  # Choose "transformers" or "vllm"
    max_tokens=2048,
    temperature=0.0,
    top_p=0.05,
    top_k=1,
    repetition_penalty=1.05,
)

print("✅ Model initialized successfully!")


In [None]:
# Load image
image_path = "examples/test_images/cafe.jpg"  # Replace with your image path

try:
    image = Image.open(image_path).convert("RGB")
    print(f"✅ Image loaded successfully!")
    print(f"📏 Image size: {image.size}")
    
    # Display the image
    plt.figure(figsize=(10, 8))
    plt.imshow(image)
    plt.axis('off')
    plt.title('Input Image for Object Detection')
    plt.show()
    
except FileNotFoundError:
    print(f"❌ Image not found at: {image_path}")
    print("Please update the image_path variable with a valid image path")
    
    # Create a dummy image for demonstration
    print("Creating a dummy image for demonstration...")
    image = Image.new('RGB', (640, 480), color='lightblue')
    print("✅ Using dummy image")


In [None]:
# Define categories to detect
categories = [
    "man",
    "woman", 
    "yellow flower",
    "sofa",
    "robot-shope light",
    "blanket",
    "microwave",
    "laptop",
    "cup",
    "white chair",
    "lamp",
]

print("🎯 Categories to detect:")
for i, category in enumerate(categories, 1):
    print(f"  {i}. {category}")

print(f"\n📊 Total categories: {len(categories)}")


In [None]:
# Run object detection inference
print("🔍 Running object detection...")

results = rex_model.inference(
    images=image, 
    task="detection", 
    categories=categories
)

print("✅ Inference completed!")

# Display results
result = results[0]
if result["success"]:
    predictions = result["extracted_predictions"]
    raw_output = result["raw_output"]
    
    print(f"🎯 Found {sum(len(preds) for preds in predictions.values())} objects")
    print("\n📋 Detection Results:")
    
    for category, detections in predictions.items():
        if detections:
            print(f"\n  {category.upper()}:")
            for i, detection in enumerate(detections):
                coords = detection.get("coords", [])
                if len(coords) == 4:
                    x0, y0, x1, y1 = coords
                    print(f"    Box {i+1}: ({x0:.1f}, {y0:.1f}, {x1:.1f}, {y1:.1f})")
else:
    print(f"❌ Inference failed: {result['error']}")
    print("Raw output:", result.get('raw_output', 'No output available'))


In [None]:
# Visualize detection results
if result["success"] and predictions:
    print("🎨 Creating visualization...")
    
    # Create visualization using Rex Omni's built-in function
    vis_image = RexOmniVisualize(
        image=image,
        predictions=predictions,
        font_size=20,
        draw_width=5,
        show_labels=True,
    )
    
    # Display the visualization
    plt.figure(figsize=(15, 10))
    plt.imshow(vis_image)
    plt.axis('off')
    plt.title('Object Detection Results', fontsize=16, fontweight='bold')
    plt.show()
    
    # Save visualization (optional)
    try:
        output_path = "detection_results.jpg"
        vis_image.save(output_path)
        print(f"💾 Visualization saved to: {output_path}")
    except Exception as e:
        print(f"⚠️ Could not save image: {e}")
        
else:
    print("❌ No predictions to visualize")


In [None]:
# Example: Using VLLM backend (commented out - uncomment to use)
# rex_model_vllm = RexOmniWrapper(
#     model_path="IDEA-Research/Rex-Omni",
#     backend="vllm",  # Use VLLM for faster inference
#     max_tokens=2048,
#     temperature=0.0,
#     top_p=0.05,
#     top_k=1,
#     repetition_penalty=1.05,
#     # VLLM-specific parameters
#     gpu_memory_utilization=0.8,
#     tensor_parallel_size=1,
# )

print("💡 VLLM backend provides faster inference for production use cases")
print("💡 Uncomment the code above to use VLLM instead of transformers")


In [None]:
# Examples of other supported tasks (commented out)

# 1. Keypoint Detection with Skeleton Visualization
# keypoint_results = rex_model.inference(
#     images=image,
#     task="keypoint",
#     categories=["person"]
# )

# 2. OCR with Bounding Boxes
# ocr_results = rex_model.inference(
#     images=image,
#     task="ocr_box",
#     categories=["text"]
# )

# 3. Visual Prompting (Point-based)
# pointing_results = rex_model.inference(
#     images=image,
#     task="pointing",
#     categories=["object at point"]
# )

print("📝 Supported tasks:")
print("  • detection - Object detection with bounding boxes")
print("  • keypoint - Keypoint detection with skeleton visualization")
print("  • ocr_box - OCR with bounding boxes")
print("  • pointing - Visual prompting with points")
print("  • visual_prompting - Advanced visual prompting")

print("\n🎨 Visualization features:")
print("  • Automatic skeleton drawing for keypoints")
print("  • Color-coded bounding boxes")
print("  • Category labels")
print("  • Customizable fonts and line widths")


# Rex Omni Object Detection Tutorial

This notebook demonstrates how to use Rex Omni for object detection tasks (tasks that output in box format)

## Features
- Easy-to-use API with automatic model initialization
- Support for both Transformers and VLLM backends
- Built-in visualization capabilities
- Flexible configuration options

## Step 1: Initialize Rex Omni Model

In [1]:
# Import required libraries
import torch
from PIL import Image
from rex_omni import RexOmniWrapper, visualize_predictions
import matplotlib.pyplot as plt
import numpy as np

# Rex-Omni supports both Transformers and VLLM backends by switching the backend parameter.
model_path = "IDEA-Research/Rex-Omni"  # Replace with your model path

print("🚀 Initializing Rex Omni model...")

rex_model = RexOmniWrapper(
    model_path=model_path,
    backend="transformers",  # Choose "transformers" or "vllm"
    max_tokens=2048,
    temperature=0.0,
    top_p=0.05,
    top_k=1,
    repetition_penalty=1.05,
)

ImportError: cannot import name 'visualize_predictions' from 'rex_omni' (/comp_robot/jiangqing/projects/2023/research/R1/QwenSFTOfficial/open_source/rex_omni/__init__.py)

## Step 2: Object Detection Example

Let's load an image for object detection task. You can replace this with your own image path.

In [None]:
# Load image
image_path = "examples/test_images/cafe.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")
print(f"✅ Image loaded successfully!")
print(f"📏 Image size: {image.size}")

# Display the image
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.axis('off')
plt.title('Input Image for Object Detection')
plt.show()

# Define categories to detect
categories = [
    "man",
    "woman",
    "yellow flower",
    "sofa",
    "robot-shope light",
    "blanket",
    "microwave",
    "laptop",
    "cup",
    "white chair",
    "lamp",
]

# inference
results = rex_model.inference(images=image, task="detection", categories=categories)

lets visualize the predicted results using RexOmniVisualize function