In [None]:
import os
import shutil
from pathlib import Path
from collections import defaultdict
import random

def get_class_from_label(label_file):
    """Reads the label file and returns the class ID (first number in file)."""
    with open(label_file, 'r') as f:
        lines = f.readlines()
        if not lines:
            return None
        return lines[0].split()[0]

def count_class_distribution(label_dir):
    """Returns a dictionary counting occurrences of each class."""
    counts = defaultdict(int)
    for file in Path(label_dir).glob("*.txt"):
        class_id = get_class_from_label(file)
        if class_id is not None:
            counts[class_id] += 1
    return dict(counts)

def get_original_images(images_dir):
    """Filters images that are original (not augmented)."""
    all_images = list(Path(images_dir).glob("*.*"))
    return [img for img in all_images if "_aug" not in img.stem]

def balance_val_set(train_img_dir, train_lbl_dir, val_img_dir, val_lbl_dir, val_ratio=0.2):
    """Moves images from train to val ensuring per-class 80:20 split with only original images."""
    # Collect class distributions from current train and val
    val_counts = count_class_distribution(val_lbl_dir)
    train_counts = count_class_distribution(train_lbl_dir)

    print("Initial class distribution:")
    print("Train:", train_counts)
    print("Val:", val_counts)

    # Target per-class for validation
    target_val = {
        cls: int((train_counts.get(cls, 0) + val_counts.get(cls, 0)) * val_ratio)
        for cls in set(train_counts) | set(val_counts)
    }

    print("Target Val Composition (20%):", target_val)

    # For each class, move originals from train to val to match target
    moved = 0
    for cls, target in target_val.items():
        current_val = val_counts.get(cls, 0)
        need = target - current_val
        if need <= 0:
            continue

        # Find original train images of this class
        original_images = get_original_images(train_img_dir)
        random.shuffle(original_images)

        candidates = []
        for img in original_images:
            lbl_file = Path(train_lbl_dir) / (img.stem + ".txt")
            if not lbl_file.exists():
                continue
            if get_class_from_label(lbl_file) == cls:
                candidates.append((img, lbl_file))
            if len(candidates) >= need:
                break

        # Move selected
        for img_path, lbl_path in candidates:
            shutil.move(str(img_path), str(Path(val_img_dir) / img_path.name))
            shutil.move(str(lbl_path), str(Path(val_lbl_dir) / lbl_path.name))
            moved += 1

    print(f"\n✅ Moved {moved} images to validation set to rebalance 80:20 per class.")

# EXAMPLE USAGE (customize this for your system)
train_images_path = "Dataset/YOLODatasetFull/images/train"
train_labels_path = "Dataset/YOLODatasetFull/labels/train"
val_images_path = "Dataset/YOLODatasetFull/images/val"
val_labels_path = "Dataset/YOLODatasetFull/labels/val"

balance_val_set(train_images_path, train_labels_path, val_images_path, val_labels_path, val_ratio=0.2)
