In [1]:
import sys

sys.path.append("..")

In [2]:
import torch
import torchvision
from PIL import Image
from typing import List, Union, Tuple
import numpy as np
import matplotlib.pyplot as plt
import cv2
from segment.dino_script import run_dino, format_dino, get_dino_model, transform_image_dino
from segment.utils import get_device, image_handler
from segment.sam_results import format_boxes, format_scores

DEVICE = get_device()


class DinoDetector:
    def __init__(
        self,
        image: Union[str, Image.Image, List[Image.Image]],
        text_prompt: str,
        image_size: int = 1024,
        box_threshold: float = 0.3,
        text_threshold: float = 0.25,
        iou_threshold: float = 0.8,
    ):
        self.image_size = image_size
        self.device = DEVICE
        self.images = image_handler(image, self.image_size)
        self.dino_images = self.image_to_tensor()
        self.text_prompt = text_prompt
        self.model = self._get_dino_model()
        self.box_threshold = box_threshold
        self.text_threshold = text_threshold
        self.iou_threshold = iou_threshold
        self.boxes = None
        self.scores = None
        self.phrases = None

    def _get_dino_model(self):
        return get_dino_model()

    def image_to_tensor(self):
        with torch.no_grad():
            dino_images = torch.stack(
                [transform_image_dino(image) for image in self.images]
            )
        return dino_images.to(self.device)

    def _run_dino(self):
        self.boxes, self.scores, self.phrases = run_dino(
            self.model,
            self.dino_images,
            self.text_prompt,
            self.box_threshold,
            self.text_threshold,
        )

    def _format_dino(self):
        self.boxes, self.scores, self.phrases = format_dino(
            self.boxes, self.scores, self.phrases, (self.image_size, self.image_size), self.iou_threshold
        )

    def asdict(self):
        boxes = format_boxes(self.boxes[0])
        scores = format_scores(self.scores[0])
        phrases = self.phrases[0]
        return [
            dict(box=box, score=score, label=phrases[idx])
            for idx, (box, score) in enumerate(zip(boxes, scores))
        ]

    def visualize_results(
        self,
        image_num=0,
        cols: int = 4,
    ):
        # Convert PIL Image to numpy array
        image_np = np.array(self.images[image_num])
        boxes = self.boxes[0]
        scores = self.scores[0].squeeze().cpu().numpy().tolist()
        phrases = self.phrases[0]

        # Check image dimensions
        if image_np.ndim != 3:
            raise ValueError("Image must be a 3-dimensional array")

        # Number of detections
        n = len(boxes)
        rows = (n + cols - 1) // cols  # Calculate required number of rows

        # Setting up the plot
        fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
        if n == 1:
            axs = np.array([[axs]])
        elif rows == 1:
            axs = np.array([axs])
        else:
            axs = axs.reshape(rows, cols)

        for i, (box, phrase, score) in enumerate(zip(boxes, phrases, scores)):
            row = i // cols
            col = i % cols

            # Create a copy of the image
            img_with_box = image_np.copy()

            # Draw bounding box
            x1, y1, x2, y2 = box.int().tolist()
            cv2.rectangle(img_with_box, (x1, y1), (x2, y2), (0, 255, 0), 2)

            # Show the image with bounding box
            ax = axs[row, col]
            ax.imshow(img_with_box)
            ax.axis("off")
            ax.set_title(f"Label: {phrase}, Score: {score:.2f}", fontsize=12)

        # Hide unused subplots if the total number of detections is not a multiple of cols
        for idx in range(i + 1, rows * cols):
            row = idx // cols
            col = idx % cols
            axs[row, col].axis("off")

        plt.tight_layout()
        plt.show()

In [3]:
from segment.utils import load_resize_image

In [12]:
im_path = "/Users/jordandavis/Downloads/hailey.webp"
image = load_resize_image(im_path)
text_prompt = "face . hair"
detector = DinoDetector([image], text_prompt)

In [33]:
detector.model = detector.model.to('mps')

In [34]:
detector._run_dino()

RuntimeError: Input type (float) and bias type (c10::Half) should be the same

In [15]:
detector._format_dino()

In [16]:
detector.visualize_results()

ValueError: Number of rows must be a positive integer, not 0

<Figure size 2000x0 with 0 Axes>