# Roboflow Dataset Augmentation and Synthetic Data Generation
1. **Traditional augmentations:** horizontal flip, vertical flip, rotation, tilt, brightness
2. **Synthetic data generation:** 70% augmented images + 30% completely synthetic images

In [None]:
import os
import json
import cv2
import numpy as np
import yaml
from PIL import Image, ImageDraw, ImageEnhance
import random
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict, Counter
from pathlib import Path
import albumentations as A

print("Libraries imported successfully!")

## 1. Configuration

In [None]:
from pathlib import Path

# --- Paths ---
BASE_DIR = Path('/home/andrea/work/AI-waste-detection/')
K_FOLD_CV_DIR = BASE_DIR / 'datasets/k_fold_cv'  # Source k-fold CV dataset
K_FOLD_CV_AUGMENTED_DIR = BASE_DIR / 'datasets/k_fold_cv_augmented'  # Output augmented dataset
OBJECT_BANK_DIR = BASE_DIR / 'datasets/object_bank_for_balancing'  # Pre-existing object bank

# --- K-Fold Configuration ---
NUM_FOLDS = 5  # Number of folds to process (fold_0 to fold_4)

# --- Generation Parameters ---
IMAGE_SIZE = (640, 640)
OBJECTS_PER_IMAGE_RANGE = (1, 7)
SCALE_RANGE = (0.2, 0.7)
ROTATION_RANGE = (-20, 20)
OVERLAP_THRESHOLD = 0.05

# --- Augmentation Parameters ---
HORIZONTAL_FLIP_PROB = 0.5
VERTICAL_FLIP_PROB = 0.5
ROTATION_PROB = 0.25  # 25% chance for 0,90,180,270 rotation
TILT_RANGE = (-15, 15)  # degrees
BRIGHTNESS_RANGE = (0.92, 1.08)  # ±8% brightness change

# --- Synthetic Data Distribution ---
AUGMENTED_SYNTHETIC_RATIO = 0.7  # 70% augmented images
PURE_SYNTHETIC_RATIO = 0.3       # 30% pure synthetic images

# Create directories
K_FOLD_CV_AUGMENTED_DIR.mkdir(exist_ok=True)

print(f"Source K-Fold CV Directory: {K_FOLD_CV_DIR}")
print(f"Output Augmented Directory: {K_FOLD_CV_AUGMENTED_DIR}")
print(f"Object Bank Directory: {OBJECT_BANK_DIR}")
print(f"Number of folds to process: {NUM_FOLDS}")

# Check if source exists
if not K_FOLD_CV_DIR.exists():
    print(f"❌ ERROR: Source k-fold CV dataset not found at {K_FOLD_CV_DIR}")
else:
    print(f"✅ Source k-fold CV dataset found")
    # Check individual folds
    for fold_idx in range(NUM_FOLDS):
        fold_dir = K_FOLD_CV_DIR / f'fold_{fold_idx}'
        if fold_dir.exists():
            print(f"  ✅ Found fold_{fold_idx}")
        else:
            print(f"  ❌ Missing fold_{fold_idx}")
    
if not OBJECT_BANK_DIR.exists():
    print(f"❌ WARNING: Object bank not found at {OBJECT_BANK_DIR}")
    print("  You may need to run synthetic_data_generation.ipynb first to create the object bank")
else:
    print(f"✅ Object bank found")

## 2. Analyze Original Dataset

In [None]:
def load_dataset_config(dataset_dir):
    """Loads the data.yaml file from the dataset directory."""
    config_path = dataset_dir / 'data.yaml'
    if not config_path.exists():
        raise FileNotFoundError(f"data.yaml not found in {dataset_dir}")
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def get_class_distribution(labels_dir, num_classes):
    """Counts class instances in a directory of YOLO label files."""
    class_counts = Counter()
    if not labels_dir.is_dir():
        return class_counts
    for label_file in labels_dir.glob('*.txt'):
        with open(label_file, 'r') as f:
            for line in f:
                if line.strip():
                    # Convert to float first, then to int to handle cases like '4.0'
                    class_id = int(float(line.split()[0]))
                    if class_id < num_classes:
                        class_counts[class_id] += 1
    return class_counts

def analyze_kfold_dataset_distribution():
    """Analyze the class distribution of all folds in the k-fold CV dataset."""
    print("📊 Analyzing K-Fold CV Dataset...")
    
    fold_distributions = {}
    target_classes = None
    num_classes = 0
    
    try:
        # Get class info from first available fold
        for fold_idx in range(NUM_FOLDS):
            fold_dir = K_FOLD_CV_DIR / f'fold_{fold_idx}'
            if fold_dir.exists():
                dataset_config = load_dataset_config(fold_dir)
                target_classes = dataset_config.get('names', [])
                num_classes = len(target_classes)
                print(f"Target classes from fold_{fold_idx}/data.yaml: {target_classes}")
                break
        
        if not target_classes:
            raise FileNotFoundError("No valid data.yaml found in any fold")
        
        # Analyze each fold
        for fold_idx in range(NUM_FOLDS):
            fold_dir = K_FOLD_CV_DIR / f'fold_{fold_idx}'
            if not fold_dir.exists():
                print(f"⚠️  Warning: fold_{fold_idx} not found, skipping...")
                continue
                
            print(f"\n📁 Analyzing fold_{fold_idx}...")
            
            # Analyze train set for this fold
            train_labels_dir = fold_dir / 'train/labels'
            train_distribution = get_class_distribution(train_labels_dir, num_classes)
            
            # Analyze validation set for this fold
            val_labels_dir = fold_dir / 'val/labels'
            val_distribution = get_class_distribution(val_labels_dir, num_classes)
            
            fold_distributions[fold_idx] = {
                'train': train_distribution,
                'val': val_distribution
            }
            
            print(f"  Train Set Distribution:")
            for i, class_name in enumerate(target_classes):
                print(f"    - {class_name} (ID {i}): {train_distribution[i]} instances")
                
            print(f"  Validation Set Distribution:")
            for i, class_name in enumerate(target_classes):
                print(f"    - {class_name} (ID {i}): {val_distribution[i]} instances")
        
        # Plot distributions for all folds
        if fold_distributions:
            num_folds_found = len(fold_distributions)
            fig, axes = plt.subplots(2, num_folds_found, figsize=(5*num_folds_found, 10))
            if num_folds_found == 1:
                axes = axes.reshape(-1, 1)
            
            for idx, (fold_idx, distributions) in enumerate(fold_distributions.items()):
                # Train distribution
                train_df = pd.DataFrame.from_dict(distributions['train'], orient='index').sort_index()
                if not train_df.empty:
                    train_df.plot(kind='bar', legend=False, ax=axes[0, idx], 
                                title=f'Fold {fold_idx} - Train Set')
                    axes[0, idx].set_xlabel('Class ID')
                    axes[0, idx].set_ylabel('Instances')
                
                # Val distribution
                val_df = pd.DataFrame.from_dict(distributions['val'], orient='index').sort_index()
                if not val_df.empty:
                    val_df.plot(kind='bar', legend=False, ax=axes[1, idx], 
                              title=f'Fold {fold_idx} - Val Set')
                    axes[1, idx].set_xlabel('Class ID')
                    axes[1, idx].set_ylabel('Instances')
            
            plt.tight_layout()
            plt.show()
        
        return target_classes, num_classes, fold_distributions
        
    except Exception as e:
        print(f"❌ Error analyzing k-fold dataset: {e}")
        return [], 0, {}

# Analyze the k-fold dataset
TARGET_CLASSES, NUM_CLASSES, FOLD_DISTRIBUTIONS = analyze_kfold_dataset_distribution()

## 3. Traditional Augmentation Functions

In [None]:
def parse_yolo_label(label_line, img_width, img_height):
    """Parse a YOLO label line into absolute coordinates."""
    parts = label_line.strip().split()
    # Convert to float first, then to int to handle cases like '4.0'
    class_id = int(float(parts[0]))
    cx, cy, w, h = map(float, parts[1:5])
    
    # Convert to absolute coordinates
    x1 = int((cx - w/2) * img_width)
    y1 = int((cy - h/2) * img_height)
    x2 = int((cx + w/2) * img_width)
    y2 = int((cy + h/2) * img_height)
    
    return class_id, [x1, y1, x2, y2]

def bbox_to_yolo(bbox, img_width, img_height):
    """Convert absolute bbox to YOLO format."""
    x1, y1, x2, y2 = bbox
    cx = (x1 + x2) / 2 / img_width
    cy = (y1 + y2) / 2 / img_height
    w = (x2 - x1) / img_width
    h = (y2 - y1) / img_height
    return cx, cy, w, h

def apply_traditional_augmentations(img_path, label_path, output_img_dir, output_label_dir, base_filename):
    """Apply traditional augmentations to an image and its labels."""
    # Load image
    image = cv2.imread(str(img_path))
    if image is None:
        return []
    
    # Convert BGR to RGB for albumentations
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_height, img_width = image.shape[:2]
    
    # Load labels
    bboxes = []
    class_labels = []
    
    if label_path.exists():
        with open(label_path, 'r') as f:
            for line in f:
                if line.strip():
                    class_id, bbox = parse_yolo_label(line, img_width, img_height)
                    # Convert to normalized format [x_min, y_min, x_max, y_max] for albumentations
                    x1, y1, x2, y2 = bbox
                    norm_bbox = [x1/img_width, y1/img_height, x2/img_width, y2/img_height]
                    bboxes.append(norm_bbox)
                    class_labels.append(class_id)
    
    augmented_files = []
    
    # Define augmentation compositions with bbox_params
    transform_configs = [
        # Horizontal flip
        (A.HorizontalFlip(p=1.0), HORIZONTAL_FLIP_PROB),
        # Vertical flip  
        (A.VerticalFlip(p=1.0), VERTICAL_FLIP_PROB),
        # Random rotation (0, 90, 180, 270)
        (A.RandomRotate90(p=1.0), ROTATION_PROB),
        # Tilt (affine rotation)
        (A.Affine(rotate=random.uniform(*TILT_RANGE), p=1.0), 1.0),
        # Brightness change
        (A.RandomBrightnessContrast(
            brightness_limit=(BRIGHTNESS_RANGE[0]-1, BRIGHTNESS_RANGE[1]-1),
            contrast_limit=0,
            p=1.0
        ), 1.0)
    ]
    
    # Apply augmentations
    for aug_idx, (augmentation, prob) in enumerate(transform_configs):
        if random.random() >= prob:
            continue
            
        try:
            # Create composition with bbox parameters
            transform = A.Compose([augmentation], 
                                bbox_params=A.BboxParams(format='albumentations', 
                                                       label_fields=['class_labels'],
                                                       min_visibility=0.1))
            
            # Apply augmentation
            if len(bboxes) > 0:
                augmented = transform(image=image, bboxes=bboxes, class_labels=class_labels)
                aug_image = augmented['image']
                aug_bboxes = augmented['bboxes']
                aug_class_labels = augmented['class_labels']
            else:
                # If no bboxes, just transform the image
                augmented = transform(image=image, bboxes=[], class_labels=[])
                aug_image = augmented['image']
                aug_bboxes = []
                aug_class_labels = []
            
            # Convert back to BGR for saving
            aug_image_bgr = cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)
            
            # Save augmented image
            aug_filename = f"{base_filename}_aug_{aug_idx}"
            aug_img_path = output_img_dir / f"{aug_filename}.jpg"
            aug_label_path = output_label_dir / f"{aug_filename}.txt"
            
            cv2.imwrite(str(aug_img_path), aug_image_bgr)
            
            # Save augmented labels
            with open(aug_label_path, 'w') as f:
                for class_id, norm_bbox in zip(aug_class_labels, aug_bboxes):
                    # Convert normalized bbox back to absolute coordinates
                    x1_norm, y1_norm, x2_norm, y2_norm = norm_bbox
                    x1 = x1_norm * aug_image.shape[1]
                    y1 = y1_norm * aug_image.shape[0]
                    x2 = x2_norm * aug_image.shape[1]
                    y2 = y2_norm * aug_image.shape[0]
                    
                    # Convert to YOLO format
                    cx, cy, w, h = bbox_to_yolo([x1, y1, x2, y2], aug_image.shape[1], aug_image.shape[0])
                    f.write(f"{class_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")
            
            augmented_files.append((aug_img_path, aug_label_path))
            
        except Exception as e:
            print(f"Failed to apply augmentation {aug_idx}: {e}")
            continue
    
    return augmented_files

## 4. Synthetic Data Generation (Based on synthetic_data_generation.ipynb)

In [None]:
def load_object_bank():
    """Loads the paths of all objects in the bank, organized by class name."""
    object_bank = defaultdict(list)
    print("🔎 Loading Object Bank...")
    if not OBJECT_BANK_DIR.exists():
        print("  Object bank directory not found.")
        return object_bank
        
    for class_dir in OBJECT_BANK_DIR.iterdir():
        if class_dir.is_dir():
            for obj_file in class_dir.glob('*.png'):
                object_bank[class_dir.name].append(obj_file)
    
    print("  Object Bank loaded successfully.")
    for class_name, objects in object_bank.items():
        print(f"  - Found {len(objects)} objects for class '{class_name}'")
    return object_bank

def create_gradient_background(size):
    """Creates a background with a random linear gradient."""
    width, height = size
    color1 = np.random.randint(150, 255, 3)
    color2 = np.random.randint(150, 255, 3)
    
    background = np.zeros((height, width, 3), dtype=np.uint8)
    for y in range(height):
        ratio = y / height
        color = (color1 * (1 - ratio) + color2 * ratio).astype(np.uint8)
        background[y, :] = color
        
    return Image.fromarray(background)

def calculate_iou(box1, box2):
    """Calculates IoU for two boxes in (x1, y1, x2, y2) format."""
    x1_i = max(box1[0], box2[0])
    y1_i = max(box1[1], box2[1])
    x2_i = min(box1[2], box2[2])
    y2_i = min(box1[3], box2[3])
    
    inter_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = box1_area + box2_area - inter_area
    
    return inter_area / union_area if union_area > 0 else 0

def generate_pure_synthetic_image(object_bank, classes_to_add, image_size):
    """Generates a completely synthetic image with gradient background."""
    background = create_gradient_background(image_size)
    placed_object_bboxes = []
    yolo_annotations = []

    for class_name in classes_to_add:
        if class_name not in TARGET_CLASSES:
            continue
        class_id = TARGET_CLASSES.index(class_name)
        
        if not object_bank[class_name]: 
            continue

        obj_path = random.choice(object_bank[class_name])
        obj_img = Image.open(obj_path)

        # Random transformations
        scale = random.uniform(*SCALE_RANGE)
        new_size = (int(image_size[0] * scale), int(image_size[1] * scale))
        obj_img.thumbnail(new_size, Image.Resampling.LANCZOS)
        rotation = random.uniform(*ROTATION_RANGE)
        obj_img = obj_img.rotate(rotation, expand=True, resample=Image.Resampling.BICUBIC)

        # Find a valid placement
        for _ in range(50): # 50 attempts
            pos_x = random.randint(0, image_size[0] - obj_img.width)
            pos_y = random.randint(0, image_size[1] - obj_img.height)
            
            new_bbox_corners = (pos_x, pos_y, pos_x + obj_img.width, pos_y + obj_img.height)
            
            is_overlapping = any(calculate_iou(new_bbox_corners, b) > OVERLAP_THRESHOLD for b in placed_object_bboxes)
            
            if not is_overlapping:
                background.paste(obj_img, (pos_x, pos_y), obj_img)
                placed_object_bboxes.append(new_bbox_corners)
                
                # YOLO format: class_id cx cy w h
                cx = (pos_x + obj_img.width / 2) / image_size[0]
                cy = (pos_y + obj_img.height / 2) / image_size[1]
                w = obj_img.width / image_size[0]
                h = obj_img.height / image_size[1]
                yolo_annotations.append(f"{class_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
                break
                
    return background, yolo_annotations

def generate_augmented_synthetic_image(original_img_path, original_label_path, object_bank, target_classes):
    """Takes an existing image and adds synthetic objects to it."""
    # Load original image
    original_img = cv2.imread(str(original_img_path))
    if original_img is None:
        return None, []
    
    img_height, img_width = original_img.shape[:2]
    original_pil = Image.fromarray(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
    
    # Resize to target size
    original_pil = original_pil.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
    
    # Load existing annotations
    existing_bboxes = []
    yolo_annotations = []
    
    if original_label_path.exists():
        with open(original_label_path, 'r') as f:
            for line in f:
                if line.strip():
                    class_id, bbox = parse_yolo_label(line, IMAGE_SIZE[0], IMAGE_SIZE[1])
                    existing_bboxes.append(bbox)
                    # Keep original annotations
                    cx, cy, w, h = bbox_to_yolo(bbox, IMAGE_SIZE[0], IMAGE_SIZE[1])
                    yolo_annotations.append(f"{class_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
    
    # Add 1-3 synthetic objects
    num_objects_to_add = random.randint(1, 3)
    available_classes = [cls for cls in target_classes if object_bank[cls]]
    
    if not available_classes:
        return original_pil, yolo_annotations
    
    classes_to_add = random.choices(available_classes, k=min(num_objects_to_add, len(available_classes)))
    
    for class_name in classes_to_add:
        class_id = target_classes.index(class_name)
        
        obj_path = random.choice(object_bank[class_name])
        obj_img = Image.open(obj_path)
        
        # Random transformations
        scale = random.uniform(0.1, 0.4)  # Smaller scale for augmented images
        new_size = (int(IMAGE_SIZE[0] * scale), int(IMAGE_SIZE[1] * scale))
        obj_img.thumbnail(new_size, Image.Resampling.LANCZOS)
        rotation = random.uniform(*ROTATION_RANGE)
        obj_img = obj_img.rotate(rotation, expand=True, resample=Image.Resampling.BICUBIC)
        
        # Find placement that doesn't overlap with existing objects
        for _ in range(30):  # 30 attempts
            pos_x = random.randint(0, IMAGE_SIZE[0] - obj_img.width)
            pos_y = random.randint(0, IMAGE_SIZE[1] - obj_img.height)
            
            new_bbox = (pos_x, pos_y, pos_x + obj_img.width, pos_y + obj_img.height)
            
            # Check overlap with existing objects
            is_overlapping = any(calculate_iou(new_bbox, existing_bbox) > OVERLAP_THRESHOLD for existing_bbox in existing_bboxes)
            
            if not is_overlapping:
                original_pil.paste(obj_img, (pos_x, pos_y), obj_img)
                existing_bboxes.append(new_bbox)
                
                # Add new annotation
                cx = (pos_x + obj_img.width / 2) / IMAGE_SIZE[0]
                cy = (pos_y + obj_img.height / 2) / IMAGE_SIZE[1]
                w = obj_img.width / IMAGE_SIZE[0]
                h = obj_img.height / IMAGE_SIZE[1]
                yolo_annotations.append(f"{class_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
                break
    
    return original_pil, yolo_annotations

## 5. Dataset Augmentation and Balancing
TARGET_TOTAL_INSTANCES_PER_FOLD desired number of training images per fold

In [None]:
def create_augmented_kfold_dataset():
    """Creates the augmented k-fold CV dataset with controlled size and proper balancing."""
    print("🚀 Creating Augmented K-Fold CV Dataset...")
    
    # Target parameters per fold
    TARGET_TOTAL_INSTANCES_PER_FOLD = 27000
    TARGET_INSTANCES_PER_CLASS_PER_FOLD = TARGET_TOTAL_INSTANCES_PER_FOLD // NUM_CLASSES
    
    print(f"🎯 Target per fold: {TARGET_TOTAL_INSTANCES_PER_FOLD} total instances ({TARGET_INSTANCES_PER_CLASS_PER_FOLD} per class)")
    
    # Process each fold
    for fold_idx in range(NUM_FOLDS):
        fold_dir = K_FOLD_CV_DIR / f'fold_{fold_idx}'
        augmented_fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fold_idx}'
        
        if not fold_dir.exists():
            print(f"⚠️  Skipping fold_{fold_idx}: source directory not found")
            continue
            
        print(f"\n🔄 Processing fold_{fold_idx}...")
        
        # Create output directory structure for this fold
        train_img_dir = augmented_fold_dir / 'train/images'
        train_label_dir = augmented_fold_dir / 'train/labels'
        val_img_dir = augmented_fold_dir / 'val/images'
        val_label_dir = augmented_fold_dir / 'val/labels'
        
        for dir_path in [train_img_dir, train_label_dir, val_img_dir, val_label_dir]:
            dir_path.mkdir(parents=True, exist_ok=True)
        
        # Copy data.yaml for this fold
        shutil.copy2(fold_dir / 'data.yaml', augmented_fold_dir / 'data.yaml')
        
        # Copy validation set unchanged
        print(f"  📁 Copying validation set for fold_{fold_idx}...")
        val_source_dir = fold_dir / 'val'
        if val_source_dir.exists():
            if (val_source_dir / 'images').exists():
                for img_file in (val_source_dir / 'images').glob('*'):
                    shutil.copy2(img_file, val_img_dir)
            if (val_source_dir / 'labels').exists():
                for label_file in (val_source_dir / 'labels').glob('*'):
                    shutil.copy2(label_file, val_label_dir)
        
        # Copy original training set
        print(f"  📁 Copying original training set for fold_{fold_idx}...")
        train_source_dir = fold_dir / 'train'
        original_train_files = []
        
        if (train_source_dir / 'images').exists():
            for img_file in (train_source_dir / 'images').glob('*'):
                shutil.copy2(img_file, train_img_dir)
                original_train_files.append(img_file.stem)
                
        if (train_source_dir / 'labels').exists():
            for label_file in (train_source_dir / 'labels').glob('*'):
                shutil.copy2(label_file, train_label_dir)
        
        # Get current class distribution for this fold
        original_dist = get_class_distribution(train_label_dir, NUM_CLASSES)
        total_original = sum(original_dist.values())
        
        print(f"  📊 Original distribution for fold_{fold_idx} (total: {total_original}):")
        for i, class_name in enumerate(TARGET_CLASSES):
            count = original_dist.get(i, 0)
            print(f"    - {class_name}: {count} instances")
        
        # Calculate what we need per class to reach target
        needs_per_class = {}
        total_needed_instances = 0
        for i in range(NUM_CLASSES):
            current_count = original_dist.get(i, 0)
            needed = max(0, TARGET_INSTANCES_PER_CLASS_PER_FOLD - current_count)
            needs_per_class[i] = needed
            total_needed_instances += needed
        
        print(f"    🔢 Total instances needed for fold_{fold_idx}: {total_needed_instances}")
        
        if total_needed_instances == 0:
            print(f"    ✅ Fold_{fold_idx} is already balanced at target size!")
            continue
        
        # Calculate distribution of how to generate the needed instances
        traditional_target_instances = int(total_needed_instances * 0.4)  # 40% traditional
        synthetic_target_instances = total_needed_instances - traditional_target_instances  # 60% synthetic
        
        augmented_synthetic_target_instances = int(synthetic_target_instances * AUGMENTED_SYNTHETIC_RATIO)
        pure_synthetic_target_instances = synthetic_target_instances - augmented_synthetic_target_instances
        
        print(f"    📋 Generation plan for fold_{fold_idx}:")
        print(f"      - Traditional augmentations: {traditional_target_instances} instances")
        print(f"      - Synthetic images: {synthetic_target_instances} instances")
        print(f"        ├── Augmented synthetic: {augmented_synthetic_target_instances} instances")
        print(f"        └── Pure synthetic: {pure_synthetic_target_instances} instances")
        
        # Collect images by class for targeted augmentation for this fold
        class_to_images = defaultdict(list)
        
        for label_file in (train_source_dir / 'labels').glob('*.txt'):
            img_file = train_source_dir / 'images' / f"{label_file.stem}.jpg"
            if not img_file.exists():
                for ext in ['.png', '.jpeg', '.JPG', '.PNG']:
                    alt_path = train_source_dir / 'images' / f"{label_file.stem}{ext}"
                    if alt_path.exists():
                        img_file = alt_path
                        break
            
            if not img_file.exists():
                continue
            
            # Read classes in this image
            classes_in_image = set()
            try:
                with open(label_file, 'r') as f:
                    for line in f:
                        if line.strip():
                            class_id = int(float(line.split()[0]))
                            if class_id < NUM_CLASSES:
                                classes_in_image.add(class_id)
            except:
                continue
            
            # Add this image to all classes it contains
            for class_id in classes_in_image:
                class_to_images[class_id].append((img_file, label_file))
        
        print(f"    📋 Images available per class for fold_{fold_idx}:")
        for i in range(NUM_CLASSES):
            count = len(class_to_images[i])
            print(f"      - Class {i} ({TARGET_CLASSES[i]}): {count} images")
        
        # Phase 1: Apply controlled traditional augmentations (40% of needed instances)
        print(f"    🔄 Phase 1: Applying traditional augmentations for fold_{fold_idx} (target: {traditional_target_instances} instances)...")
        augmented_count = 0
        traditional_needs = needs_per_class.copy()
        
        # Limit traditional augmentations per class to maintain balance
        for class_id in range(NUM_CLASSES):
            max_traditional_for_class = max(1, int(traditional_target_instances / NUM_CLASSES))
            traditional_needs[class_id] = min(traditional_needs[class_id], max_traditional_for_class)
        
        for class_id in range(NUM_CLASSES):
            needed = traditional_needs[class_id]
            if needed <= 0:
                continue
                
            class_name = TARGET_CLASSES[class_id]
            available_images = class_to_images[class_id]
            
            if not available_images:
                print(f"      ⚠️  Warning: No images found for class {class_name} in fold_{fold_idx}")
                continue
            
            print(f"      🎨 Generating {needed} traditional augmentations for {class_name} in fold_{fold_idx}...")
            
            generated_for_class = 0
            attempts = 0
            max_attempts = len(available_images) * 5  # Limit attempts
            
            while generated_for_class < needed and attempts < max_attempts:
                img_path, label_path = random.choice(available_images)
                
                base_filename = f"{img_path.stem}_fold{fold_idx}_cls{class_id}_aug{attempts}"
                aug_files = apply_traditional_augmentations(
                    img_path, label_path, train_img_dir, train_label_dir, base_filename
                )
                
                # Count instances of our target class that were generated
                for aug_img_path, aug_label_path in aug_files:
                    if aug_label_path.exists():
                        with open(aug_label_path, 'r') as f:
                            for line in f:
                                if line.strip():
                                    try:
                                        line_class_id = int(float(line.split()[0]))
                                        if line_class_id == class_id:
                                            generated_for_class += 1
                                            augmented_count += 1
                                            if generated_for_class >= needed:
                                                break
                                    except:
                                        continue
                        if generated_for_class >= needed:
                            break
                
                attempts += 1
            
            # Update original needs
            needs_per_class[class_id] = max(0, needs_per_class[class_id] - generated_for_class)
            print(f"        ✅ Generated {generated_for_class} traditional augmentations for {class_name} in fold_{fold_idx}")
        
        print(f"    ✅ Total traditional augmentations for fold_{fold_idx}: {augmented_count}")
        
        # Phase 2: Generate synthetic images for remaining needs
        remaining_needed = sum(needs_per_class.values())
        if remaining_needed > 0:
            print(f"    🧬 Phase 2: Generating synthetic data for fold_{fold_idx} - {remaining_needed} remaining instances...")
            print(f"      📊 Target distribution: {AUGMENTED_SYNTHETIC_RATIO:.0%} augmented synthetic + {PURE_SYNTHETIC_RATIO:.0%} pure synthetic")
            
            # Load object bank
            object_bank = load_object_bank()
            
            if not any(object_bank.values()):
                print(f"      ❌ Cannot generate synthetic data for fold_{fold_idx}: Object bank is empty")
            else:
                synthetic_count = 0
                augmented_synthetic_count = 0
                pure_synthetic_count = 0
                needs = needs_per_class.copy()  # Working copy
                
                # Target synthetic image counts (estimate 2-3 instances per image)
                estimated_images_needed = remaining_needed // 2
                target_augmented_synthetic_images = int(estimated_images_needed * AUGMENTED_SYNTHETIC_RATIO)
                target_pure_synthetic_images = int(estimated_images_needed * PURE_SYNTHETIC_RATIO)
                
                print(f"      🎯 Estimated target for fold_{fold_idx}: ~{target_augmented_synthetic_images} augmented synthetic + ~{target_pure_synthetic_images} pure synthetic images")
                
                pbar = tqdm(total=remaining_needed, desc=f"Synthetic generation fold_{fold_idx}")
                
                while any(n > 0 for n in needs.values()) and synthetic_count < remaining_needed * 2:  # Safety limit
                    needed_classes_ids = [cid for cid, n in needs.items() if n > 0]
                    
                    if not needed_classes_ids:
                        break
                    
                    # Decide whether to generate augmented synthetic or pure synthetic
                    current_total_synthetic = augmented_synthetic_count + pure_synthetic_count
                    current_aug_ratio = augmented_synthetic_count / max(1, current_total_synthetic)
                    
                    should_generate_augmented = (
                        current_aug_ratio < AUGMENTED_SYNTHETIC_RATIO and 
                        augmented_synthetic_count < target_augmented_synthetic_images and
                        len(class_to_images.get(random.choice(needed_classes_ids), [])) > 0  # Has source images
                    )
                    
                    if should_generate_augmented:
                        # Generate augmented synthetic image
                        candidate_classes = [cid for cid in needed_classes_ids if len(class_to_images.get(cid, [])) > 0]
                        if candidate_classes:
                            class_id = random.choice(candidate_classes)
                            source_img_path, source_label_path = random.choice(class_to_images[class_id])
                            
                            img, annotations = generate_augmented_synthetic_image(
                                source_img_path, source_label_path, object_bank, TARGET_CLASSES
                            )
                            
                            if img and annotations:
                                timestamp = pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')
                                img_filename = f"fold{fold_idx}_aug_synthetic_{timestamp}.jpg"
                                img.save(train_img_dir / img_filename)
                                
                                label_filename = Path(img_filename).with_suffix('.txt')
                                with open(train_label_dir / label_filename, 'w') as f:
                                    f.write("\n".join(annotations))
                                
                                augmented_synthetic_count += 1
                            else:
                                should_generate_augmented = False
                        else:
                            should_generate_augmented = False
                    
                    if not should_generate_augmented:
                        # Generate pure synthetic image
                        num_objects = random.randint(1, min(4, len(needed_classes_ids)))
                        classes_to_request_ids = random.sample(needed_classes_ids, num_objects)
                        classes_to_request_names = [TARGET_CLASSES[cid] for cid in classes_to_request_ids]

                        img, annotations = generate_pure_synthetic_image(object_bank, classes_to_request_names, IMAGE_SIZE)
                        
                        if img and annotations:
                            timestamp = pd.Timestamp.now().strftime('%Y%m%d%H%M%S%f')
                            img_filename = f"fold{fold_idx}_pure_synthetic_{timestamp}.jpg"
                            img.save(train_img_dir / img_filename)
                            
                            label_filename = Path(img_filename).with_suffix('.txt')
                            with open(train_label_dir / label_filename, 'w') as f:
                                f.write("\n".join(annotations))
                            
                            pure_synthetic_count += 1
                        else:
                            continue
                    
                    # Update needs and progress (for both types)
                    if annotations:
                        instances_added = 0
                        for ann in annotations:
                            class_id = int(ann.split()[0])
                            if class_id < NUM_CLASSES and needs[class_id] > 0:
                                needs[class_id] -= 1
                                instances_added += 1
                        
                        if instances_added > 0:
                            pbar.update(instances_added)
                        synthetic_count += 1
                
                pbar.close()
                total_synthetic_images = augmented_synthetic_count + pure_synthetic_count
                print(f"      ✅ Generated {total_synthetic_images} synthetic images for fold_{fold_idx}:")
                print(f"        - Augmented synthetic: {augmented_synthetic_count} images")
                print(f"        - Pure synthetic: {pure_synthetic_count} images")
                
                if total_synthetic_images > 0:
                    actual_aug_ratio = augmented_synthetic_count / total_synthetic_images
                    actual_pure_ratio = pure_synthetic_count / total_synthetic_images
                    print(f"        - Actual distribution: {actual_aug_ratio:.1%} augmented + {actual_pure_ratio:.1%} pure")
        
        # Final summary for this fold
        final_dist = get_class_distribution(train_label_dir, NUM_CLASSES)
        final_total = sum(final_dist.values())
        
        print(f"    🎉 Fold_{fold_idx} augmentation complete!")
        print(f"      - Original instances: {total_original}")
        print(f"      - Final instances: {final_total}")
        print(f"      - Total increase: +{final_total - total_original} (+{((final_total/total_original - 1)*100):.1f}%)")
        print(f"      - Total training images: {len(list(train_img_dir.glob('*')))}")
        
        # Check balance for this fold
        min_count = min(final_dist.values()) if final_dist else 0
        max_count = max(final_dist.values()) if final_dist else 1
        balance_ratio = min_count / max_count if max_count > 0 else 0
        
        print(f"      📊 Final class distribution for fold_{fold_idx}:")
        for i, class_name in enumerate(TARGET_CLASSES):
            count = final_dist.get(i, 0)
            percentage = (count / final_total * 100) if final_total > 0 else 0
            print(f"        - {class_name}: {count} instances ({percentage:.1f}%)")
        
        print(f"      Balance ratio: {balance_ratio:.2f} (1.0 = perfect balance)")
        if balance_ratio >= 0.8:
            print(f"      ✅ Fold_{fold_idx} is well balanced")
        else:
            print(f"      ⚠️  Fold_{fold_idx} could be better balanced")
    
    print(f"\n🎉 All folds augmentation completed!")
    print(f"Augmented k-fold CV dataset available at: {K_FOLD_CV_AUGMENTED_DIR}")
    
    return K_FOLD_CV_AUGMENTED_DIR

## 6. Verification and Analysis

In [None]:
def verify_augmented_kfold_dataset():
    """Verify the final class distribution of the augmented k-fold CV dataset."""
    print("📊 Verifying Augmented K-Fold CV Dataset...")
    
    if not K_FOLD_CV_AUGMENTED_DIR.exists():
        print("❌ Augmented k-fold CV dataset not found.")
        return
    
    overall_stats = {}
    fold_results = {}
    
    for fold_idx in range(NUM_FOLDS):
        augmented_fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fold_idx}'
        original_fold_dir = K_FOLD_CV_DIR / f'fold_{fold_idx}'
        
        if not augmented_fold_dir.exists():
            print(f"⚠️  Augmented fold_{fold_idx} not found, skipping...")
            continue
        
        print(f"\n📁 Verifying fold_{fold_idx}...")
        
        # Get final distribution
        train_labels_dir = augmented_fold_dir / 'train/labels'
        final_distribution = get_class_distribution(train_labels_dir, NUM_CLASSES)
        
        val_labels_dir = augmented_fold_dir / 'val/labels'
        val_distribution = get_class_distribution(val_labels_dir, NUM_CLASSES)
        
        # Get original distribution for comparison
        original_train_dist = {}
        if original_fold_dir.exists():
            original_train_labels_dir = original_fold_dir / 'train/labels'
            original_train_dist = get_class_distribution(original_train_labels_dir, NUM_CLASSES)
        
        print(f"  Final Train Set Distribution for fold_{fold_idx}:")
        if not final_distribution:
            print("    No labels found.")
        else:
            for i, class_name in enumerate(TARGET_CLASSES):
                original_count = original_train_dist.get(i, 0)
                final_count = final_distribution[i]
                increase = final_count - original_count
                print(f"    - {class_name} (ID {i}): {final_count} instances (+{increase})")
        
        print(f"  Final Val Set Distribution for fold_{fold_idx}:")
        for i, class_name in enumerate(TARGET_CLASSES):
            print(f"    - {class_name} (ID {i}): {val_distribution[i]} instances")
        
        # Store results for this fold
        fold_results[fold_idx] = {
            'original_train': original_train_dist,
            'final_train': final_distribution,
            'val': val_distribution
        }
    
    # Plot comparison for all folds
    if fold_results:
        num_folds_found = len(fold_results)
        fig, axes = plt.subplots(2, num_folds_found, figsize=(6*num_folds_found, 10))
        if num_folds_found == 1:
            axes = axes.reshape(-1, 1)
        
        for idx, (fold_idx, results) in enumerate(fold_results.items()):
            # Before vs After comparison for train set
            class_names = [f"C{i}" for i in range(NUM_CLASSES)]
            original_counts = [results['original_train'].get(i, 0) for i in range(NUM_CLASSES)]
            final_counts = [results['final_train'][i] for i in range(NUM_CLASSES)]
            
            x = np.arange(len(class_names))
            width = 0.35
            
            axes[0, idx].bar(x - width/2, original_counts, width, label='Original', alpha=0.8)
            axes[0, idx].bar(x + width/2, final_counts, width, label='Augmented', alpha=0.8)
            axes[0, idx].set_xlabel('Classes')
            axes[0, idx].set_ylabel('Instances')
            axes[0, idx].set_title(f'Fold {fold_idx} - Before vs After Augmentation')
            axes[0, idx].set_xticks(x)
            axes[0, idx].set_xticklabels(class_names)
            axes[0, idx].legend()
            
            # Final train distribution
            final_df = pd.DataFrame.from_dict(results['final_train'], orient='index').sort_index()
            final_df.plot(kind='bar', legend=False, ax=axes[1, idx], 
                         title=f'Fold {fold_idx} - Final Train Distribution')
            axes[1, idx].set_xlabel('Class ID')
            axes[1, idx].set_ylabel('Instances')
        
        plt.tight_layout()
        plt.show()
    
    # Overall summary statistics
    print(f"\n📈 Overall Summary:")
    total_original_across_folds = 0
    total_final_across_folds = 0
    
    for fold_idx, results in fold_results.items():
        original_total = sum(results['original_train'].values())
        final_total = sum(results['final_train'].values())
        increase_total = final_total - original_total
        increase_percent = (increase_total / original_total) * 100 if original_total > 0 else 0
        
        total_original_across_folds += original_total
        total_final_across_folds += final_total
        
        print(f"  Fold {fold_idx}:")
        print(f"    - Original training instances: {original_total}")
        print(f"    - Final training instances: {final_total}")
        print(f"    - Increase: +{increase_total} (+{increase_percent:.1f}%)")
        
        # Check balance for this fold
        min_count = min(results['final_train'].values()) if results['final_train'] else 0
        max_count = max(results['final_train'].values()) if results['final_train'] else 1
        balance_ratio = min_count / max_count if max_count > 0 else 0
        
        print(f"    - Balance ratio: {balance_ratio:.2f}")
        
        if balance_ratio > 0.8:
            print(f"    - ✅ Well balanced!")
        elif balance_ratio > 0.6:
            print(f"    - ⚠️ Moderate imbalance")
        else:
            print(f"    - ❌ Significant imbalance")
    
    # Overall stats
    if total_original_across_folds > 0:
        overall_increase = total_final_across_folds - total_original_across_folds
        overall_increase_percent = (overall_increase / total_original_across_folds) * 100
        
        print(f"\n🌟 Across all folds:")
        print(f"  - Total original instances: {total_original_across_folds}")
        print(f"  - Total final instances: {total_final_across_folds}")
        print(f"  - Total increase: +{overall_increase} (+{overall_increase_percent:.1f}%)")
        print(f"  - Average instances per fold: {total_final_across_folds // len(fold_results)}")
    
    return fold_results

## 7. Run Augmentation Pipeline

In [None]:
# Run the complete k-fold augmentation pipeline
if TARGET_CLASSES and NUM_CLASSES > 0:
    print("🚀 Starting K-Fold CV Dataset Augmentation Pipeline...")
    print(f"Classes to process: {TARGET_CLASSES}")
    print(f"Number of folds: {NUM_FOLDS}")
    
    # Create the augmented k-fold CV dataset
    augmented_dataset_path = create_augmented_kfold_dataset()
    
    if augmented_dataset_path:
        print(f"\n✅ Augmented k-fold CV dataset created at: {augmented_dataset_path}")
        
        # Verify the results
        final_results = verify_augmented_kfold_dataset()
        
        print(f"\n🎉 K-Fold CV augmentation pipeline completed successfully!")
        print(f"Dataset available at: {K_FOLD_CV_AUGMENTED_DIR}")
        print(f"Structure:")
        print(f"  ├── fold_0/")
        print(f"  │   ├── train/ (original + augmented + synthetic)")
        print(f"  │   │   ├── images/")
        print(f"  │   │   └── labels/")
        print(f"  │   ├── val/ (copied from original)")
        print(f"  │   │   ├── images/")
        print(f"  │   │   └── labels/")
        print(f"  │   └── data.yaml")
        print(f"  ├── fold_1/ (same structure)")
        print(f"  ├── fold_2/ (same structure)")
        print(f"  ├── fold_3/ (same structure)")
        print(f"  └── fold_4/ (same structure)")
    else:
        print("❌ Failed to create augmented k-fold CV dataset")
else:
    print("❌ Cannot proceed: No target classes found or dataset not properly loaded.")
    print("Please check that the k-fold CV dataset exists and has valid data.yaml files.")

## 8. Optional: Quick Dataset Statistics

In [None]:
# Quick statistics about the created k-fold CV dataset
def show_kfold_dataset_statistics():
    """Show quick statistics about the augmented k-fold CV dataset."""
    if not K_FOLD_CV_AUGMENTED_DIR.exists():
        print("Augmented k-fold CV dataset not found.")
        return
    
    print("📊 K-Fold CV Dataset Statistics:")
    
    total_train_images = 0
    total_train_labels = 0
    total_val_images = 0
    total_val_labels = 0
    
    fold_stats = {}
    
    for fold_idx in range(NUM_FOLDS):
        augmented_fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fold_idx}'
        if not augmented_fold_dir.exists():
            continue
            
        train_img_dir = augmented_fold_dir / 'train/images'
        train_label_dir = augmented_fold_dir / 'train/labels'
        val_img_dir = augmented_fold_dir / 'val/images'
        val_label_dir = augmented_fold_dir / 'val/labels'
        
        # Count files for this fold
        fold_train_images = len(list(train_img_dir.glob('*'))) if train_img_dir.exists() else 0
        fold_train_labels = len(list(train_label_dir.glob('*'))) if train_label_dir.exists() else 0
        fold_val_images = len(list(val_img_dir.glob('*'))) if val_img_dir.exists() else 0
        fold_val_labels = len(list(val_label_dir.glob('*'))) if val_label_dir.exists() else 0
        
        total_train_images += fold_train_images
        total_train_labels += fold_train_labels
        total_val_images += fold_val_images
        total_val_labels += fold_val_labels
        
        print(f"  Fold {fold_idx}:")
        print(f"    Training set: {fold_train_images} images, {fold_train_labels} labels")
        print(f"    Validation set: {fold_val_images} images, {fold_val_labels} labels")
        print(f"    Fold total: {fold_train_images + fold_val_images} images")
        
        # Count different types of training images for this fold
        if train_img_dir.exists():
            # Count original images (no special prefix/suffix)
            original_count = len([f for f in train_img_dir.glob('*') 
                                if not any(keyword in f.stem for keyword in ['_aug_', f'fold{fold_idx}_aug_synthetic', f'fold{fold_idx}_pure_synthetic'])])
            
            # Count traditional augmentations (contain '_aug_' but not fold prefixes for synthetic)
            traditional_aug_count = len([f for f in train_img_dir.glob('*') 
                                       if '_aug_' in f.stem and not f.stem.startswith(f'fold{fold_idx}_aug_synthetic') and not f.stem.startswith(f'fold{fold_idx}_pure_synthetic')])
            
            # Count augmented synthetic (start with 'fold{fold_idx}_aug_synthetic')
            aug_synthetic_count = len([f for f in train_img_dir.glob('*') if f.stem.startswith(f'fold{fold_idx}_aug_synthetic')])
            
            # Count pure synthetic (start with 'fold{fold_idx}_pure_synthetic')
            pure_synthetic_count = len([f for f in train_img_dir.glob('*') if f.stem.startswith(f'fold{fold_idx}_pure_synthetic')])
            
            total_synthetic = aug_synthetic_count + pure_synthetic_count
            
            print(f"    📈 Training set breakdown:")
            print(f"      Original images: {original_count}")
            print(f"      Traditional augmentations: {traditional_aug_count}")
            print(f"      Synthetic images total: {total_synthetic}")
            print(f"        ├── Augmented synthetic: {aug_synthetic_count}")
            print(f"        └── Pure synthetic: {pure_synthetic_count}")
            
            # Show percentages for this fold
            if fold_train_images > 0:
                print(f"      📊 Distribution percentages:")
                print(f"        Original: {(original_count/fold_train_images)*100:.1f}%")
                print(f"        Traditional augmentations: {(traditional_aug_count/fold_train_images)*100:.1f}%")
                print(f"        Synthetic: {(total_synthetic/fold_train_images)*100:.1f}%")
                
                if total_synthetic > 0:
                    aug_ratio = aug_synthetic_count / total_synthetic
                    pure_ratio = pure_synthetic_count / total_synthetic
                    print(f"      🎨 Synthetic data distribution:")
                    print(f"        Augmented synthetic: {aug_ratio:.1%} (target: {AUGMENTED_SYNTHETIC_RATIO:.1%})")
                    print(f"        Pure synthetic: {pure_ratio:.1%} (target: {PURE_SYNTHETIC_RATIO:.1%})")
            
            fold_stats[fold_idx] = {
                'train_images': fold_train_images,
                'val_images': fold_val_images,
                'original': original_count,
                'traditional_aug': traditional_aug_count,
                'aug_synthetic': aug_synthetic_count,
                'pure_synthetic': pure_synthetic_count
            }
    
    print(f"\n🌟 Overall Statistics:")
    print(f"  Total across all folds:")
    print(f"    Training images: {total_train_images}")
    print(f"    Training labels: {total_train_labels}")
    print(f"    Validation images: {total_val_images}")
    print(f"    Validation labels: {total_val_labels}")
    print(f"    Grand total images: {total_train_images + total_val_images}")
    
    if fold_stats:
        # Aggregate statistics
        total_original = sum(stats['original'] for stats in fold_stats.values())
        total_traditional_aug = sum(stats['traditional_aug'] for stats in fold_stats.values())
        total_aug_synthetic = sum(stats['aug_synthetic'] for stats in fold_stats.values())
        total_pure_synthetic = sum(stats['pure_synthetic'] for stats in fold_stats.values())
        total_all_synthetic = total_aug_synthetic + total_pure_synthetic
        
        print(f"\n  📈 Aggregate breakdown:")
        print(f"    Total original images: {total_original}")
        print(f"    Total traditional augmentations: {total_traditional_aug}")
        print(f"    Total synthetic images: {total_all_synthetic}")
        print(f"      ├── Augmented synthetic: {total_aug_synthetic}")
        print(f"      └── Pure synthetic: {total_pure_synthetic}")
        
        # Overall percentages
        if total_train_images > 0:
            print(f"\n  📊 Overall distribution percentages:")
            print(f"    Original: {(total_original/total_train_images)*100:.1f}%")
            print(f"    Traditional augmentations: {(total_traditional_aug/total_train_images)*100:.1f}%")
            print(f"    Synthetic: {(total_all_synthetic/total_train_images)*100:.1f}%")
            
            if total_all_synthetic > 0:
                overall_aug_ratio = total_aug_synthetic / total_all_synthetic
                overall_pure_ratio = total_pure_synthetic / total_all_synthetic
                print(f"\n  🎨 Overall synthetic data distribution:")
                print(f"    Augmented synthetic: {overall_aug_ratio:.1%} (target: {AUGMENTED_SYNTHETIC_RATIO:.1%})")
                print(f"    Pure synthetic: {overall_pure_ratio:.1%} (target: {PURE_SYNTHETIC_RATIO:.1%})")
                
                # Show if we're close to target distribution
                aug_diff = abs(overall_aug_ratio - AUGMENTED_SYNTHETIC_RATIO)
                if aug_diff < 0.1:  # Within 10%
                    print(f"    ✅ Overall synthetic distribution matches target well!")
                else:
                    print(f"    ⚠️  Overall synthetic distribution differs from target by {aug_diff:.1%}")

# Show statistics
show_kfold_dataset_statistics()

In [None]:
def show_image_with_bboxes(image_path, label_path, title, ax):
    """Show an image with bounding boxes from a YOLO label file on a given axis."""
    img = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(img)
    
    if label_path.exists():
        with open(label_path, 'r') as f:
            for line in f:
                if line.strip():
                    parts = line.split()
                    class_id = int(float(parts[0]))
                    cx, cy, w, h = map(float, parts[1:])
                    img_width, img_height = img.size
                    x1 = int((cx - w / 2) * img_width)
                    y1 = int((cy - h / 2) * img_height)
                    x2 = int((cx + w / 2) * img_width)
                    y2 = int((cy + h / 2) * img_height)
                    draw.rectangle([x1, y1, x2, y2], outline='red', width=2)
                    draw.text((x1, y1), str(class_id), fill='white')
    ax.imshow(img)
    ax.set_title(title)
    ax.axis('off')

def show_example_images_kfold():
    """Show example images with bounding boxes from the k-fold CV dataset."""
    if not K_FOLD_CV_AUGMENTED_DIR.exists():
        print("Augmented k-fold CV dataset not found.")
        return
    
    # Find examples from the first available fold
    examples_found = []
    
    for fold_idx in range(NUM_FOLDS):
        fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fold_idx}'
        if not fold_dir.exists():
            continue
            
        train_img_dir = fold_dir / 'train/images'
        train_label_dir = fold_dir / 'train/labels'
        
        if not train_img_dir.exists():
            continue
        
        # Look for different types of examples
        traditional_example = next(train_img_dir.glob('*_aug_*.jpg'), None)
        if traditional_example and not traditional_example.stem.startswith(f'fold{fold_idx}_'):
            examples_found.append((traditional_example, f"Traditional Augmentation Example (Fold {fold_idx})"))
        
        aug_synthetic_example = next(train_img_dir.glob(f'fold{fold_idx}_aug_synthetic_*.jpg'), None)
        if aug_synthetic_example:
            examples_found.append((aug_synthetic_example, f"Augmented Synthetic Example (Fold {fold_idx})"))
        
        pure_synthetic_example = next(train_img_dir.glob(f'fold{fold_idx}_pure_synthetic_*.jpg'), None)
        if pure_synthetic_example:
            examples_found.append((pure_synthetic_example, f"Pure Synthetic Example (Fold {fold_idx})"))
        
        # If we found examples from this fold, break (we only need one set)
        if len(examples_found) >= 3:
            break
    
    # Ensure we have exactly 3 examples (pad with None if needed)
    while len(examples_found) < 3:
        examples_found.append((None, "No example found"))
    examples_found = examples_found[:3]  # Take only first 3
    
    # Create the plot
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    found = False
    
    for i, (img_path, title) in enumerate(examples_found):
        if img_path and img_path.exists():
            # Find corresponding label file
            fold_idx = None
            # Extract fold index from title
            for fid in range(NUM_FOLDS):
                if f'Fold {fid}' in title:
                    fold_idx = fid
                    break
            
            if fold_idx is not None:
                fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fold_idx}'
                train_label_dir = fold_dir / 'train/labels'
                label_file = train_label_dir / f"{img_path.stem}.txt"
            else:
                # Fallback: try to find label file in any fold
                label_file = None
                for fid in range(NUM_FOLDS):
                    fold_dir = K_FOLD_CV_AUGMENTED_DIR / f'fold_{fid}'
                    potential_label = fold_dir / 'train/labels' / f"{img_path.stem}.txt"
                    if potential_label.exists():
                        label_file = potential_label
                        break
                if label_file is None:
                    label_file = img_path.parent.parent / 'labels' / f"{img_path.stem}.txt"
            
            show_image_with_bboxes(img_path, label_file, title, axes[i])
            found = True
        else:
            axes[i].axis('off')
            axes[i].set_title(title)
    
    plt.tight_layout()
    plt.show()
    
    if not found:
        print("No example images found in the augmented k-fold CV dataset.")
    else:
        print("Example images shown above represent the different types of data generation applied to the k-fold CV dataset.")

show_example_images_kfold()