In [3]:
import os
import shutil
import random

In [4]:
def split_dataset(
    images_dir,
    labels_dir,
    output_dir,
    splits={"train": 0.8, "val": 0.2},
    create_splits=["train", "val"],
    seed=42
):
    random.seed(seed)
    images = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    images.sort()
    total = len(images)
    print(f"Total images found: {total}")

    # Shuffle images
    random.shuffle(images)

    # Calculate split indices
    split_indices = []
    prev = 0
    for split in create_splits:
        pct = splits.get(split, 0)
        count = int(pct * total)
        split_indices.append((split, prev, prev + count))
        prev += count
    # Adjust last split to include any remainder
    if split_indices:
        last_split, start, _ = split_indices[-1]
        split_indices[-1] = (last_split, start, total)

    # Create output folders
    for split, _, _ in split_indices:
        os.makedirs(os.path.join(output_dir, split, "images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, split, "labels"), exist_ok=True)

    # Copy files
    for split, start, end in split_indices:
        for img_file in images[start:end]:
            label_file = os.path.splitext(img_file)[0] + ".txt"
            shutil.copy2(os.path.join(images_dir, img_file), os.path.join(output_dir, split, "images", img_file))
            shutil.copy2(os.path.join(labels_dir, label_file), os.path.join(output_dir, split, "labels", label_file))
        print(f"{split}: {end-start} images")

In [6]:
if __name__ == "__main__":
    # Example usage:
    split_dataset(
        images_dir="./data/candy_data/images",
        labels_dir="./data/candy_data/labels",
        output_dir="output",
        splits={"train": 0.7, "val": 0.2, "test": 0.1},
        create_splits=["train", "val", "test"]
    )

Total images found: 162
train: 113 images
val: 32 images
test: 17 images
