Requirements

In [None]:
torch
opencv-python
Pillow
plotly
transformers

In [None]:
import tkinter as tk
from tkinter import filedialog, Text, messagebox
from PIL import Image, ImageTk
import numpy as np
import cv2
import torch
from transformers import pipeline
import requests
from dataclasses import dataclass
from typing import List

@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox

    @classmethod
    def from_dict(cls, detection_dict: dict) -> 'DetectionResult':
        return cls(
            score=detection_dict['score'],
            label=detection_dict['label'],
            box=BoundingBox(
                xmin=detection_dict['box']['xmin'],
                ymin=detection_dict['box']['ymin'],
                xmax=detection_dict['box']['xmax'],
                ymax=detection_dict['box']['ymax']
            )
        )

def load_image(image_str: str) -> Image.Image:
    if image_str.startswith("http"):
        image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_str).convert("RGB")
    return image

def detect(image: Image.Image, labels: List[str], threshold: float = 0.3) -> List[DetectionResult]:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    detector_id = "IDEA-Research/grounding-dino-tiny"
    object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)

    labels = [label if label.endswith(".") else label + "." for label in labels]
    results = object_detector(image, candidate_labels=labels, threshold=threshold)
    results = [DetectionResult.from_dict(result) for result in results]

    return results

class AIModelApp:
    def __init__(self, master):
        self.master = master
        master.title("AI Model Image Segmentation")

        self.image_label = tk.Label(master, text="Upload an Image:")
        self.image_label.pack()

        self.upload_button = tk.Button(master, text="Upload Image", command=self.upload_image)
        self.upload_button.pack()

        self.prompt_label = tk.Label(master, text="Enter labels (comma-separated):")
        self.prompt_label.pack()

        self.prompt_entry = Text(master, height=5, width=50)
        self.prompt_entry.pack()

        self.submit_button = tk.Button(master, text="Submit", command=self.process_image)
        self.submit_button.pack()

        self.delete_button = tk.Button(master, text="Delete Selected Boxes", command=self.delete_selected_boxes)
        self.delete_button.pack()

        self.result_label = tk.Label(master, text="")
        self.result_label.pack()

        self.display_label = tk.Label(master)
        self.display_label.pack()

        self.image_path = ""
        self.detections = []
        self.selected_boxes = []
        self.offsets = []

    def upload_image(self):
        self.image_path = filedialog.askopenfilename()
        if self.image_path:
            image = Image.open(self.image_path)
            image.thumbnail((300, 300))
            self.display_image(image)

    def display_image(self, image: Image.Image):
        max_width, max_height = 600, 600
        image.thumbnail((max_width, max_height), Image.BICUBIC)
        self.image_display = ImageTk.PhotoImage(image)
        self.display_label.config(image=self.image_display)
        self.display_label.image = self.image_display
        self.bind_mouse_events()

    def bind_mouse_events(self):
        self.display_label.bind("<ButtonPress-1>", self.on_mouse_click)
        self.display_label.bind("<B1-Motion>", self.on_mouse_drag)
        self.display_label.bind("<ButtonRelease-1>", self.on_button_release)

    def on_mouse_click(self, event):
        x, y = event.x, event.y
        for detection in self.detections:
            box = detection.box
            # Check if click is near the annotation
            if box.xmin <= x <= box.xmax and box.ymin <= y <= box.ymax:
                if detection not in self.selected_boxes:
                    self.selected_boxes.append(detection)
                    self.offsets.append((x - box.xmin, y - box.ymin))
                break
        self.redraw_image()

    def on_mouse_drag(self, event):
        for index, detection in enumerate(self.selected_boxes):
            box = detection.box
            offset_x, offset_y = self.offsets[index]
            box.xmin = event.x - offset_x
            box.xmax = box.xmin + (box.xmax - box.xmin)
            box.ymin = event.y - offset_y
            box.ymax = box.ymin + (box.ymax - box.ymin)
        self.redraw_image()

    def on_button_release(self, event):
        pass  # No action needed on button release for multiple selection

    def delete_selected_boxes(self):
        if self.selected_boxes:
            for box in self.selected_boxes:
                self.detections.remove(box)
            self.selected_boxes = []
            self.offsets = []
            self.redraw_image()
        else:
            messagebox.showinfo("Info", "No bounding boxes selected.")

    def redraw_image(self):
        image = load_image(self.image_path)
        annotated_image = self.draw_detections(image, self.detections)
        self.display_image(annotated_image)

    def draw_detections(self, image: Image.Image, detections: List[DetectionResult]) -> Image.Image:
        image_cv2 = np.array(image)
        image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)

        for detection in detections:
            box = detection.box
            color = (0, 255, 0)  # Green for bounding boxes
            cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color, 2)
            cv2.putText(image_cv2, f'{detection.label}: {detection.score:.2f}', 
                        (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

        return Image.fromarray(cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB))

    def process_image(self):
        labels_input = self.prompt_entry.get("1.0", tk.END).strip()
        labels = [label.strip() + '.' for label in labels_input.split(',') if label.strip()]

        if self.image_path and labels:
            self.detections = detect(load_image(self.image_path), labels)
            annotated_image = self.draw_detections(load_image(self.image_path), self.detections)
            self.display_image(annotated_image)
            self.result_label.config(text="Processing complete. Click on annotations to select.")
        else:
            messagebox.showerror("Error", "Please upload an image and enter labels.")

if __name__ == "__main__":
    root = tk.Tk()
    app = AIModelApp(root)
    root.mainloop()
