# Split the dataset to train-validation-test sets

In [1]:
import os
import random
import shutil
from pathlib import Path
from tqdm import tqdm

In [2]:
def split_dataset(
    image_dir,
    mask_dir=None,
    output_dir="data/split-dataset",
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
    seed=42
):
    """
    Splits images (and corresponding masks, if provided) into train/val/test directories.

    Args:
        image_dir (str): Path to directory containing images.
        mask_dir (str or None): Path to directory containing masks (if available).
        output_dir (str): Root directory for output split folders.
        train_ratio (float): Proportion of data for training.
        val_ratio (float): Proportion of data for validation.
        test_ratio (float): Proportion of data for testing.
        seed (int): Random seed for reproducibility.
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"

    random.seed(seed)

    image_dir = Path(image_dir)
    mask_dir = Path(mask_dir) if mask_dir else None
    output_dir = Path(output_dir)

    image_files = sorted([f for f in image_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg")])
    random.shuffle(image_files)

    total = len(image_files)
    train_end = int(total * train_ratio)
    val_end = train_end + int(total * val_ratio)

    subsets = {
        "train": image_files[:train_end],
        "val": image_files[train_end:val_end],
        "test": image_files[val_end:]
    }

    for split_name, file_list in subsets.items():
        split_img_dir = output_dir / split_name / "img"
        split_mask_dir = output_dir / split_name / "mask" if mask_dir else None

        split_img_dir.mkdir(parents=True, exist_ok=True)
        if split_mask_dir:
            split_mask_dir.mkdir(parents=True, exist_ok=True)

        print(f"Copying {split_name} set ({len(file_list)} files)...")

        for img_path in tqdm(file_list, desc=f"{split_name.capitalize():>5}", unit="file"):
            shutil.copy(img_path, split_img_dir / img_path.name)

            if mask_dir:
                mask_name = img_path.stem + "-mask.png"
                mask_path = mask_dir / mask_name
                if mask_path.exists():
                    shutil.copy(mask_path, split_mask_dir / mask_name)
                else:
                    print(f"[WARNING] Missing mask for: {img_path.name}")

In [3]:
split_dataset(
    image_dir="../../data/v5-damaged-and-mask-dataset/generated-damaged-images",
    mask_dir="../../data/v5-damaged-and-mask-dataset/generated-damage-masks",
    output_dir="../../data/v5-split-dataset"
)

Copying train set (24206 files)...


Train: 100%|██████████| 24206/24206 [02:45<00:00, 146.59file/s]


Copying val set (5187 files)...


  Val: 100%|██████████| 5187/5187 [00:39<00:00, 130.90file/s]


Copying test set (5188 files)...


 Test: 100%|██████████| 5188/5188 [00:40<00:00, 127.77file/s]
