In [1]:
import os
import json
import shutil
from collections import OrderedDict
import yaml # For data.yaml
from PIL import Image # To verify/get image dimensions if needed
import concurrent.futures
import time # For timing

# --- Utility Function: Convert to YOLO Format (Identical to Script 1) ---
def convert_yolo_format(image_width, image_height, points, class_id):
    """
    Converts a single bounding box from [[x1, y1], [x2, y2]] points to YOLO format.
    Returns a string: "class_id x_center_norm y_center_norm width_norm height_norm"
    All coordinates are normalized by image dimensions.
    """
    x_coords = [p[0] for p in points]
    y_coords = [p[1] for p in points]

    x_min_abs = float(min(x_coords))
    y_min_abs = float(min(y_coords))
    x_max_abs = float(max(x_coords))
    y_max_abs = float(max(y_coords))

    x_min_abs = max(0.0, x_min_abs)
    y_min_abs = max(0.0, y_min_abs)
    x_max_abs = min(float(image_width - 1), x_max_abs)
    y_max_abs = min(float(image_height - 1), y_max_abs)
    
    if x_min_abs >= x_max_abs or y_min_abs >= y_max_abs:
        return None

    box_width_abs = x_max_abs - x_min_abs
    box_height_abs = y_max_abs - y_min_abs

    x_center_abs = x_min_abs + box_width_abs / 2.0
    y_center_abs = y_min_abs + box_height_abs / 2.0

    x_center_norm = x_center_abs / image_width
    y_center_norm = y_center_abs / image_height
    width_norm = box_width_abs / image_width
    height_norm = box_height_abs / image_height

    return f"{class_id} {x_center_norm:.6f} {y_center_norm:.6f} {width_norm:.6f} {height_norm:.6f}"

# --- Class Discovery for Test Set ---
def _get_labels_from_single_json(json_path):
    """Helper function to extract labels from a single JSON file."""
    labels_in_file = set()
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        for shape in data.get("shapes", []): # Assuming 'shapes' from LabelMe
            label = shape.get("label")
            if label:
                labels_in_file.add(label)
    except Exception as e:
        # print(f"Warning (Class Discovery Worker): Error reading {json_path}: {e}")
        pass # Silently ignore problematic JSONs during discovery for robustness
    return labels_in_file

def discover_classes_for_yolo_test_set(source_test_data_dir, max_workers=None):
    """Scans the test data directory to discover all unique class labels for YOLO."""
    print("Discovering class labels from the test set (parallelized)...")
    overall_start_time = time.time()
    
    json_file_paths = []
    if not os.path.isdir(source_test_data_dir):
        print(f"Error (Class Discovery): Source directory '{source_test_data_dir}' not found.")
        return OrderedDict()

    for item_name in os.listdir(source_test_data_dir):
        if item_name.lower().endswith(".json"):
            json_file_paths.append(os.path.join(source_test_data_dir, item_name))

    if not json_file_paths:
        print("Warning: No JSON files found in the test set for class discovery.")
        return OrderedDict()

    unique_labels = set()
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_path = {executor.submit(_get_labels_from_single_json, path): path for path in json_file_paths}
        for future in concurrent.futures.as_completed(future_to_path):
            try:
                labels_from_file = future.result()
                unique_labels.update(labels_from_file)
            except Exception as exc:
                # path = future_to_path[future]
                # print(f"Warning (Class Discovery Main): Generated an exception for {path}: {exc}")
                pass
    
    sorted_labels = sorted(list(unique_labels))
    # Assign 0-indexed IDs for YOLO
    master_class_to_id_map = OrderedDict((label, i) for i, label in enumerate(sorted_labels))
    
    discovery_time = time.time() - overall_start_time
    if master_class_to_id_map:
        print(f"Discovered {len(master_class_to_id_map)} unique classes in {discovery_time:.2f} seconds:")
        for name, idx in master_class_to_id_map.items():
            print(f"  '{name}': {idx}")
    else:
        print(f"Warning: No classes discovered in the test set after parallel processing.")
    return master_class_to_id_map

# --- Worker for processing one image-JSON pair (Multi-Class) ---
def _process_single_image_json_pair_multiclass(args_tuple):
    """
    Processes one image and its JSON, using discovered class IDs.
    args_tuple: (img_filename_no_ext, source_data_path, class_to_id_map,
                 output_image_path, output_label_path)
    Returns: True if successful (label file created), False otherwise.
    """
    img_filename_no_ext, source_data_path, class_to_id_map, \
    output_image_path, output_label_path = args_tuple
    
    png_file = f"{img_filename_no_ext}.png"
    json_file = f"{img_filename_no_ext}.json"

    source_png_path = os.path.join(source_data_path, png_file)
    source_json_path = os.path.join(source_data_path, json_file)

    if not (os.path.exists(source_png_path) and os.path.exists(source_json_path)):
        return False

    try:
        shutil.copy2(source_png_path, output_image_path)
        with open(source_json_path, 'r') as f:
            data = json.load(f)
        
        image_height = data.get("imageHeight")
        image_width = data.get("imageWidth")

        if image_height is None or image_width is None:
            with Image.open(source_png_path) as img_pil:
                pil_width, pil_height = img_pil.size
            if image_width is None: image_width = pil_width
            if image_height is None: image_height = pil_height
        
        if not image_height or not image_width:
            if os.path.exists(output_image_path): os.remove(output_image_path)
            return False

        yolo_annotations = []
        for shape in data.get("shapes", []):
            original_label = shape.get("label")
            points = shape.get("points")
            shape_type = shape.get("shape_type")

            if not original_label or not points or shape_type != "rectangle" or len(points) != 2:
                continue
            
            if original_label not in class_to_id_map:
                # print(f"Warning (Worker): Label '{original_label}' in {json_file} not in discovered class map. Skipping annotation.")
                continue 
            
            class_id = class_to_id_map[original_label]
            yolo_str = convert_yolo_format(image_width, image_height, points, class_id)
            if yolo_str:
                yolo_annotations.append(yolo_str)
        
        if yolo_annotations:
            with open(output_label_path, 'w') as f_label:
                f_label.write("\n".join(yolo_annotations) + "\n")
            return True
        else:
            # Image copied, but no valid labels for known classes.
            # Optionally remove image:
            # if os.path.exists(output_image_path):
            #     os.remove(output_image_path)
            # return False
            return True # Processed as negative or labels were skipped

    except Exception as e:
        # print(f"Warning (Worker): Error processing {source_png_path} or {source_json_path}: {e}")
        if os.path.exists(output_image_path): os.remove(output_image_path)
        return False

# --- Main Function to Create YOLO Test Dataset (Multi-Class) ---
def create_yolo_test_dataset_multiclass(
    source_test_data_dir,
    output_yolo_test_dir,
    max_workers=None
):
    """
    Creates a YOLO-formatted test dataset, discovering and using original class labels.
    """
    overall_start_time = time.time()
    print(f"\nProcessing Test Dataset for YOLO (Multi-Class from original labels)...")

    if not os.path.isdir(source_test_data_dir):
        print(f"Error: Source test data directory '{source_test_data_dir}' not found.")
        return

    # 1. Discover all classes from the test set
    class_to_id_map = discover_classes_for_yolo_test_set(source_test_data_dir, max_workers)
    if not class_to_id_map:
        print("Error: No classes discovered. Cannot proceed.")
        return

    test_images_dir = os.path.join(output_yolo_test_dir, "images", "test")
    test_labels_dir = os.path.join(output_yolo_test_dir, "labels", "test")
    os.makedirs(test_images_dir, exist_ok=True)
    os.makedirs(test_labels_dir, exist_ok=True)

    print(f"  Output images to: {test_images_dir}")
    print(f"  Output labels to: {test_labels_dir}")

    tasks = []
    image_basenames = sorted([
        os.path.splitext(f)[0] for f in os.listdir(source_test_data_dir)
        if f.lower().endswith(".png") and os.path.exists(os.path.join(source_test_data_dir, f"{os.path.splitext(f)[0]}.json"))
    ])

    if not image_basenames:
        print(f"  No matching PNG/JSON pairs found in '{source_test_data_dir}'.")
        return

    for img_basename in image_basenames:
        output_image_path = os.path.join(test_images_dir, f"{img_basename}.png")
        output_label_path = os.path.join(test_labels_dir, f"{img_basename}.txt")
        tasks.append(
            (img_basename, source_test_data_dir, class_to_id_map,
             output_image_path, output_label_path)
        )
    
    processed_count = 0
    print(f"  Found {len(tasks)} images to process...")
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {
            executor.submit(_process_single_image_json_pair_multiclass, task_args): task_args
            for task_args in tasks
        }
        for future in concurrent.futures.as_completed(future_to_task):
            try:
                if future.result():
                    processed_count += 1
            except Exception:
                pass # Error already handled or printed in worker
    
    print(f"  Processed {processed_count} image-JSON pairs.")

    data_yaml_content = {
        'path': os.path.abspath(output_yolo_test_dir),
        'test': os.path.join('images', 'test'),
        'nc': len(class_to_id_map),
        'names': list(class_to_id_map.keys())
    }
    data_yaml_path = os.path.join(output_yolo_test_dir, "data.yaml")
    try:
        with open(data_yaml_path, 'w') as f:
            yaml.dump(data_yaml_content, f, sort_keys=False, default_flow_style=False)
        print(f"  Successfully created 'data.yaml' at {data_yaml_path}")
    except Exception as e:
        print(f"  Error writing data.yaml: {e}")

    total_script_time = time.time() - overall_start_time
    print(f"YOLO Multi-Class Test Dataset preparation complete in {total_script_time:.2f} seconds!")
    print(f"Dataset ready under: {output_yolo_test_dir}")


# --- Configuration and Execution (Multi-Class for Test Set) ---
if __name__ == "__main__":
    # **IMPORTANT**: Modify these paths before running!
    
    SOURCE_TEST_DATA_DIR_MULTI_CLASS = "/blue/hulcr/gmarais/PhD/phase_1_data/1_data_splitting/test_set_output"
    OUTPUT_YOLO_TEST_DIR_MULTI_CLASS = "/blue/hulcr/gmarais/PhD/phase_1_data/3_classification_phase_2/ultralytics/test"
    
    MAX_WORKERS_MULTI = None # os.cpu_count()
    # MAX_WORKERS_MULTI = 4

    print("="*50)
    print(f"Starting YOLO dataset creation for TEST SET (MULTI-CLASS from original labels)")

    if SOURCE_TEST_DATA_DIR_MULTI_CLASS == "/path/to/your/source_test_data_folder_multi" or \
       OUTPUT_YOLO_TEST_DIR_MULTI_CLASS == "/path/to/your_output_yolo_test_dir_multi":
        print("\nPLEASE UPDATE 'SOURCE_TEST_DATA_DIR_MULTI_CLASS' and 'OUTPUT_YOLO_TEST_DIR_MULTI_CLASS' before running!")
    else:
        create_yolo_test_dataset_multiclass(
            source_test_data_dir=SOURCE_TEST_DATA_DIR_MULTI_CLASS,
            output_yolo_test_dir=OUTPUT_YOLO_TEST_DIR_MULTI_CLASS,
            max_workers=MAX_WORKERS_MULTI
        )
        print("\nYOLO Multi-Class Test Set script execution finished.")

Starting YOLO dataset creation for TEST SET (MULTI-CLASS from original labels)

Processing Test Dataset for YOLO (Multi-Class from original labels)...
Discovering class labels from the test set (parallelized)...
Discovered 63 unique classes in 0.85 seconds:
  'Ambrosiodmus_minor': 0
  'Ambrosiophilus_atratus': 1
  'Anisandrus_dispar': 2
  'Anisandrus_sayi': 3
  'Cnestus_mutilatus': 4
  'Coccotrypes_carpophagus': 5
  'Coccotrypes_dactyliperda': 6
  'Coptoborus_ricini': 7
  'Cryptocarenus_heveae': 8
  'Ctonoxylon_hagedorn': 9
  'Cyclorhipidion_pelliculosum': 10
  'Dendroctonus_rufipennis': 11
  'Dendroctonus_terebrans': 12
  'Dendroctonus_valens': 13
  'Dryocoetes_autographus': 14
  'Euplatypus_compositus': 15
  'Euwallacea_fornicatus': 16
  'Euwallacea_perbrevis': 17
  'Euwallacea_validus': 18
  'Hylastes_porculus': 19
  'Hylastes_salebrosus': 20
  'Hylesinus_aculeatus': 21
  'Hylesinus_crenatus': 22
  'Hylesinus_toranio': 23
  'Hylesinus_varius': 24
  'Hylurgops_palliatus': 25
  'Hylur