In [1]:
import os
import json
import shutil
from collections import OrderedDict
import yaml # For data.yaml
from PIL import Image # To verify/get image dimensions if needed
import concurrent.futures
import time # For timing

# --- Utility Functions (largely unchanged but crucial) ---
def convert_yolo_format(image_width, image_height, points, class_id):
    """
    Converts a single bounding box to YOLO format.
    'points' is a list of two points [[x1, y1], [x2, y2]]
    Returns a string in YOLO format: "class_id x_center y_center width height"
    """
    x_coords = [p[0] for p in points]
    y_coords = [p[1] for p in points]

    x_min_abs = min(x_coords)
    y_min_abs = min(y_coords)
    x_max_abs = max(x_coords)
    y_max_abs = max(y_coords)

    x_min_abs = max(0, x_min_abs)
    y_min_abs = max(0, y_min_abs)
    x_max_abs = min(image_width - 1, x_max_abs)
    y_max_abs = min(image_height - 1, y_max_abs)

    if x_min_abs >= x_max_abs or y_min_abs >= y_max_abs:
        return None

    box_width_abs = x_max_abs - x_min_abs
    box_height_abs = y_max_abs - y_min_abs

    x_center_abs = x_min_abs + box_width_abs / 2.0
    y_center_abs = y_min_abs + box_height_abs / 2.0

    x_center_norm = x_center_abs / image_width
    y_center_norm = y_center_abs / image_height
    width_norm = box_width_abs / image_width
    height_norm = box_height_abs / image_height

    return f"{class_id} {x_center_norm:.6f} {y_center_norm:.6f} {width_norm:.6f} {height_norm:.6f}"

# --- Parallelized Class Discovery ---
def _get_labels_from_single_json(json_path):
    """Helper function to extract labels from a single JSON file."""
    labels_in_file = set()
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        for shape in data.get("shapes", []):
            label = shape.get("label")
            if label:
                labels_in_file.add(label)
    except Exception as e:
        print(f"Warning (Class Discovery Worker): Error reading {json_path}: {e}")
    return labels_in_file

def discover_all_classes_parallel(source_root_dir, all_fold_names, max_workers=None):
    """Scans all specified folds in parallel to discover all unique class labels."""
    print("Discovering all class labels across all specified folds (parallelized)...")
    overall_start_time = time.time()
    
    json_file_paths = []
    for fold_name in all_fold_names:
        current_fold_path = os.path.join(source_root_dir, fold_name)
        if not os.path.isdir(current_fold_path):
            print(f"Warning (Class Discovery): Source fold '{current_fold_path}' not found. Skipping.")
            continue
        for item_name in os.listdir(current_fold_path):
            if item_name.lower().endswith(".json"):
                json_file_paths.append(os.path.join(current_fold_path, item_name))

    if not json_file_paths:
        print("Warning: No JSON files found for class discovery.")
        return OrderedDict()

    unique_labels = set()
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_path = {executor.submit(_get_labels_from_single_json, path): path for path in json_file_paths}
        for future in concurrent.futures.as_completed(future_to_path):
            try:
                labels_from_file = future.result()
                unique_labels.update(labels_from_file)
            except Exception as exc:
                path = future_to_path[future]
                print(f"Warning (Class Discovery Main): Generated an exception for {path}: {exc}")
    
    sorted_labels = sorted(list(unique_labels))
    master_class_to_id_map = OrderedDict((label, i) for i, label in enumerate(sorted_labels))
    
    discovery_time = time.time() - overall_start_time
    if master_class_to_id_map:
        print(f"Discovered {len(master_class_to_id_map)} unique classes in {discovery_time:.2f} seconds:")
        for name, idx in master_class_to_id_map.items():
            print(f"  '{name}': {idx}")
    else:
        print(f"Warning: No classes discovered after parallel processing in {discovery_time:.2f} seconds.")
    return master_class_to_id_map

# --- Parallelized File Processing for a Single Fold ---
def _process_single_image_json_pair(args_tuple):
    """
    Helper function to process one image-JSON pair.
    Expected args_tuple: (img_filename_no_ext, source_fold_path, master_class_to_id_map,
                         output_images_dir, output_labels_dir)
    Returns: True if successful, False otherwise.
    """
    img_filename_no_ext, source_fold_path, master_class_to_id_map, \
    output_images_dir, output_labels_dir = args_tuple

    png_file = f"{img_filename_no_ext}.png"
    json_file = f"{img_filename_no_ext}.json"

    source_png_path = os.path.join(source_fold_path, png_file)
    source_json_path = os.path.join(source_fold_path, json_file)

    if not (os.path.exists(source_png_path) and os.path.exists(source_json_path)):
        # print(f"Warning (File Worker): Missing image or JSON for {img_filename_no_ext} in {source_fold_path}")
        return False # Silently fail for missing pairs to reduce noise if expected

    dest_png_path = os.path.join(output_images_dir, png_file)
    try:
        shutil.copy2(source_png_path, dest_png_path)
        with open(source_json_path, 'r') as f:
            data = json.load(f)
    except Exception as e:
        # print(f"Warning (File Worker): Error copying or reading JSON for {source_png_path}: {e}")
        if os.path.exists(dest_png_path): os.remove(dest_png_path)
        return False

    image_height = data.get("imageHeight")
    image_width = data.get("imageWidth")

    if image_height is None or image_width is None:
        try:
            with Image.open(source_png_path) as img:
                image_width_pil, image_height_pil = img.size
            if image_height is None: image_height = image_height_pil
            if image_width is None: image_width = image_width_pil
        except Exception:
            if os.path.exists(dest_png_path): os.remove(dest_png_path)
            return False
            
    if not image_height or not image_width: # Handles 0 or None
        if os.path.exists(dest_png_path): os.remove(dest_png_path)
        return False

    yolo_annotations = []
    for shape in data.get("shapes", []):
        label_name = shape.get("label")
        points = shape.get("points")
        shape_type = shape.get("shape_type")

        if not label_name or not points or shape_type != "rectangle" or len(points) != 2:
            continue
        if label_name not in master_class_to_id_map:
            # This should be rare if discovery is correct and data is consistent
            # print(f"Error (File Worker): Label '{label_name}' from {json_file} not in master map.")
            continue
        
        class_id = master_class_to_id_map[label_name]
        yolo_str = convert_yolo_format(image_width, image_height, points, class_id)
        if yolo_str:
            yolo_annotations.append(yolo_str)

    if yolo_annotations:
        dest_label_path = os.path.join(output_labels_dir, f"{img_filename_no_ext}.txt")
        with open(dest_label_path, 'w') as f_label:
            f_label.write("\n".join(yolo_annotations) + "\n")
        return True
    elif os.path.exists(dest_png_path): # Image copied, but no valid annotations
        # Optionally remove images if they have no annotations
        # os.remove(dest_png_path)
        # print(f"Note (File Worker): No annotations for {png_file}, but image was copied.")
        return False # Count as not fully processed if no labels
    return False


def process_single_fold_for_yolo_parallel(source_fold_path, image_basenames_in_fold,
                                          master_class_to_id_map,
                                          output_images_dir, output_labels_dir, max_workers=None):
    """Processes files from a single source fold in parallel."""
    processed_count = 0
    
    tasks = []
    for img_basename in image_basenames_in_fold:
        tasks.append((img_basename, source_fold_path, master_class_to_id_map,
                      output_images_dir, output_labels_dir))

    if not tasks:
        return 0

    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_task = {executor.submit(_process_single_image_json_pair, task_args): task_args for task_args in tasks}
        for future in concurrent.futures.as_completed(future_to_task):
            try:
                if future.result(): # If _process_single_image_json_pair returned True
                    processed_count += 1
            except Exception as exc:
                task_args_failed = future_to_task[future]
                # print(f"Warning (Fold Processor): Task for {task_args_failed[0]} generated an exception: {exc}")
                pass # Silently continue, or log more verbosely
    return processed_count

# --- Main Orchestrator Function for Jupyter Notebook ---
def create_kfold_yolo_datasets(source_dir, base_dest_dir, source_fold_names_str, max_workers=None):
    """
    Main function to create k-fold cross-validation datasets in YOLO format.
    Designed to be called from a Jupyter Notebook.
    """
    overall_start_time = time.time()

    all_original_fold_names = [name.strip() for name in source_fold_names_str.split(',') if name.strip()]

    if not all_original_fold_names or len(all_original_fold_names) < 2:
        print("Error: Please provide at least two source fold names for cross-validation.")
        return

    if not os.path.isdir(source_dir):
        print(f"Error: Source directory '{source_dir}' not found.")
        return
    
    os.makedirs(base_dest_dir, exist_ok=True)

    # 1. Discover all classes globally first (parallelized)
    master_class_to_id_map = discover_all_classes_parallel(source_dir, all_original_fold_names, max_workers)
    if not master_class_to_id_map:
        print("Error: No classes were found in any of the specified folds. Cannot proceed.")
        return

    num_cv_folds = len(all_original_fold_names)
    print(f"\nPreparing data for {num_cv_folds}-Fold Cross-Validation...")

    # 2. Loop for each CV iteration
    for i in range(num_cv_folds):
        cv_iteration_start_time = time.time()
        current_val_fold_name = all_original_fold_names[i]
        current_train_fold_names = [f_name for idx, f_name in enumerate(all_original_fold_names) if idx != i]

        cv_iteration_dir_name = f"cv_iteration_{i+1}"
        current_cv_split_output_root = os.path.join(base_dest_dir, cv_iteration_dir_name)

        print(f"\n--- Processing CV Iteration {i+1}/{num_cv_folds} ---")
        print(f"  Validation Fold: {current_val_fold_name}")
        print(f"  Training Folds: {', '.join(current_train_fold_names)}")
        print(f"  Output to: {current_cv_split_output_root}")

        train_images_dir = os.path.join(current_cv_split_output_root, "images", "train")
        train_labels_dir = os.path.join(current_cv_split_output_root, "labels", "train")
        val_images_dir = os.path.join(current_cv_split_output_root, "images", "val")
        val_labels_dir = os.path.join(current_cv_split_output_root, "labels", "val")

        os.makedirs(train_images_dir, exist_ok=True)
        os.makedirs(train_labels_dir, exist_ok=True)
        os.makedirs(val_images_dir, exist_ok=True)
        os.makedirs(val_labels_dir, exist_ok=True)

        # --- Process Training Folds for this CV Iteration (Parallel within each source fold) ---
        print(f"  Processing training data for CV Iteration {i+1}...")
        total_train_images_for_cv_iter = 0
        for train_fold_name in current_train_fold_names:
            fold_proc_start_time = time.time()
            source_fold_path = os.path.join(source_dir, train_fold_name)
            if not os.path.isdir(source_fold_path):
                print(f"  Warning: Training source fold '{source_fold_path}' not found. Skipping.")
                continue
            
            image_basenames = {os.path.splitext(f)[0] for f in os.listdir(source_fold_path) if f.lower().endswith(".png")}
            if not image_basenames:
                print(f"  No PNG images found in training source fold '{source_fold_path}'.")
                continue
                
            count = process_single_fold_for_yolo_parallel(
                source_fold_path, list(image_basenames), master_class_to_id_map,
                train_images_dir, train_labels_dir, max_workers
            )
            total_train_images_for_cv_iter += count
            fold_proc_time = time.time() - fold_proc_start_time
            print(f"    Processed {count} images from source train fold '{train_fold_name}' in {fold_proc_time:.2f}s.")
        print(f"  Total training images for CV Iteration {i+1}: {total_train_images_for_cv_iter}")

        # --- Process Validation Fold for this CV Iteration (Parallel within the source fold) ---
        print(f"  Processing validation data for CV Iteration {i+1}...")
        total_val_images_for_cv_iter = 0
        source_val_fold_path = os.path.join(source_dir, current_val_fold_name)
        if not os.path.isdir(source_val_fold_path):
            print(f"  Warning: Validation source fold '{source_val_fold_path}' not found. Skipping val set for this iter.")
        else:
            fold_proc_start_time = time.time()
            image_basenames_val = {os.path.splitext(f)[0] for f in os.listdir(source_val_fold_path) if f.lower().endswith(".png")}
            if not image_basenames_val:
                print(f"  No PNG images found in validation source fold '{source_val_fold_path}'.")
            else:
                count_val = process_single_fold_for_yolo_parallel(
                    source_val_fold_path, list(image_basenames_val), master_class_to_id_map,
                    val_images_dir, val_labels_dir, max_workers
                )
                total_val_images_for_cv_iter = count_val
                fold_proc_time = time.time() - fold_proc_start_time
                print(f"    Processed {count_val} images from source validation fold '{current_val_fold_name}' in {fold_proc_time:.2f}s.")
        print(f"  Total validation images for CV Iteration {i+1}: {total_val_images_for_cv_iter}")

        # --- Create data.yaml for this CV Iteration ---
        data_yaml_content = {
            'path': os.path.abspath(current_cv_split_output_root),
            'train': os.path.join('images', 'train'),
            'val': os.path.join('images', 'val'),
            'nc': len(master_class_to_id_map),
            'names': list(master_class_to_id_map.keys())
        }
        data_yaml_path = os.path.join(current_cv_split_output_root, "data.yaml")
        try:
            with open(data_yaml_path, 'w') as f:
                yaml.dump(data_yaml_content, f, sort_keys=False, default_flow_style=False)
            cv_iteration_time = time.time() - cv_iteration_start_time
            print(f"  Successfully created 'data.yaml' for CV Iteration {i+1}. Iteration took {cv_iteration_time:.2f}s.")
        except Exception as e:
            print(f"  Error writing data.yaml for CV Iteration {i+1}: {e}")

    total_script_time = time.time() - overall_start_time
    print(f"\n{num_cv_folds}-Fold Cross-Validation dataset preparation complete in {total_script_time:.2f} seconds!")
    print(f"All CV iteration datasets are ready under: {base_dest_dir}")

# To run this in a Jupyter Notebook cell:

# 1. Make sure you have PyYAML installed: pip install PyYAML Pillow
# 2. Define your configuration variables in the cell:
#
# SOURCE_DIR = "/path/to/your/original_dataset_with_folds"  # e.g., "/mnt/my_original_data"
# BASE_DEST_DIR = "/path/to/your_yolo_kfold_datasets_output" # e.g., "/mnt/yolo_5fold_datasets_parallel"
# SOURCE_FOLD_NAMES_STR = "original_fold_A,original_fold_B,original_fold_C,original_fold_D,original_fold_E" # Comma-separated
#
# # Optional: Set max_workers for parallel processing. None uses os.cpu_count().
# # For I/O bound tasks, you can sometimes benefit from more workers than CPUs.
# # For CPU bound inside processes, os.cpu_count() is a good default.
# MAX_WORKERS = None # or os.cpu_count() or a specific number like 4 or 8
#
# # 3. Call the main function:
# if __name__ == '__main__': # This check is good practice, though not strictly needed if only run in a cell
#    # Example of how to run it directly for testing (if not in Jupyter)
#    # For Jupyter, you'd just define the variables above and then call:
#    # create_kfold_yolo_datasets(SOURCE_DIR, BASE_DEST_DIR, SOURCE_FOLD_NAMES_STR, MAX_WORKERS)
#    
#    # --- EXAMPLE USAGE FOR JUPYTER (after defining variables above) ---
#    # print(f"Using source: {SOURCE_DIR}")
#    # print(f"Using destination: {BASE_DEST_DIR}")
#    # print(f"Using folds: {SOURCE_FOLD_NAMES_STR}")
#    # print(f"Max workers: {MAX_WORKERS if MAX_WORKERS is not None else 'Default (os.cpu_count())'}")
#    #
#    # create_kfold_yolo_datasets(SOURCE_DIR, BASE_DEST_DIR, SOURCE_FOLD_NAMES_STR, max_workers=MAX_WORKERS)
#    pass # Pass here, as the actual call would be made directly in the notebook cell after defining vars

In [2]:
# --- Configuration for Your Dataset ---
SOURCE_DIR = "/blue/hulcr/gmarais/PhD/phase_1_data/1_data_splitting/classification_folds_output"  # CHANGE THIS
BASE_DEST_DIR = "/blue/hulcr/gmarais/PhD/phase_1_data/3_classification_phase_2/ultralytics" # CHANGE THIS
SOURCE_FOLD_NAMES_STR = "fold1,fold2,fold3,fold4,fold5" # CHANGE THIS

# Optional: Control the number of parallel processes
# None will default to the number of CPUs on your machine.
# You might want to experiment with this value.
MAX_WORKERS = 6 # os.cpu_count() can also be explicitly used here. For instance, if os.cpu_count() is not available, use MAX_WORKERS = 4
# import os # if you want to use os.cpu_count() explicitly
# MAX_WORKERS = os.cpu_count()


# --- Run the Dataset Creation ---
print(f"Starting dataset creation with the following parameters:")
print(f"  Source Directory: {SOURCE_DIR}")
print(f"  Base Destination Directory: {BASE_DEST_DIR}")
print(f"  Source Fold Names: {SOURCE_FOLD_NAMES_STR}")
print(f"  Max Workers for Parallelization: {MAX_WORKERS if MAX_WORKERS is not None else 'Default (os.cpu_count())'}")
print("-" * 30)

create_kfold_yolo_datasets(SOURCE_DIR, BASE_DEST_DIR, SOURCE_FOLD_NAMES_STR, max_workers=MAX_WORKERS)

Starting dataset creation with the following parameters:
  Source Directory: /blue/hulcr/gmarais/PhD/phase_1_data/1_data_splitting/classification_folds_output
  Base Destination Directory: /blue/hulcr/gmarais/PhD/phase_1_data/3_classification_phase_2/ultralytics
  Source Fold Names: fold1,fold2,fold3,fold4,fold5
  Max Workers for Parallelization: 6
------------------------------
Discovering all class labels across all specified folds (parallelized)...
Discovered 63 unique classes in 1.62 seconds:
  'Ambrosiodmus_minor': 0
  'Ambrosiophilus_atratus': 1
  'Anisandrus_dispar': 2
  'Anisandrus_sayi': 3
  'Cnestus_mutilatus': 4
  'Coccotrypes_carpophagus': 5
  'Coccotrypes_dactyliperda': 6
  'Coptoborus_ricini': 7
  'Cryptocarenus_heveae': 8
  'Ctonoxylon_hagedorn': 9
  'Cyclorhipidion_pelliculosum': 10
  'Dendroctonus_rufipennis': 11
  'Dendroctonus_terebrans': 12
  'Dendroctonus_valens': 13
  'Dryocoetes_autographus': 14
  'Euplatypus_compositus': 15
  'Euwallacea_fornicatus': 16
  'Euwal