In [2]:
import os
import shutil
import random

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def split_dataset(
    images_dir, labels_dir,
    output_root,
    train_ratio=0.7, valid_ratio=0.2, test_ratio=0.1,
    seed=42
):
    assert abs(train_ratio + valid_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1."

    # Create output folders
    for split in ['train', 'valid', 'test']:
        for subfolder in ['images', 'labels']:
            create_dir(os.path.join(output_root, split, subfolder))

    # Match image and label filenames (without extension)
    image_files = sorted([f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
    label_files = sorted([f for f in os.listdir(labels_dir) if f.lower().endswith('.txt')])

    # Ensure corresponding labels exist
    base_names = [os.path.splitext(f)[0] for f in image_files]
    paired_files = [
        (img, f"{os.path.splitext(img)[0]}.txt")
        for img in image_files if f"{os.path.splitext(img)[0]}.txt" in label_files
    ]

    print(f"Found {len(paired_files)} matched image-label pairs.")

    # Shuffle
    random.seed(seed)
    random.shuffle(paired_files)

    # Split
    total = len(paired_files)
    train_end = int(train_ratio * total)
    valid_end = train_end + int(valid_ratio * total)

    train_files = paired_files[:train_end]
    valid_files = paired_files[train_end:valid_end]
    test_files = paired_files[valid_end:]

    # Helper to copy files
    def copy_files(file_list, split_name):
        for img_file, lbl_file in file_list:
            shutil.copy2(os.path.join(images_dir, img_file), os.path.join(output_root, split_name, 'images', img_file))
            shutil.copy2(os.path.join(labels_dir, lbl_file), os.path.join(output_root, split_name, 'labels', lbl_file))

    # Copy files
    copy_files(train_files, 'train')
    copy_files(valid_files, 'valid')
    copy_files(test_files, 'test')

    print("Split completed:")
    print(f"  Train: {len(train_files)}")
    print(f"  Valid: {len(valid_files)}")
    print(f"  Test : {len(test_files)}")

# ==== Example Usage ====
split_dataset(
    images_dir='C:/Users/admin/Documents/GitHub/Pill_Identification/Pill_Jpeg_Processed',
    labels_dir='C:/Users/admin/Documents/GitHub/Pill_Identification/Pill_YOLO_Labels',
    output_root='C:/Users/admin/Documents/GitHub/Pill_Identification',
    train_ratio=0.7,
    valid_ratio=0.2,
    test_ratio=0.1
)


Found 205 matched image-label pairs.
Split completed:
  Train: 143
  Valid: 41
  Test : 21
