In [4]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import numpy as np
import json
import os
import random

from skimage import draw
from skimage import io as skio

import torch
from torch.utils.data import Dataset

class AppleFromDronesDetectionDataset(Dataset):
    def __init__(self, json_file, image_dir, transform=None, image_size=256):
        self.images_dir = image_dir

        with open(json_file, 'r') as fp:
            json_str = ''.join(fp.readlines())

        self.annotations = json.loads(json_str)
        self.image_list = list(self.annotations.keys())
        self.transform = transform
        self.boxes = []

        for image_name in self.image_list:
            annot = self.annotations[image_name]
            boxes = []
            for apple in annot:
                cx, cy, r = apple['cx'], apple['cy'], apple['r']
                x0, x1 = max(cx - r, 0), min(cx + r, image_size - 1)
                y0, y1 = max(cy - r, 0), min(cy + r, image_size - 1)
                boxes.append((x0, y0, x1, y1))

            self.boxes.append(np.array(boxes))

        self._num_apples = None

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_id = f'{idx:05d}:{self.image_list[idx]}'
        image_path = os.path.join(self.images_dir, self.image_list[idx])
        image = skio.imread(image_path)
        boxes = self.boxes[idx]
        n_boxes = boxes.shape[0]

        if self.transform:
            transformed = self.transform(image=image, bboxes=boxes, class_labels=["apple"] * len(boxes))
            image = transformed["image"]
            boxes = transformed["bboxes"]

        labels = torch.ones((n_boxes,), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id
        }

        return image, target, image_id

    def num_apples(self):
        if self._num_apples is not None:
            return self._num_apples

        acc = 0
        for boxes in self.boxes:
            acc += boxes.shape[0]

        self._num_apples = acc
        return self._num_apples


class AppleDetectionGUI:
    def __init__(self, root, dataset):
        self.root = root
        self.dataset = dataset
        self.index = 0

        self.root.title("Apple Detection from Drones")
        self.root.geometry("700x500")

        self.canvas = tk.Canvas(root, width=300, height=300)
        self.canvas.pack()

        self.label = tk.Label(root, text="", font=("Helvetica", 14))
        self.label.pack()

        self.fruit_label = tk.Label(root, text="", font=("Helvetica", 14))
        self.fruit_label.pack()

        self.prev_button = tk.Button(root, text="Previous", command=self.show_prev_image)
        self.prev_button.pack(side=tk.LEFT, padx=10, pady=10)

        self.next_button = tk.Button(root, text="Next", command=self.show_next_image)
        self.next_button.pack(side=tk.RIGHT, padx=10, pady=10)

        self.load_button = tk.Button(root, text="Load Image", command=self.load_image)
        self.load_button.pack(side=tk.BOTTOM, pady=10)

        self.show_image()

    def draw_boxes(self, img, boxes):
        for (x0, y0, x1, y1) in boxes:
            rr, cc = draw.rectangle_perimeter(start=(y0, x0), end=(y1, x1), shape=img.shape)
            img[rr, cc] = [255, 255, 0]
        return img

    def show_image(self):
        img, target, img_id = self.dataset[self.index]
        display_img = self.draw_boxes(img.copy(), target['boxes'])
        pil_img = Image.fromarray(display_img)
        img_tk = ImageTk.PhotoImage(pil_img)

        self.canvas.create_image(0, 0, anchor=tk.NW, image=img_tk)
        self.canvas.image = img_tk

        num_apples = len(target['boxes'])
        self.label.config(text=f"Image ID: {img_id} \n\nNumber of Apples: {num_apples}\n")

        if num_apples > 0:
            self.fruit_label.config(text="Fruit: Apple")
        else:
            self.fruit_label.config(text="Fruit: None")

    def show_prev_image(self):
        self.index = (self.index - 1) % len(self.dataset)
        self.show_image()

    def show_next_image(self):
        self.index = (self.index + 1) % len(self.dataset)
        self.show_image()

    def load_image(self):
        file_path = filedialog.askopenfilename()
        if file_path:
            img = skio.imread(file_path)
            pil_img = Image.fromarray(img)
            img_tk = ImageTk.PhotoImage(pil_img)

            self.canvas.create_image(0, 0, anchor=tk.NW, image=img_tk)
            self.canvas.image = img_tk

            self.label.config(text=f"Loaded Image: {os.path.basename(file_path)}, Image ID: {img_id}, Number of Apples: {len(target['boxes'])}")

if __name__ == "__main__":
    root = tk.Tk()

    train_dataset = AppleFromDronesDetectionDataset('./training.json', './images', transform=None)
    gui = AppleDetectionGUI(root, train_dataset)

    root.mainloop()
