In [None]:
import os
import cv2
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from ultralytics import YOLO
from datetime import datetime
from roboflow import Roboflow
import matplotlib.pyplot as plt
from PIL import Image, ImageTk
import threading

# Initialize YOLOv11 and Roboflow models
model = YOLO("yolo11m.pt")  # Use YOLOv11 model
rf = Roboflow(api_key="7P6wSkFD6Zb39ZYTL84S")  # Updated API Key
project = rf.workspace("animal-class").project("animal-class-cnxhg")  # Updated Workspace and Project Name
rf_model = project.version(5).model  # Use Version 5 of the Model

# Set up output directories
output_dir_detected = "Output/Detected_Animals"
output_dir_classified = "Output/Classified_Animals"
os.makedirs(output_dir_detected, exist_ok=True)
os.makedirs(output_dir_classified, exist_ok=True)

# Function to save cropped images for each detection
def save_cropped_image(image_rgb, detection):
    x1, y1, x2, y2, conf, cls_id = map(int, detection)
    cropped_image = image_rgb[y1:y2, x1:x2]
    cropped_filename = f'object_{datetime.now().strftime("%Y%m%d%H%M%S%f")}.jpg'
    cropped_path = os.path.join(output_dir_detected, cropped_filename)
    cv2.imwrite(cropped_path, cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR))
    return cropped_path

# GUI Setup
class ObjectDetectionApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Animal Detection and Classification")
        self.root.geometry("900x800")
        self.root.configure(bg="#F5F5F5")  # Light gray background

        # Header
        header = tk.Label(
            root,
            text="Animal Detection and Classification",
            font=("Arial", 20, "bold"),
            bg="#2E3B4E",
            fg="white",
            pady=10
        )
        header.pack(fill=tk.X)

        # Upload Section
        upload_frame = tk.Frame(root, bg="#F5F5F5")
        upload_frame.pack(pady=20)
        upload_button = ttk.Button(upload_frame, text="Upload Image or Video", command=self.upload_file)
        upload_button.pack()

        # Threshold Sliders
        threshold_frame = tk.Frame(root, bg="#F5F5F5")
        threshold_frame.pack(pady=20, fill=tk.X)

        # Detection Threshold
        detection_label = tk.Label(
            threshold_frame, text="Detection Threshold:", font=("Arial", 12), bg="#F5F5F5"
        )
        detection_label.grid(row=0, column=0, padx=10, sticky="w")

        self.detection_threshold = tk.DoubleVar(value=0.1)
        detection_slider = ttk.Scale(
            threshold_frame,
            from_=0.1,
            to=1.0,
            orient=tk.HORIZONTAL,
            variable=self.detection_threshold,
            command=self.update_threshold_label,
        )
        detection_slider.grid(row=0, column=1, padx=10, sticky="we")
        self.detection_threshold_label = tk.Label(
            threshold_frame,
            text=f"{self.detection_threshold.get():.2f}",
            font=("Arial", 12),
            bg="#F5F5F5"
        )
        self.detection_threshold_label.grid(row=0, column=2, padx=10)

        # Classification Threshold
        classification_label = tk.Label(
            threshold_frame, text="Classification Threshold:", font=("Arial", 12), bg="#F5F5F5"
        )
        classification_label.grid(row=1, column=0, padx=10, sticky="w")

        self.classification_threshold = tk.DoubleVar(value=0.5)
        classification_slider = ttk.Scale(
            threshold_frame,
            from_=0.1,
            to=1.0,
            orient=tk.HORIZONTAL,
            variable=self.classification_threshold,
            command=self.update_threshold_label,
        )
        classification_slider.grid(row=1, column=1, padx=10, sticky="we")
        self.classification_threshold_label = tk.Label(
            threshold_frame,
            text=f"{self.classification_threshold.get():.2f}",
            font=("Arial", 12),
            bg="#F5F5F5"
        )
        self.classification_threshold_label.grid(row=1, column=2, padx=10)

        threshold_frame.columnconfigure(1, weight=1)

        # Progress Bar
        self.progress_bar = ttk.Progressbar(self.root, orient="horizontal", length=900, mode="determinate")
        self.progress_bar.pack(pady=20)

        # Canvas for Image Display
        self.canvas_frame = tk.Frame(self.root, bg="#D3D3D3", relief=tk.GROOVE, bd=2)
        self.canvas_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=10)

        self.canvas = tk.Canvas(self.canvas_frame, bg="#FFFFFF")
        self.canvas.pack(fill=tk.BOTH, expand=True)

        # To store the last uploaded image path
        self.last_image_path = None

    def upload_file(self):
        file_path = filedialog.askopenfilename(filetypes=[("Image/Video files", "*.jpg *.png *.mp4")])
        if file_path:
            if file_path.endswith(('.jpg', '.png')):
                self.last_image_path = file_path
                self.process_image(file_path)
            elif file_path.endswith('.mp4'):
                threading.Thread(target=self.process_video, args=(file_path,)).start()
            else:
                messagebox.showerror("Invalid File", "Please upload a .jpg, .png, or .mp4 file.")

    def process_image(self, image_path):
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        print("Processing Image with YOLO...")
        results = model(image_rgb)
        processed_img_path = self.detect_and_classify(image_rgb, results)
        self.display_image(processed_img_path)

    def process_video(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        processed_frames = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = model(frame_rgb)
            processed_frame_path = self.detect_and_classify(frame_rgb, results)
            self.display_image(processed_frame_path)
            processed_frames += 1
            self.progress_bar["value"] = (processed_frames / frame_count) * 100
            self.root.update()

        cap.release()
        self.progress_bar["value"] = 0

    def detect_and_classify(self, image_rgb, results):
        detection_threshold = self.detection_threshold.get()
        classification_threshold = self.classification_threshold.get()

        plt.figure(figsize=(12, 8), dpi=150)
        plt.imshow(image_rgb)
        ax = plt.gca()

        detected_classes = {}

        if results and hasattr(results[0], 'boxes') and results[0].boxes.data.numel() > 0:
            detections = results[0].boxes.data
            for det in detections:
                x1, y1, x2, y2, conf, cls_id = det[:6].tolist()
                if conf >= detection_threshold:
                    try:
                        cropped_path = save_cropped_image(image_rgb, [x1, y1, x2, y2, conf, cls_id])

                        cropped_image = Image.open(cropped_path)
                        cropped_image = cropped_image.resize((640, 640))
                        cropped_image.save(cropped_path)

                        prediction = rf_model.predict(cropped_path).json()

                        if "predictions" in prediction and len(prediction["predictions"]) > 0:
                            top_prediction = prediction["predictions"][0]
                            obj_class = top_prediction.get("top", "unknown")
                            confidence = top_prediction.get("confidence", 0)

                            if confidence >= classification_threshold:
                                confidence *= 100
                                detected_classes[obj_class] = detected_classes.get(obj_class, 0) + 1
                                rect = plt.Rectangle(
                                    (x1, y1), x2 - x1, y2 - y1,
                                    linewidth=2, edgecolor="red" if obj_class.lower() == "coyote" else "green",
                                    facecolor="none"
                                )
                                ax.add_patch(rect)
                                ax.text(
                                    x1, y1, f"{obj_class}: {confidence:.2f}%",
                                    color="Black", fontsize=12, backgroundcolor="red" if obj_class.lower() == "coyote" else "green"
                                )
                    except Exception as e:
                        print(f"Error processing detection: {e}")

        plt.axis("off")
        output_path = os.path.join(output_dir_classified, f"{datetime.now().strftime('%Y%m%d%H%M%S')}.jpg")
        plt.savefig(output_path, bbox_inches="tight", dpi=150)
        plt.close()
        return output_path

    def display_image(self, path):
        img = Image.open(path)
        img = img.resize((900, 500), Image.LANCZOS)
        img_tk = ImageTk.PhotoImage(img)
        self.canvas.create_image(0, 0, anchor=tk.NW, image=img_tk)
        self.canvas.image = img_tk

    def update_threshold_label(self, event=None):
        self.detection_threshold_label.config(text=f"{self.detection_threshold.get():.2f}")
        self.classification_threshold_label.config(text=f"{self.classification_threshold.get():.2f}")


# Run the app
if __name__ == "__main__":
    root = tk.Tk()
    app = ObjectDetectionApp(root)
    root.mainloop()
