In [None]:
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from datasets import load_dataset
import numpy as np


In [None]:
# Number of random images to display
NUM_IMAGES = 5

# Subset to view, one of train/test
SUBSET = 'train'

In [None]:
dataset_dict = load_dataset('PrzemekS/highway-vehicles')

# Access the train or test dataset
my_dataset = dataset_dict[SUBSET]

# Define class names (ensure this matches your dataset)
class_names = ['vehicle', 'truck']  
id2label = {id: label for id, label in enumerate(class_names)}

# Randomly select indices
indices = random.sample(range(len(my_dataset)), NUM_IMAGES)

for idx in indices:
    example = my_dataset[idx]
    image = example['image']
    objects = example['objects']
    bboxes = objects['bbox']        # List of bounding boxes
    labels = objects['category']    # List of category IDs

    # Convert image to numpy array
    img_np = np.array(image)
    height, width, _ = img_np.shape

    # Create figure and axes
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(img_np)

    # Add bounding boxes
    for bbox, label in zip(bboxes, labels):
        # bbox is [xmin, ymin, width, height] in absolute pixel coordinates
        xmin = bbox[0]
        ymin = bbox[1]
        box_width = bbox[2]
        box_height = bbox[3]

        # Choose color based on class label
        if label == 0:
            color = 'red'    # Class 0 in red
        elif label == 1:
            color = 'green'  # Class 1 in green
        else:
            color = 'blue'   # Default color for other classes

        # Create a Rectangle patch
        rect = patches.Rectangle((xmin, ymin), box_width, box_height, linewidth=2, edgecolor=color, facecolor='none')

        # Add the patch to the Axes
        ax.add_patch(rect)

        # Add label
        label_name = id2label[label]
        plt.text(xmin, ymin - 5, label_name, color=color, fontsize=12, weight='bold')

    plt.axis('off')
    plt.show()
