In [1]:
import os
import json
import shutil
from collections import OrderedDict
# import yaml # Not strictly needed for COCO JSON output, but was in previous script
from PIL import Image
import concurrent.futures
import time
import datetime

# --- COCO Bounding Box Conversion ---
def convert_to_coco_bbox(points, image_height, image_width):
    """
    Converts [[x1, y1], [x2, y2]] points to COCO bbox [x_min, y_min, width, height].
    Clamps coordinates to be within image boundaries.
    """
    x_coords = [p[0] for p in points]
    y_coords = [p[1] for p in points]

    x_min_abs = float(min(x_coords))
    y_min_abs = float(min(y_coords))
    x_max_abs = float(max(x_coords))
    y_max_abs = float(max(y_coords))

    x_min_abs = max(0.0, x_min_abs)
    y_min_abs = max(0.0, y_min_abs)
    x_max_abs = min(float(image_width - 1), x_max_abs)
    y_max_abs = min(float(image_height - 1), y_max_abs)
    
    if x_min_abs >= x_max_abs or y_min_abs >= y_max_abs:
        return None

    width = x_max_abs - x_min_abs
    height = y_max_abs - y_min_abs
    
    return [x_min_abs, y_min_abs, width, height]

# --- Fixed Category Definition for COCO ---
def get_fixed_coco_category_list(class_name="bark_beetle", class_id=1):
    """Returns a COCO categories list for a single, fixed class."""
    print(f"Defining fixed COCO category: ID: {class_id}, Name: {class_name}")
    return [{
        "id": class_id,
        "name": class_name,
        "supercategory": class_name # Or a more general one if applicable
    }]

# --- Worker for processing one image-JSON pair for COCO (Single Class) ---
def _process_single_image_to_coco_data_single_class(args_tuple):
    """
    Processes one image and its JSON. All shapes are mapped to a single class name.
    Returns dict with image_info and list of annotation_data (pre-ID assignment).
    args_tuple: (img_filename_no_ext, source_fold_path, new_img_filename, fixed_class_name)
    """
    img_filename_no_ext, source_fold_path, new_img_filename, fixed_class_name = args_tuple
    
    png_file = f"{img_filename_no_ext}.png"
    json_file = f"{img_filename_no_ext}.json"

    source_png_path = os.path.join(source_fold_path, png_file)
    source_json_path = os.path.join(source_fold_path, json_file)

    if not (os.path.exists(source_png_path) and os.path.exists(source_json_path)):
        return None

    try:
        with open(source_json_path, 'r') as f:
            user_json_data = json.load(f)
        
        image_height = user_json_data.get("imageHeight")
        image_width = user_json_data.get("imageWidth")

        if image_height is None or image_width is None:
            with Image.open(source_png_path) as img_pil:
                pil_width, pil_height = img_pil.size
            if image_width is None: image_width = pil_width
            if image_height is None: image_height = pil_height
        
        if not image_height or not image_width:
            return None

        image_info = {
            "file_name": new_img_filename,
            "height": int(image_height),
            "width": int(image_width),
            "original_path": source_png_path
        }
        
        annotations_data = []
        for shape in user_json_data.get("shapes", []):
            points = shape.get("points")
            shape_type = shape.get("shape_type")

            if not points or shape_type != "rectangle" or len(points) != 2:
                continue
            
            coco_bbox = convert_to_coco_bbox(points, image_height, image_width)
            if coco_bbox:
                annotations_data.append({
                    "category_name": fixed_class_name, # Use the fixed class name
                    "bbox": coco_bbox,
                    "area": coco_bbox[2] * coco_bbox[3]
                })
        
        if not annotations_data and not image_info.get('force_include_empty', False):
             if not image_info: return None
        
        return {"image_info": image_info, "annotations_data": annotations_data, "original_source_path": source_png_path}

    except Exception as e:
        # print(f"Warning (COCO Worker - Single Class): Error processing {source_png_path} or {source_json_path}: {e}")
        return None


# --- Main Orchestrator Function for Jupyter Notebook (Single Class COCO) ---
def create_kfold_coco_datasets_single_class(source_dir, base_dest_dir, source_fold_names_str,
                                            fixed_class_name="bark_beetle", fixed_class_id=1,
                                            max_workers=None):
    """
    Main function to create k-fold cross-validation datasets in COCO format
    with all objects mapped to a single class.
    """
    overall_start_time = time.time()
    all_original_fold_names = [name.strip() for name in source_fold_names_str.split(',') if name.strip()]

    if not all_original_fold_names or len(all_original_fold_names) < 2:
        print("Error: Please provide at least two source fold names for cross-validation.")
        return
    if not os.path.isdir(source_dir):
        print(f"Error: Source directory '{source_dir}' not found.")
        return
    os.makedirs(base_dest_dir, exist_ok=True)

    # 1. Get the fixed COCO category list
    coco_categories = get_fixed_coco_category_list(class_name=fixed_class_name, class_id=fixed_class_id)
    # This will be a list with one category object.
    # The category_name_to_id map will simply be {fixed_class_name: fixed_class_id}
    category_name_to_id = {cat['name']: cat['id'] for cat in coco_categories}


    num_cv_folds = len(all_original_fold_names)
    print(f"\nPreparing data for {num_cv_folds}-Fold Cross-Validation (COCO Format, Single Class: '{fixed_class_name}')...")

    for i in range(num_cv_folds): # Loop for each CV iteration
        cv_iteration_start_time = time.time()
        current_val_fold_name = all_original_fold_names[i]
        current_train_fold_names = [f_name for idx, f_name in enumerate(all_original_fold_names) if idx != i]

        cv_iteration_dir_name = f"cv_iteration_{i+1}"
        current_cv_split_output_root = os.path.join(base_dest_dir, cv_iteration_dir_name)
        print(f"\n--- Processing CV Iteration {i+1}/{num_cv_folds} ---")
        print(f"  Output to: {current_cv_split_output_root}")

        train_img_dir = os.path.join(current_cv_split_output_root, "train")
        val_img_dir = os.path.join(current_cv_split_output_root, "val")
        annotations_dir = os.path.join(current_cv_split_output_root, "annotations")
        os.makedirs(train_img_dir, exist_ok=True)
        os.makedirs(val_img_dir, exist_ok=True)
        os.makedirs(annotations_dir, exist_ok=True)

        for split_type, source_fold_list in [("train", current_train_fold_names), ("val", [current_val_fold_name])]:
            print(f"  Processing {split_type} data for CV Iteration {i+1}...")
            split_start_time = time.time()

            coco_output_data = {
                "info": {
                    "description": f"COCO-style dataset for CV Iteration {i+1} - {split_type} (Single Class: {fixed_class_name})",
                    "version": "1.0", "year": datetime.date.today().year,
                    "date_created": datetime.datetime.utcnow().isoformat(' ')
                },
                "licenses": [{"name": "Placeholder License", "id": 0, "url": ""}],
                "categories": coco_categories, # Use the fixed list
                "images": [],
                "annotations": []
            }
            
            current_image_id = 1
            current_annotation_id = 1
            tasks_for_split = []
            target_image_dir_for_split = train_img_dir if split_type == "train" else val_img_dir

            for fold_idx, fold_name in enumerate(source_fold_list):
                source_fold_path = os.path.join(source_dir, fold_name)
                if not os.path.isdir(source_fold_path):
                    print(f"    Warning: Source fold '{source_fold_path}' for {split_type} not found. Skipping.")
                    continue
                
                image_basenames = sorted([os.path.splitext(f)[0] for f in os.listdir(source_fold_path) if f.lower().endswith(".png")])
                for img_basename in image_basenames:
                    new_img_filename = f"{split_type}_fold{fold_idx}_{img_basename}.png"
                    # Pass fixed_class_name to the worker
                    tasks_for_split.append((img_basename, source_fold_path, new_img_filename, fixed_class_name)) 
            
            if not tasks_for_split:
                print(f"    No images found to process for {split_type} set in this CV iteration.")
            else:
                processed_results = []
                with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
                    # Use the modified worker function
                    future_to_task = {executor.submit(_process_single_image_to_coco_data_single_class, task_args): task_args for task_args in tasks_for_split}
                    for future in concurrent.futures.as_completed(future_to_task):
                        try:
                            result = future.result()
                            if result:
                                processed_results.append(result)
                        except Exception:
                            pass
                
                for result_data in processed_results:
                    shutil.copy2(result_data["original_source_path"], os.path.join(target_image_dir_for_split, result_data["image_info"]["file_name"]))
                    img_entry = result_data["image_info"]
                    img_entry["id"] = current_image_id 
                    del img_entry["original_path"] 
                    coco_output_data["images"].append(img_entry) 
                    
                    for ann_data in result_data["annotations_data"]:
                        # ann_data["category_name"] will already be fixed_class_name
                        # Ensure it's in our simple category_name_to_id map
                        if ann_data["category_name"] not in category_name_to_id:
                             print(f"    Logic Error: Fixed class name '{ann_data['category_name']}' not in map. This should not happen.")
                             continue

                        ann_entry = {
                            "id": current_annotation_id,
                            "image_id": current_image_id,
                            "category_id": category_name_to_id[ann_data["category_name"]], # Should be fixed_class_id
                            "bbox": ann_data["bbox"],
                            "area": ann_data["area"],
                            "iscrowd": 0,
                            "segmentation": [] 
                        }
                        coco_output_data["annotations"].append(ann_entry)
                        current_annotation_id += 1
                    current_image_id += 1
            
            output_json_filename = f"instances_{split_type}.json"
            output_json_path = os.path.join(annotations_dir, output_json_filename)
            try:
                with open(output_json_path, 'w') as f:
                    json.dump(coco_output_data, f, indent=4)
                split_processing_time = time.time() - split_start_time
                print(f"    Successfully created '{output_json_filename}' with {len(coco_output_data['images'])} images and {len(coco_output_data['annotations'])} annotations in {split_processing_time:.2f}s.")
            except Exception as e:
                print(f"    Error writing COCO JSON '{output_json_filename}': {e}")
        
        cv_iteration_time = time.time() - cv_iteration_start_time
        print(f"  CV Iteration {i+1} processing took {cv_iteration_time:.2f}s.")

    total_script_time = time.time() - overall_start_time
    print(f"\n{num_cv_folds}-Fold Cross-Validation COCO dataset (Single Class: '{fixed_class_name}') preparation complete in {total_script_time:.2f} seconds!")
    print(f"All CV iteration datasets are ready under: {base_dest_dir}")

In [2]:
# --- Configuration for Your Dataset (COCO Format - Single Class) ---

# **IMPORTANT**: Modify these paths and names to match your actual dataset and desired output.
SOURCE_DIR = "/blue/hulcr/gmarais/PhD/phase_1_data/1_data_splitting"
BASE_DEST_DIR = "/blue/hulcr/gmarais/PhD/phase_1_data/2_object_detection_phase_2/coco/test"
SOURCE_FOLD_NAMES_STR = "test_set_output" # Comma-separated

# Define the single class name and its ID (COCO typically starts IDs at 1)
FIXED_CLASS_NAME = "bark_beetle"
FIXED_CLASS_ID = 1 # Make sure this ID is used consistently if your model expects a specific ID

# Optional: Control the number of parallel processes
import os # if you want to use os.cpu_count() explicitly
MAX_WORKERS = None # Defaults to os.cpu_count()
# MAX_WORKERS = 4 # Or set a specific number

# --- Run the Dataset Creation ---
print(f"Starting COCO dataset creation for Co-DETR with K-Fold CV (Single Class: '{FIXED_CLASS_NAME}'):")
print(f"  Source Directory: {SOURCE_DIR}")
print(f"  Base Destination Directory: {BASE_DEST_DIR}")
print(f"  Source Fold Names: {SOURCE_FOLD_NAMES_STR.split(',')}")
print(f"  Max Workers for Parallelization: {MAX_WORKERS if MAX_WORKERS is not None else f'Default (likely {os.cpu_count()})'}")
print("-" * 30)

# Ensure the function create_kfold_coco_datasets_single_class is defined by running the cell above first!
create_kfold_coco_datasets_single_class(
    source_dir=SOURCE_DIR,
    base_dest_dir=BASE_DEST_DIR,
    source_fold_names_str=SOURCE_FOLD_NAMES_STR,
    fixed_class_name=FIXED_CLASS_NAME,
    fixed_class_id=FIXED_CLASS_ID,
    max_workers=MAX_WORKERS
)

print("\nCOCO Single-Class Script execution finished in this cell.")

Starting COCO dataset creation for Co-DETR with K-Fold CV (Single Class: 'bark_beetle'):
  Source Directory: /blue/hulcr/gmarais/PhD/phase_1_data/1_data_splitting/test_set_output
  Base Destination Directory: /blue/hulcr/gmarais/PhD/phase_1_data/2_object_detection_phase_2/coco/test
  Source Fold Names: ['/']
  Max Workers for Parallelization: Default (likely 128)
------------------------------
Error: Please provide at least two source fold names for cross-validation.

COCO Single-Class Script execution finished in this cell.
