In [None]:
import os
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from PIL import Image
 
# Initialize the figure and persistent axes once
fig, axes = plt.subplots(8, 16, figsize=(20, 10))
fig.subplots_adjust(left=0.01, right=0.99, bottom=0.03, top=0.99, wspace=0.05, hspace=0.05)  # Leave space for the button at the bottom
axes = axes.flatten()
 
# Set paths and parameters
root_folder     = 'Eurosat_Images'
chunk_size      = 128  # 16x8 grid
press_count     = 1  # Initialize the counter for "Next" button presses
counter_text    = fig.text(0.98, 0.98, f'It: {press_count}/{int(27000/128)+1}\n0 selected', ha='right', va='top', fontsize=14, color="red", transform=fig.transFigure)
selected_images = []
 
# Initialize the folder list
folders    = sorted([os.path.join(root_folder, folder) for folder in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, folder))])
folder_idx = 0  # Start with the first folder
image_idx  = 0  # Start with the first image
 
def get_selected_image_count():
    try:
        with open('selected_images.txt', 'r') as f:
            return sum(1 for _ in f)  # Count the lines in the file
    except FileNotFoundError:
        return 0  # Return 0 if the file doesn't exist yet
 
# Function to load images in chunks of 16x8
def load_image_chunk(folder_path, start_idx, chunk_size=128):
    image_paths = [os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith('.png')]
    chunk       = image_paths[start_idx:start_idx + chunk_size]
    return chunk
 
# Function to update and display an 8x8 grid of images without clearing figure
def update_display(image_paths):
    global axes  # Use persistent axes created once at start
 
    # Display images in each subplot
    for ax, img_path in zip(axes, image_paths):
        img = plt.imread(img_path) #Image.open(img_path)
        ax.imshow(img)
        #ax.set_title(os.path.basename(img_path), fontsize=6)
        ax.axis('off')
        ax.set_picker(True)
        ax._img_path = img_path  # Store the image path in the axis for access on click
        #img.close()
 
    # Hide any unused subplots
    for ax in axes[len(image_paths):]:
        ax.clear()
        ax.axis('off')
 
    plt.draw()
 
def on_key_press(event):
    if event.key == 'enter':
        next_chunk(event)
 
# Callback for selecting images  
def on_image_click(event):
    # Check if the left mouse button (button 1) was clicked
    if event.mouseevent.button != 1:
        return  # Ignore the event if it wasn't a left-click
 
    ax = event.artist
    img_path = ax._img_path
 
    # Toggle selection
    if img_path in selected_images:
        selected_images.remove(img_path)
        # Reload and display the original color image
        img = Image.open(img_path)
        ax.imshow(img)
    else:
        selected_images.append(img_path)
        # Convert to grayscale and display
        img = Image.open(img_path).convert('L')  # Convert to grayscale
        ax.imshow(img, cmap='gray')  # Display in grayscale
    plt.draw()
 
# Function to save selected image paths to a file
def save_selected_images():
    with open('selected_images.txt', 'a') as f:
        for img_path in selected_images:
            f.write(f"{img_path}\n")
    selected_images.clear()  # Clear selection for next chunk
 
# Callback for loading the next chunk
def next_chunk(event):
    global folder_idx, image_idx, press_count, counter_text
 
    # Increment the press counter
    press_count += 1
    # Save current selections
    save_selected_images()
    # Update indices for next chunk
    image_idx += chunk_size
    current_folder = folders[folder_idx]
 
    # Check if we need to move to the next folder
    if image_idx >= len(os.listdir(current_folder)):
        folder_idx += 1
        image_idx = 0
    # Check if there are more folders to process
    if folder_idx < len(folders):
        images = load_image_chunk(folders[folder_idx], image_idx, chunk_size)
    else:
        print("All folders and images processed.")
        plt.close()  # Close the figure after all folders are processed
        return
 
    # Clear the previous figure text, then add updated counter display in the top-right
    counter_text.set_text(f'It: {press_count}/{int(27000/128)+1}\n{get_selected_image_count()} sel')

 
    # Update display with the new chunk
    update_display(images)
 
def previous_chunk(event):
    global folder_idx, image_idx, press_count, counter_text
 
    # Decrement the press counter only if we're going back
    press_count -= 1
 
    # Update indices for the previous chunk
    image_idx -= chunk_size
 
    # Check if we need to move to the previous folder
    if image_idx < 0:
        folder_idx -= 1
        if folder_idx < 0:  # Prevent going back beyond the first folder
            folder_idx = 0
            image_idx = 0
        else:
            # Set to the last chunk of the previous folder
            current_folder = folders[folder_idx]
            image_idx = max(0, len(os.listdir(current_folder)) - chunk_size)
 
    # Load images for the new indices
    current_folder = folders[folder_idx]
    images = load_image_chunk(current_folder, image_idx, chunk_size)
 
    # Update the counter display
    selected_image_count = get_selected_image_count()
    counter_text.set_text(f'It: {press_count}/{int(27000/128)+1}\n{get_selected_image_count()} sel')
 
    # Update display with the previous chunk
    update_display(images)
 
 
# Initial display
current_folder = folders[folder_idx]
init_img_names = load_image_chunk(current_folder, image_idx, chunk_size)
update_display(init_img_names)
 
# Add Next button only once outside the update_display function
ax_next = plt.axes([0.45, 0.00, 0.1, 0.03])  # Adjusted position for bottom placement
btn_next = Button(ax_next, 'Next (Click or <Enter>)')
btn_next.on_clicked(next_chunk)
# Create the "Go to Previous" button
ax_prev = plt.axes([0.34, 0.00, 0.1, 0.03])  # Positioned next to the Next button
btn_prev = Button(ax_prev, 'Go to Previous (Click)')
btn_prev.on_clicked(previous_chunk)
 
# Connect the click event for image selection
fig.canvas.mpl_connect('pick_event', on_image_click)
fig.canvas.mpl_connect('key_press_event', on_key_press)
 
plt.show()