In [None]:
labeler_id = 3

In [None]:
! pip install -q pillow matplotlib ipywidgets ipython

In [2]:
# change to the directory where the data is stored
csv_file_path = 'data_labeling.csv'

In [None]:
import csv
import os
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# Paths to the directories and CSV file
test_image_dir = 'data/test_image_headmind'
similar_image_dir = 'data/DAM'


def load_csv(file_path):
    with open(file_path, 'r') as file:
        reader = csv.reader(file)
        next(reader)  # Skip the header
        data = list(reader)
    return data

# Split data into 4 parts
def split_data(data, num_parts=4):
    chunk_size = len(data) // num_parts
    chunks = [data[i * chunk_size: (i + 1) * chunk_size] for i in range(num_parts)]
    if len(data) % num_parts != 0:
        chunks[-1].extend(data[num_parts * chunk_size:])  # Add remainder to the last chunk
    return chunks

# Display images in a 4x5 grid
def display_images(test_image, similar_images):
    num_images = min(len(similar_images) + 1, 21)  # Test image + up to 20 similar images
    num_rows = 5
    num_cols = 5
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 12))
    axes = axes.flatten()
    
    # Display the test image
    test_image_path = os.path.join(test_image_dir, test_image)
    try:
        img = Image.open(test_image_path)
        axes[0].imshow(img)
        axes[0].set_title("Test Image", fontsize=10)
        axes[0].axis('off')
    except FileNotFoundError:
        axes[0].axis('off')
        axes[0].set_title("Test Image (not found)", fontsize=10)
    
    # Display similar images
    for i, similar_image in enumerate(similar_images[:20]):  # Show only top 20 similar images
        similar_image_path = os.path.join(similar_image_dir, similar_image)
        try:
            img = Image.open(similar_image_path)
            axes[i + 1].imshow(img)
            axes[i + 1].set_title(similar_image, fontsize=8)
            axes[i + 1].axis('off')
        except FileNotFoundError:
            axes[i + 1].axis('off')
            axes[i + 1].set_title(f"{similar_image} (not found)", fontsize=8)
    
    # Hide unused axes
    for j in range(len(similar_images) + 1, len(axes)):
        axes[j].axis('off')
    
    plt.tight_layout()
    plt.show()

# Interactive labeling
def interactive_labeling(data, labeler_id):
    index = 0
    labels_file = f'labeler_{labeler_id}_labels.csv'
    
    # Create or load labels file
    if not os.path.exists(labels_file):
        with open(labels_file, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["test_image_name", "label_image_name"])  # Write header
    labels = {}
    
    # Load existing labels
    with open(labels_file, 'r') as file:
        reader = csv.reader(file)
        next(reader)  # Skip header
        for row in reader:
            labels[row[0]] = row[1]

    # Function to handle navigation and labeling
    def update_display(change=None):
        nonlocal index
        clear_output(wait=True)
        
        # Get current row
        row = data[index]
        test_image = row[0]
        similar_images = row[1:]
        
        # Display images
        display_images(test_image, similar_images)
        
        # Create dropdown for labeling
        options = [("None", "None")] + [(img, img) for img in similar_images[:20]]
        dropdown = widgets.Dropdown(
            options=options,
            description="Most Similar:",
            value=labels.get(test_image, "None")
        )
        
        # Save button
        save_button = widgets.Button(description="Save")
        def save_callback(b):
            selected_label = dropdown.value
            labels[test_image] = selected_label
            
            # Save the label to the file
            with open(labels_file, 'a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([test_image, selected_label])
            
            print(f"Label saved: {test_image} -> {selected_label}")
        save_button.on_click(save_callback)
        
        # Navigation buttons
        prev_button = widgets.Button(description="Previous")
        next_button = widgets.Button(description="Next")
        def prev_callback(b):
            nonlocal index
            index = max(0, index - 1)
            update_display()
        def next_callback(b):
            nonlocal index
            index = min(len(data) - 1, index + 1)
            update_display()
        prev_button.on_click(prev_callback)
        next_button.on_click(next_callback)
        
        # Display widgets
        display(dropdown, save_button, widgets.HBox([prev_button, next_button]))
    
    update_display()

# Main workflow
data = load_csv(csv_file_path)
data_parts = split_data(data, num_parts=4)

interactive_labeling(data_parts[labeler_id - 1], labeler_id)
