In [1]:
import sys
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLOE

# Add src to path
sys.path.append(os.path.abspath('..'))
# Add nuscenes_data to path
sys.path.append(os.path.abspath('../nuscenes_data'))

from src.data.loader import NuScenesLoader

# 1. Initialize Loader
print("⏳ Loading NuScenes...")
loader = NuScenesLoader(dataroot="../nuscenes_data", 
                        version="v1.0-trainval")
samples = loader.get_all_samples()
print(f"✅ Loaded {len(samples)} frames.")

# 2. Initialize YOLOE-11 (Large)
print("⏳ Loading YOLOE Model...")
model = YOLOE("yoloe-11l-seg.pt")

# 3. Define WOD-E2E Taxonomy (The specific things we hunt for)
custom_classes = [
    # 1. VRUs (Vulnerable Road Users)
    "person", "pedestrian", "child", 
    "cyclist", "bicyclist", "motorcyclist", "scooter rider",
    "construction worker", "worker in safety vest", "police officer",
    
    # 2. Vehicles (Specialized)
    "car", "pickup truck", "suv", "van", "sedan", "coupe",
    "truck", "semi truck", "trailer", "cement mixer",
    "bus", "school bus",
    "police car", "police vehicle", "ambulance", "fire truck",
    "construction vehicle", "bulldozer", "excavator", "forklift",
    "road sweeper", "street cleaner",
    
    # 3. Construction & Barriers
    "traffic cone", "orange cone",  "traffic drum",
    "construction barrel", "orange drum", # Crucial for Highway Construction
    "traffic barrier", "concrete barrier", "jersey barrier",
    "road work sign", "temporary sign",
    "construction fence", "safety fence",
    "scaffolding", "construction scaffolding",
    "construction barricade",
    
    # 4. Hazards / Debris (FOD)
    "debris", "cardboard box", "tire", 
    "plastic bag", "tree branch", "large rock",
    "puddle", 
    
    # 5. Traffic Control
    "traffic light", "traffic signal", "red light", 
    "stop sign", "yield sign", "speed limit sign",
    "pedestrian crossing sign", "school zone sign",
    "crosswalk",
]

# Compile the specific prompts
model.set_classes(custom_classes, model.get_text_pe(custom_classes))
print("✅ Model Ready with Custom Taxonomy.")

⏳ Loading NuScenes...
Loading NuScenes v1.0-trainval database from ../nuscenes_data...
✅ Loaded 34149 frames.
⏳ Loading YOLOE Model...
✅ Model Ready with Custom Taxonomy.


In [None]:
# --- Widgets ---

# 1. Frame Selector (0 to 400)
frame_slider = widgets.IntSlider(
    value=50, 
    min=0, 
    max=len(samples)-1, 
    step=1, 
    description='Frame Idx:',
    layout=widgets.Layout(width='50%')
)

# 2. Camera Selector
cam_dropdown = widgets.Dropdown(
    options=[
        "CAM_FRONT", "CAM_FRONT_LEFT", "CAM_FRONT_RIGHT",
        "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT"
    ],
    value="CAM_FRONT",
    description='Camera:',
)

# 3. Confidence Threshold Slider
conf_slider = widgets.FloatSlider(
    value=0.15,
    min=0.05,
    max=1.0,
    step=0.05,
    description='Min Conf:',
    readout=True
)

output_plot = widgets.Output()

# --- Logic ---

def update_view(change=None):
    """Runs inference and updates the plot whenever a widget changes."""
    idx = frame_slider.value
    cam_name = cam_dropdown.value
    conf_thresh = conf_slider.value
    
    token = samples[idx]
    
    # Get Image Path
    paths = loader.get_camera_paths(token)
    img_path = paths[cam_name]
    
    with output_plot:
        clear_output(wait=True)
        try:
            # Load Image
            img = Image.open(img_path)
            
            # Run Inference (Real-time on RTX 3090)
            # stream=True makes it slightly faster for single images
            results = model.predict(img, conf=conf_thresh, verbose=False)
            
            # Plot using Ultralytics built-in visualizer
            # line_width=2 makes it cleaner
            annotated_frame = results[0].plot(line_width=2, font_size=12)
            
            # Display using Matplotlib
            plt.figure(figsize=(14, 8))
            plt.imshow(annotated_frame)
            plt.axis('off')
            plt.title(f"Frame {idx} | {cam_name} | {loader.get_scene_description(token)}")
            plt.show()
            
        except Exception as e:
            print(f"Error: {e}")

# Bind Events
frame_slider.observe(update_view, names='value')
cam_dropdown.observe(update_view, names='value')
conf_slider.observe(update_view, names='value')

# Initial Call
update_view()

# Layout
ui = widgets.VBox([
    widgets.HBox([frame_slider, cam_dropdown, conf_slider]),
    output_plot
])

display(ui)

VBox(children=(HBox(children=(IntSlider(value=50, description='Frame Idx:', layout=Layout(width='50%'), max=34…