In [2]:
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from ipywidgets import Button, VBox, HBox, Output, IntText, Dropdown, Label, Layout, BoundedIntText
from IPython.display import display
from PIL import Image
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache


In [3]:
# Set up a ThreadPoolExecutor for asynchronous image loading
executor = ThreadPoolExecutor(max_workers=8)

In [4]:
@lru_cache(maxsize=100)
def load_image(image_path, max_size=800):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize image if it's too large
    h, w = img.shape[:2]
    if max(h, w) > max_size:
        scale = max_size / max(h, w)
        img = cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)
    
    return img

In [5]:
def load_annotations(annotation_path):
    with open(annotation_path, 'r') as file:
        annotations = file.readlines()
    boxes = []
    for line in annotations:
        parts = line.strip().split()
        class_id = int(parts[0])
        x_center = float(parts[1])
        y_center = float(parts[2])
        width = float(parts[3])
        height = float(parts[4])
        boxes.append((class_id, x_center, y_center, width, height))
    return boxes


In [6]:
def save_annotations(annotation_path, boxes):
    with open(annotation_path, 'w') as file:
        for box in boxes:
            file.write(' '.join(map(str, box)) + '\n')

In [7]:
def draw_boxes(ax, image, boxes):
    ax.imshow(image)
    for i, (class_id, x_center, y_center, width, height) in enumerate(boxes):
        x = x_center - width / 2
        y = y_center - height / 2
        rect = patches.Rectangle((x * image.shape[1], y * image.shape[0]),
                                 width * image.shape[1], height * image.shape[0],
                                 linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x * image.shape[1], y * image.shape[0], str(i), color='blue')
    ax.axis('off')

In [8]:
class AnnotationEditor:
    def __init__(self, image_dir, annotation_dir):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir

        # Get all folder paths and sort them
        self.folder_paths = sorted(self.get_all_folder_paths())

        # Dropdown to select the folder
        self.folder_dropdown = Dropdown(
            options=[(os.path.basename(folder), folder) for folder in self.folder_paths],
            description='Folder:',
            layout=Layout(margin='10px 0 0 0')
        )
        self.folder_dropdown.observe(self.update_image_list, names='value')

        self.image_index_input = BoundedIntText(
            description='Image Index:',
            min=1,
            layout=Layout(margin='0 0 0 0')
        )
        self.image_index_input.observe(self.update_image_index, names='value')

        self.image_path_label = Label(value="", layout=Layout(margin='0 0 0 0'))
        
        self.box_index_input = IntText(
            description='Box Index:',
            min=0,
            layout=Layout(margin='0 0 0 0')
        )

        self.remove_button = Button(description="Remove Selected", layout=Layout(margin='30px 0 0px 0px'))
        self.remove_button.on_click(self.remove_selected_box)

        self.next_button = Button(description="Next", layout=Layout(margin='10px 10px 10px 0'))
        self.next_button.on_click(self.next_image)

        self.prev_button = Button(description="Previous", layout=Layout(margin='10px 0 10px 10px'))
        self.prev_button.on_click(self.prev_image)

         # Add delete button
        self.delete_button = Button(description="Delete Image", 
                                    layout=Layout(margin='10px 0 10px 0'),
                                    button_style='danger')  # 'danger' gives it a red color
        self.delete_button.on_click(self.delete_current_image)

        self.status_label = Label(value="", layout=Layout(margin='0 0 0 0'))
        self.index_label = Label(value="", layout=Layout(margin='0 0 0 0'))

        self.out = Output()

        display_elements = [
            self.delete_button,
            self.folder_dropdown,
            self.image_index_input,
            self.image_path_label,
            self.index_label,
            HBox([self.prev_button, self.next_button]),
            self.box_index_input,
            self.remove_button,
            self.status_label,
            self.out,
        ]

        display(VBox(display_elements))

        self.update_image_list({'new': self.folder_paths[0]})
        
    def delete_current_image(self, b):
        if not self.image_paths:
            self.status_label.value = "No image to delete."
            return

        current_image_path = self.image_paths[self.current_index]
        current_annotation_path = self.get_annotation_path(current_image_path)

        # Delete image file
        if os.path.exists(current_image_path):
            os.remove(current_image_path)

        # Delete annotation file
        if os.path.exists(current_annotation_path):
            os.remove(current_annotation_path)

        # Update image list
        self.image_paths = self.get_all_image_paths()

        if not self.image_paths:
            self.status_label.value = "All images deleted."
            self.display_image_with_boxes({'new': None})
        else:
            # Adjust current index if necessary
            if self.current_index >= len(self.image_paths):
                self.current_index = len(self.image_paths) - 1

            self.status_label.value = f"Image and annotation deleted. Showing next image."
            self.display_image_with_boxes({'new': self.image_paths[self.current_index]})

        # Update image index input
        self.image_index_input.max = len(self.image_paths)
        if self.image_paths:
            self.image_index_input.value = self.current_index + 1

    def get_all_folder_paths(self):
        return [os.path.join(self.image_dir, folder) for folder in os.listdir(self.image_dir) if os.path.isdir(os.path.join(self.image_dir, folder))]

    def update_image_list(self, change):
        self.current_folder = change['new']
        self.image_paths = self.get_all_image_paths()
        self.current_index = 0
        self.status_label.value = ""

        self.image_index_input.max = len(self.image_paths)
        if self.image_paths:
            self.image_index_input.value = 1
            self.display_image_with_boxes({'new': self.image_paths[0]})
        else:
            self.display_image_with_boxes({'new': None})

    def get_all_image_paths(self):
        image_paths = []
        for root, _, files in os.walk(self.current_folder):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.png'):
                    image_paths.append(os.path.join(root, file))
        return sorted(image_paths)

    def get_annotation_path(self, image_path):
        rel_path = os.path.relpath(image_path, self.image_dir)
        return os.path.join(self.annotation_dir, os.path.splitext(rel_path)[0] + '.txt')

    def display_image_with_boxes(self, change):
        image_path = change['new']
        
        if not image_path:
            self.image_path_label.value = "No images found."
            self.index_label.value = ""
            self.out.clear_output()
            return
        
        annotation_path = self.get_annotation_path(image_path)
        self.current_index = self.image_paths.index(image_path)
        self.status_label.value = ""

        # Load image asynchronously
        future = executor.submit(load_image, image_path)
        
        self.boxes = load_annotations(annotation_path)
        
        self.image_path_label.value = f"Image Path: {image_path}"
        self.index_label.value = f"Image {self.current_index + 1} of {len(self.image_paths)}"
        
        self.image_index_input.unobserve(self.update_image_index, names='value')
        self.image_index_input.value = self.current_index + 1
        self.image_index_input.observe(self.update_image_index, names='value')

        self.out.clear_output()
        with self.out:
            # Wait for the image to load
            self.image = future.result()
            
            # Create the plot
            fig, ax = plt.subplots(figsize=(8, 8 * self.image.shape[0] / self.image.shape[1]))
            draw_boxes(ax, self.image, self.boxes)
            plt.show()

    def update_image_index(self, change):
        index = change['new'] - 1
        if 0 <= index < len(self.image_paths):
            self.current_index = index
            self.display_image_with_boxes({'new': self.image_paths[self.current_index]})

    def remove_selected_box(self, b):
        index_to_remove = self.box_index_input.value
        image_path = self.image_paths[self.current_index]
        if image_path and 0 <= index_to_remove < len(self.boxes):
            del self.boxes[index_to_remove]
            annotation_path = self.get_annotation_path(image_path)
            save_annotations(annotation_path, self.boxes)
            self.display_image_with_boxes({'new': image_path})

    def next_image(self, b):
        if self.image_paths and self.current_index < len(self.image_paths) - 1:
            self.current_index += 1
            self.display_image_with_boxes({'new': self.image_paths[self.current_index]})
            self.status_label.value = ""  # Clear status message
        else:
            self.status_label.value = "Reached the end of the folder."

    def prev_image(self, b):
        if self.image_paths and self.current_index > 0:
            self.current_index -= 1
            self.display_image_with_boxes({'new': self.image_paths[self.current_index]})
            self.status_label.value = ""  # Clear status message
        else:
            self.status_label.value = "Reached the start of the folder."

In [11]:
image_folder = '/blue/hulcr/share/eric.kuo/Beetle_classifier/Data/00_Preprocessed_composite_images/train/images'
annotation_folder = '/blue/hulcr/share/eric.kuo/Beetle_classifier/Data/00_Preprocessed_composite_images/train/labels_10_copy'

editor = AnnotationEditor(image_folder, annotation_folder)

VBox(children=(Button(button_style='danger', description='Delete Image', layout=Layout(margin='10px 0 10px 0')â€¦

In [10]:
# import os

# def count_completed(train_folder, start_folder_name):
#     subfolders = sorted([f for f in os.listdir(train_folder) if os.path.isdir(os.path.join(train_folder, f))])

#     if start_folder_name not in subfolders:
#         raise ValueError(f"Folder '{start_folder_name}' not found in '{train_folder}'.")

#     start_index = subfolders.index(start_folder_name)

#     total_images = 0

#     for subfolder in subfolders[start_index:]:
#         subfolder_path = os.path.join(train_folder, subfolder)
#         image_files = [file for file in os.listdir(subfolder_path) if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))]
#         num_images = len(image_files)
#         total_images += num_images
#         print(f"Folder: {subfolder}, Images: {num_images}")

#     print(f"\nTotal images starting from '{start_folder_name}': {total_images}")
#     return total_images


# train_folder = '/blue/hulcr/share/eric.kuo/Beetle_classifier/Data/00_Preprocessed_composite_images/train/images'
# start_folder_name = 'Ambrosiodmus_asperatus'  

# # Count the images
# count_completed(train_folder, start_folder_name)
