In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import math
import random

def plot_rtdetr_predictions(model, dataset, processor, class_names, device="cuda", num_images=24):
    model.eval().to(device)
    cols = 4
    rows = math.ceil(num_images / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 3), dpi=150)
    axes = axes.flatten()

    # Define distinct BGR colors for each class
    colors = [
        (255, 0, 0),     # Red
        (0, 255, 0),     # Green
        (0, 0, 255),     # Blue
        (255, 255, 0),   # Cyan
        (255, 0, 255),   # Magenta
        (0, 255, 255),   # Yellow
        (128, 128, 0),   # Olive
        (128, 0, 128),   # Purple
        (0, 128, 128),   # Teal
        (128, 128, 128)  # Gray
    ]
    class_colors = {i: colors[i % len(colors)] for i in range(len(class_names))}

    # Select random indices from dataset
    all_indices = random.sample(range(len(dataset)), min(num_images, len(dataset)))

    for i, idx in enumerate(all_indices):
        if hasattr(dataset, 'base_dataset'):
            pil_img, _ = dataset.base_dataset[idx]
            orig_img = np.array(pil_img.convert("RGB"))
            img_tensor = processor(images=pil_img, return_tensors="pt").pixel_values.to(device)
        else:
            print("Base dataset not found. Skipping.")
            continue

        with torch.no_grad():
            outputs = model(img_tensor)

        height, width = orig_img.shape[:2]
        results = processor.post_process_object_detection(
            outputs=outputs,
            target_sizes=torch.tensor([[height, width]], device=device),
            threshold=0.1
        )[0]

        draw_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR)
        height2, width2, channels = draw_img.shape
        if width2 > 800:
            fondsize = 3
            fondthickness = 4
            linethickness = 7
        else:
            fondsize = 1
            fondthickness = 1
            linethickness = 2

        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            x1, y1, x2, y2 = map(int, box.tolist())
            class_id = label.item()
            if class_id >= len(class_names):
                continue  # skip invalid class id
            color = class_colors.get(class_id, (255, 255, 255))
            label_text = f"{class_names[class_id]}: {score:.2f}"
            cv2.rectangle(draw_img, (x1, y1), (x2, y2), color, linethickness)
            cv2.putText(draw_img, label_text, (x1, max(y1 - 5, 15)),
                        cv2.FONT_HERSHEY_SIMPLEX, fondsize, color, fondthickness, cv2.LINE_AA)

        draw_img = cv2.cvtColor(draw_img, cv2.COLOR_BGR2RGB)
        axes[i].imshow(draw_img, interpolation='none')
        axes[i].axis("off")
        axes[i].set_title(f"Image {i + 1}")

    for i in range(num_images, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# 50 epochs
model_name = "PekingU/rtdetr_r101vd_coco_o365"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained("datasets/detr_surgical_best_model50")
class_names = ['no', 'Grasper', 'Harmonic_Ace', 'Myoma_Screw',
               'Needle_Holder', 'Suction', 'Trocar']

plot_rtdetr_predictions(
    model=model,
    dataset=valid_processed,
    processor=processor,
    class_names=class_names,
    device="cuda",
    num_images=100
)

In [None]:
model_name = "PekingU/rtdetr_r101vd_coco_o365"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained("datasets/detr_surgical_best_model")
class_names = ['no', 'Grasper', 'Harmonic_Ace', 'Myoma_Screw',
               'Needle_Holder', 'Suction', 'Trocar']

plot_rtdetr_predictions(
    model=model,
    dataset=valid_processed,
    processor=processor,
    class_names=class_names,
    device="cuda",
    num_images=100
)