In [None]:
%matplotlib widget
import ipywidgets as widgets
from IPython.display import display, clear_output
import cv2
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import RectangleSelector
import os

path2save = '/home/CenteredData/tmp_ted'

# Load your image list
# NOTE: list of images thatneed to be cropped
img_no_landmark = []
with open('/home/CenteredData/TED Federated Learning Project/img_no_landmark.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        img_no_landmark.append(line.strip())

# Initialize global variables
current_idx = 0
current_image = None
current_bbox = None
current_cropped = None

# Use a single figure for everything
fig = plt.figure(figsize=(15, 10))
main_ax = None
crop_ax = None
rect_selector = None  # Keep a global reference to the rectangle selector

def update_layout():
    """Update the figure layout based on what needs to be shown"""
    global fig, main_ax, crop_ax, rect_selector
    
    # Clear the current figure
    fig.clear()
    
    if current_cropped is not None:
        # Setup a 1x2 grid layout with both the main image and crop
        main_ax = fig.add_subplot(1, 2, 1)
        crop_ax = fig.add_subplot(1, 2, 2)
    else:
        # Just show the main image
        main_ax = fig.add_subplot(1, 1, 1)
        crop_ax = None
    
    # Display the main image
    if current_image is not None:
        image_rgb = cv2.cvtColor(current_image, cv2.COLOR_BGR2RGB)
        main_ax.imshow(image_rgb)
        main_ax.set_title(f'Image {current_idx+1}/{len(img_no_landmark)}: {os.path.basename(img_no_landmark[current_idx])}')
        
        # Setup the rectangle selector
        rect_selector = RectangleSelector(
            main_ax, on_select, useblit=True,
            button=[1],  # Only use left mouse button
            minspanx=5, minspany=5,
            spancoords='pixels',
            interactive=True
        )
        rect_selector.set_active(True)
    
    # Display the cropped image if it exists
    if current_cropped is not None and crop_ax is not None:
        crop_ax.imshow(current_cropped)
        crop_ax.set_title(f'Cropped: {os.path.basename(img_no_landmark[current_idx])}')
    
    fig.tight_layout()
    plt.draw()

def crop_and_display():
    """Crop the image using the current bbox and display it"""
    global current_cropped, current_bbox, current_image
    
    if current_bbox is None or current_image is None:
        print("No selection made yet!")
        return
    
    # Extract bbox coordinates
    y_min, y_max, x_min, x_max = current_bbox
    
    # Get the RGB image
    image_rgb = cv2.cvtColor(current_image, cv2.COLOR_BGR2RGB)
    
    # Crop the selected region
    current_cropped = image_rgb[y_min:y_max, x_min:x_max]
    
    # Update the layout to show both images
    update_layout()

def save_current_crop():
    """Save the current cropped image to disk"""
    global current_cropped, current_idx
    
    if current_cropped is None:
        print("No crop selection to save!")
        return False
    
    # Create output filename based on the current image
    output_path = f'{path2save}/{os.path.basename(img_no_landmark[current_idx])}'
    
    # Save the cropped image
    cv2.imwrite(output_path, cv2.cvtColor(current_cropped, cv2.COLOR_RGB2BGR))
    print(f"Saved cropped image to {output_path}")
    return True

def on_select(eclick, erelease):
    """Callback for rectangle selection"""
    global current_bbox
    
    # Get coordinates
    x1, y1 = int(eclick.xdata), int(eclick.ydata)
    x2, y2 = int(erelease.xdata), int(erelease.ydata)
    
    # Sort coordinates to get min and max values
    x_min, x_max = min(x1, x2), max(x1, x2)
    y_min, y_max = min(y1, y2), max(y1, y2)
    
    # Create bbox in the format [y_min, y_max, x_min, x_max]
    current_bbox = [y_min, y_max, x_min, x_max]
    print(f"Selected bbox: {current_bbox}")
    
    # Display the cropped image
    crop_and_display()

def load_image(idx):
    """Load and display an image from the list"""
    global current_idx, current_image, current_bbox, current_cropped
    
    # Reset current bbox and cropped image
    current_bbox = None
    current_cropped = None
    
    # Set the current index
    current_idx = idx
    
    # Get the image path
    image_path = img_no_landmark[current_idx]
    
    # Load the image
    current_image = cv2.imread(image_path)
    if current_image is None:
        print(f"Error: Could not read image at {image_path}")
        return
    
    # Update the display
    update_layout()
    
    return fig

def next_image_callback(b):
    """Go to the next image"""
    global current_idx, current_cropped, current_bbox
    
    # Save the current crop if it exists
    if current_bbox is not None:
        save_current_crop()
    
    # Reset current_cropped
    current_cropped = None
    current_bbox = None
    
    # Move to the next index
    next_idx = (current_idx + 1) % len(img_no_landmark)
    
    # Load the next image
    load_image(next_idx)
    
    print(f"Now showing image {next_idx+1}/{len(img_no_landmark)}")

# Create a simple button
next_button = widgets.Button(
    description='Next Image',
    button_style='success', 
    tooltip='Save current crop and go to next image',
    icon='arrow-right'
)

# Explicitly attach the callback
next_button.on_click(next_image_callback)

# Display the button
display(next_button)

# Load the first image
load_image(current_idx)