In [38]:
import tkinter as tk
from tkinter import ttk
from tkinter import filedialog, messagebox, simpledialog
import cv2
from PIL import Image, ImageTk
import os
import random
import csv
from autodistill_grounding_dino import GroundingDINO
from autodistill.detection import CaptionOntology


class ImageAnnotator:
    def __init__(self, root):
        self.root = root
        self.root.title("Image Annotator")

        # Initialize attributes
        self.images = []
        self.current_image_index = None
        self.current_image = None
        self.bbox_start = None
        self.bbox_list = {}
        self.selected_class = None
        self.classes = {}  # Define the classes attribute here

        # Create a frame for the files section
        self.files_frame = tk.Frame(root, bg="white")
        self.files_frame.pack(side=tk.LEFT, fill=tk.BOTH,
                              expand=True, padx=10, pady=10)

        # Create a label for the files section
        self.files_label = tk.Label(
            self.files_frame, text="Image Files", font=("Helvetica", 12), bg="white")
        self.files_label.pack(side=tk.TOP, pady=(0, 5))

        # Create a listbox to display the image files
        self.files_listbox = tk.Listbox(
            self.files_frame, width=40, height=20, bg="#f0f0f0")
        self.files_listbox.pack(
            side=tk.TOP, fill=tk.BOTH, padx=(0, 5), pady=(0, 5))
        self.files_listbox.bind('<<ListboxSelect>>', self.load_image)

        # Create buttons for loading and deleting image files
        self.load_button = tk.Button(
            self.files_frame, text="Load Images", command=self.load_images)
        self.load_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        self.delete_button = tk.Button(
            self.files_frame, text="Delete Image", command=self.delete_image)
        self.delete_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        # Button to annotate all the images at once using the AI model
        self.ai_annotate_all_button = tk.Button(
            self.files_frame, text="AI Annotate All", command=self.annotate_all_with_model)
        self.ai_annotate_all_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)


        # Create a frame for the classes section
        self.class_frame = tk.Frame(root, bg="white")
        self.class_frame.pack(side=tk.RIGHT, fill=tk.BOTH,
                              expand=True, padx=10, pady=10)

        # Create a label for the classes section
        self.class_label = tk.Label(
            self.class_frame, text="Annotation Classes", font=("Helvetica", 12), bg="white")
        self.class_label.pack(side=tk.TOP, pady=(0, 5))

        # Create a listbox to display the available classes
        self.class_listbox = tk.Listbox(
            self.class_frame, width=20, height=10, bg="#f0f0f0")
        self.class_listbox.pack(
            side=tk.TOP, fill=tk.BOTH, padx=(0, 5), pady=(0, 5))
        self.class_listbox.bind('<<ListboxSelect>>', self.select_class)

        # Add colored rectangles behind the class labels
        for class_name, color in self.classes.items():
            self.class_listbox.insert(tk.END, class_name)
            self.class_listbox.itemconfig(tk.END, {'bg': color})

        # Create buttons for adding and deleting classes
        self.add_class_button = tk.Button(
            self.class_frame, text="Add Class", command=self.add_class)
        self.add_class_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        self.delete_class_button = tk.Button(
            self.class_frame, text="Delete Class", command=self.delete_class)
        self.delete_class_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        # Create a button to trigger AI-assisted annotation
        self.ai_annotate_button = tk.Button(
            root, text="AI Annotate", command=self.annotate_with_model)
        self.ai_annotate_button.pack(side=tk.BOTTOM, pady=10)

        # Create a button to clear annotations for the current image
        self.clear_button = tk.Button(
            root, text="Clear Annotations", command=self.clear_annotations_for_image)
        self.clear_button.pack(side=tk.BOTTOM, pady=10)

        # Create a button to save annotations
        self.save_button = tk.Button(
            root, text="Save Annotations", command=self.save_annotations)
        self.save_button.pack(side=tk.BOTTOM, pady=10)

        # Initialize the current image index
        self.current_image_index = None

        # Create buttons for navigating between images
        self.prev_button = tk.Button(
            self.files_frame, text="Previous", command=self.prev_image)
        self.prev_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        self.next_button = tk.Button(
            self.files_frame, text="Next", command=self.next_image)
        self.next_button.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)

        # Create a canvas to display the image and annotations
        self.canvas = tk.Canvas(
            self.root, bg="white", highlightbackground="gray", highlightthickness=1)
        self.canvas.pack(side=tk.BOTTOM, fill=tk.BOTH,
                         padx=10, pady=10, expand=True)
        self.canvas.bind("<Button-1>", self.start_bbox)
        self.canvas.bind("<B1-Motion>", self.draw_bbox)
        self.canvas.bind("<ButtonRelease-1>", self.end_bbox)
        self.annotations_dict = {}

        # Ask the user to load images
        self.load_images()
        # Ask the user to update the classes
        self.add_class()

    def next_image(self):
        if self.current_image_index is not None and self.current_image_index < len(self.images) - 1:
            self.current_image_index += 1
            self.load_selected_image()

    def prev_image(self):
        if self.current_image_index is not None and self.current_image_index > 0:
            self.current_image_index -= 1
            self.load_selected_image()

    def load_selected_image(self):
        self.save_annotations_temp()
        self.current_image = cv2.imread(self.images[self.current_image_index])
        self.display_image()
        image_path = self.images[self.current_image_index]
        if image_path in self.annotations_dict:
            for bbox, cls in self.annotations_dict[image_path]:
                self.canvas.create_rectangle(
                    *bbox, outline=self.classes[cls], tags="bbox")
    def load_image(self, event):
        selected_index = self.files_listbox.curselection()
        if selected_index:
            # Save annotations before loading new image
            self.save_annotations_temp()

            self.current_image_index = int(selected_index[0])  # Convert to integer
            self.current_image = cv2.imread(
                self.images[self.current_image_index])
            self.display_image()

            # Display annotations if available for the loaded image
            image_path = self.images[self.current_image_index]
            if image_path in self.annotations_dict:
                for bbox, cls in self.annotations_dict[image_path]:
                    self.canvas.create_rectangle(
                        *bbox, outline=self.classes[cls], tags="bbox")


    def add_class(self):
        new_class = simpledialog.askstring(
            "Add Class", "Object to detect:")
        if new_class:
            #lower case the class name and check if it already exists with all the lower case in self.classes
            if new_class.lower() in [x.lower() for x in self.classes.keys()]:
                messagebox.showwarning(
                    "Duplicate Class", "This class already exists.")

            
            else:
                color = '#' + "%06x" % random.randint(0, 0xFFFFFF)
                self.class_listbox.insert(tk.END, new_class)
                self.class_listbox.itemconfig(
                    tk.END, {'bg': color})  # Set background color
                self.classes[new_class] = color
                messagebox.showinfo(
                    "Class Added", "New class added successfully.")

    def delete_class(self):
        selected_index = self.class_listbox.curselection()
        if selected_index:
            selected_class = self.class_listbox.get(selected_index[0])
            del self.classes[selected_class]
            self.class_listbox.delete(selected_index[0])
            messagebox.showinfo("Class Deleted", "Class deleted successfully.")

    def select_class(self, event):
        selected_index = self.class_listbox.curselection()
        if selected_index:
            self.selected_class = self.class_listbox.get(selected_index[0])

    def start_bbox(self, event):
        if self.selected_class and self.current_image is not None:
            self.bbox_start = (event.x, event.y)
            self.draw_bbox(event)

    def draw_bbox(self, event):
        if self.bbox_start is not None:
            x0, y0 = self.bbox_start
            x1, y1 = (event.x, event.y)
            # Delete previous bounding boxes and redraw existing ones
            self.canvas.delete("bbox")
            if self.current_image_index is not None:
                image_path = self.images[self.current_image_index]
                if image_path in self.annotations_dict:
                    for bbox, cls in self.annotations_dict[image_path]:
                        self.canvas.create_rectangle(
                            *bbox, outline=self.classes[cls], tags="bbox")
            # Draw the current bounding box
            self.canvas.create_rectangle(
                x0, y0, x1, y1, outline=self.classes[self.selected_class], tags="bbox")

    def end_bbox(self, event):
        if self.bbox_start is not None:
            x0, y0 = self.bbox_start
            x1, y1 = (event.x, event.y)
            # Save the bounding box coordinates and class label
            bbox = ((x0, y0, x1, y1), self.selected_class)
            image_path = self.images[self.current_image_index]
            if image_path not in self.bbox_list:
                self.bbox_list[image_path] = []
            self.bbox_list[image_path].append(bbox)
            self.bbox_start = None
        self.save_annotations_temp()  # Save the annotations
        self.load_selected_image()  # Reload the image with the new annotations

    def display_image(self):
        # Convert image from OpenCV BGR format to RGB format
        img_rgb = cv2.cvtColor(self.current_image, cv2.COLOR_BGR2RGB)
        # Resize image to fit into the canvas while maintaining aspect ratio
        img_height, img_width, _ = img_rgb.shape
        canvas_width = self.canvas.winfo_width()
        canvas_height = self.canvas.winfo_height()
        if canvas_width / img_width < canvas_height / img_height:
            resize_factor = canvas_width / img_width
        else:
            resize_factor = canvas_height / img_height
        resized_img = cv2.resize(
            img_rgb, (int(img_width * resize_factor), int(img_height * resize_factor)))
        # Convert resized image to ImageTk format
        img_tk = ImageTk.PhotoImage(Image.fromarray(resized_img))
        # Display image on canvas
        self.canvas.create_image(0, 0, anchor=tk.NW, image=img_tk)
        # Keep a reference to the image to prevent it from being garbage collected
        self.canvas.image = img_tk

    def load_images(self):
        try:
            file_paths = filedialog.askopenfilenames(
                filetypes=[("Image files", "*.jpg; *.jpeg; *.png")])
            if file_paths:
                self.images = list(file_paths)
                self.files_listbox.delete(0, tk.END)  # Clear previous entries
                for image_path in self.images:
                    self.files_listbox.insert(tk.END, os.path.basename(image_path))
                messagebox.showinfo(
                    "Images Loaded", "Image files loaded successfully.")
                self.clear_annotations()  # Clear annotations when loading new images
        except Exception as e:
            messagebox.showerror(
                "Error", f"An error occurred while loading images: {str(e)}")

    def delete_image(self):
        try:
            selected_index = self.files_listbox.curselection()
            if selected_index:
                # Save annotations before clearing
                self.save_annotations_temp()

                del self.images[selected_index[0]]
                self.files_listbox.delete(selected_index[0])
                messagebox.showinfo(
                    "Image Deleted", "Image file deleted successfully.")
        except Exception as e:
            messagebox.showerror(
                "Error", f"An error occurred while deleting image: {str(e)}")

    def save_annotations(self):
        #check if there are any annotations to save
        if not self.bbox_list:
            messagebox.showwarning("No Annotations", "No annotations to save.")
            return
        try:
            # Save annotations before saving
            self.save_annotations_temp()
            # resize the bounding box coordinates to the original image size
            for image_path, annotations in self.annotations_dict.items():
                img = cv2.imread(image_path)
                img_height, img_width, _ = img.shape
                canvas_width = self.canvas.winfo_width()
                canvas_height = self.canvas.winfo_height()
                if canvas_width / img_width < canvas_height / img_height:
                    resize_factor = canvas_width / img_width
                else:
                    resize_factor = canvas_height / img_height
                for i, (bbox, cls) in enumerate(annotations):
                    x_min, y_min, x_max, y_max = bbox
                    x_min /= resize_factor
                    y_min /= resize_factor
                    x_max /= resize_factor
                    y_max /= resize_factor
                    annotations[i] = ((x_min, y_min, x_max, y_max), cls)

            # Save the annotations to a CSV file
            save_path = filedialog.asksaveasfilename(
                defaultextension=".csv", filetypes=[("CSV files", "*.csv")])
            if not save_path:
                return
            with open(save_path, 'w', newline='') as csvfile:
                csv_writer = csv.writer(csvfile)
                csv_writer.writerow(
                    ["Image", "Class", "X_min", "Y_min", "X_max", "Y_max"])
                for image_path, annotations in self.annotations_dict.items():
                    image_name = os.path.basename(image_path)
                    for bbox, cls in annotations:
                        x_min, y_min, x_max, y_max = bbox
                        csv_writer.writerow(
                            [image_name, cls, x_min, y_min, x_max, y_max])

            messagebox.showinfo("Annotations Saved",
                                "Annotations saved successfully.")
        except Exception as e:
            messagebox.showerror(
                "Error", f"An error occurred while saving annotations: {str(e)}")


    def save_annotations_temp(self):
        """
        Save annotations temporarily before clearing the annotations dictionary.
        """
        if self.current_image_index is not None:
            image_path = self.images[self.current_image_index]
            if image_path not in self.annotations_dict:
                self.annotations_dict[image_path] = self.bbox_list.get(
                    image_path, [])
            else:
                self.annotations_dict[image_path].extend(
                    self.bbox_list.get(image_path, []))
            # Clear annotations for the current image
            self.bbox_list.pop(image_path, None)

    def clear_annotations(self):
        self.bbox_list = {}  # Clear the bounding box list
        self.canvas.delete("bbox")  # Clear annotations displayed on the canvas

    def clear_annotations_for_image(self):
        if self.current_image_index is not None:
            # Clear annotations for the current image
            image_path = self.images[self.current_image_index]
            self.bbox_list.pop(image_path, None)
            self.canvas.delete("bbox")
            # Clear the annotations dictionary for the current image
            self.annotations_dict.pop(image_path, None)

            # Optionally, you can also reset the selected class
            self.selected_class = None

            # Show a message informing the user that annotations are cleared
            messagebox.showinfo(
                "Annotations Cleared", "Annotations for the current image cleared. You can now redraw annotations.")

            # You may also want to reset any other relevant attributes or UI elements
            # For example, if you want to allow users to select a new class for annotations:
            self.class_listbox.selection_clear(0, tk.END)
            self.selected_class = None

    def annotate_with_model(self):
        if self.current_image is not None:
            # Create a progress bar
            progress_bar = ttk.Progressbar(self.root, orient='horizontal', mode='determinate')
            progress_bar.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=5)
            progress_bar['value']=20
            progress_bar.update_idletasks()
            # Get the classes entered by the user
            ontology_dict = self.get_classes_from_user()
            ontology = CaptionOntology(ontology_dict)
            base_model = GroundingDINO(ontology=ontology)
            # resize the image to the canvas size
            img_rgb = cv2.cvtColor(self.current_image, cv2.COLOR_BGR2RGB)
            img_height, img_width, _ = img_rgb.shape
            canvas_width = self.canvas.winfo_width()
            canvas_height = self.canvas.winfo_height()
            if canvas_width / img_width < canvas_height / img_height:
                resize_factor = canvas_width / img_width
            else:
                resize_factor = canvas_height / img_height
            resized_img = cv2.resize(
                img_rgb, (int(img_width * resize_factor), int(img_height * resize_factor)))
            result = base_model.predict(resized_img)
            progress_bar['value']=50
            progress_bar.update_idletasks()
            

            for i, (bbox, cls, conf) in enumerate(zip(result.xyxy, result.class_id, result.confidence), start=1):
               
                
                # save the bounding box coordinates and class label
                bbox = ((bbox[0], bbox[1], bbox[2], bbox[3]),
                        list(ontology_dict.keys())[cls])
                image_path = self.images[self.current_image_index]
                if image_path not in self.bbox_list:
                    self.bbox_list[image_path] = []
                self.bbox_list[image_path].append(bbox)
            progress_bar['value']=100
            progress_bar.update_idletasks()

            # Hide the progress bar when prediction is complete
            progress_bar.pack_forget()

        self.save_annotations_temp()  # Save the annotations
        self.load_selected_image()  # Reload the image with the new annotations

    def get_classes_from_user(self):
        classes = {}
        for i in range(self.class_listbox.size()):
            class_name = self.class_listbox.get(i)
            classes[class_name] = class_name
        return classes

    def on_closing(self):
        # check if the user wants to save annotations before closing
        if messagebox.askokcancel("Quit", "Do you want to save annotations before quitting?"):
            self.save_annotations()
        self.root.destroy()

    # Method to annotate all the images at once using the AI model
    def annotate_all_with_model(self):
        # Create a progress bar
        progress_bar = ttk.Progressbar(self.root, orient='horizontal', mode='determinate')
        progress_bar.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=5)
        progress_bar['value']=20
        progress_bar.update_idletasks()
        # Get the classes entered by the user
        ontology_dict = self.get_classes_from_user()
        ontology = CaptionOntology(ontology_dict)
        base_model = GroundingDINO(ontology=ontology)
        for i, image_path in enumerate(self.images, start=1):
            img = cv2.imread(image_path)
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_height, img_width, _ = img_rgb.shape
            canvas_width = self.canvas.winfo_width()
            canvas_height = self.canvas.winfo_height()
            if canvas_width / img_width < canvas_height / img_height:
                resize_factor = canvas_width / img_width
            else:
                resize_factor = canvas_height / img_height
            resized_img = cv2.resize(
                img_rgb, (int(img_width * resize_factor), int(img_height * resize_factor)))
            result = base_model.predict(resized_img)
            for bbox, cls, conf in zip(result.xyxy, result.class_id, result.confidence):
                # save the bounding box coordinates and class label
                bbox = ((bbox[0], bbox[1], bbox[2], bbox[3]),
                        list(ontology_dict.keys())[cls])
                if image_path not in self.bbox_list:
                    self.bbox_list[image_path] = []
                self.bbox_list[image_path].append(bbox)
            progress_bar['value']=20 + (i/len(self.images))*80
            progress_bar.update_idletasks()
        progress_bar['value']=100
        progress_bar.update_idletasks()
        # Hide the progress bar when prediction is complete
        progress_bar.pack_forget()
        self.save_annotations_temp()  # Save the annotations
        self.load_selected_image()  # Reload the image with the new annotations

if __name__ == "__main__":
    root = tk.Tk()
    app = ImageAnnotator(root)
    # Bind on_closing method to close window event
    root.protocol("WM_DELETE_WINDOW", app.on_closing)
    root.mainloop()


trying to load grounding dino directly
final text_encoder_type: bert-base-uncased
trying to load grounding dino directly
final text_encoder_type: bert-base-uncased
