In [None]:
from pathlib import Path
import random
import shutil
from collections import defaultdict

# ====== USER SETTINGS ======
RAW_DATASET = r"D:\Abhishek\base_256"     
OUTPUT_DATASET = r"D:\Abhishek\split"  # where to save train/val/test
TRAIN_RATIO = 0.6
VAL_RATIO = 0.2
TEST_RATIO = 0.2
SEED = 42
MIN_PER_ID = 3
COPY_FILES = True   # set to False to move instead of copy
# ===========================

VALID_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
random.seed(SEED)

def list_images(folder: Path):
    return [p for p in folder.iterdir() if p.suffix.lower() in VALID_EXTS and p.is_file()]

def safe_mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def main():
    raw_root = Path(RAW_DATASET)
    out_root = Path(OUTPUT_DATASET)

    # Prepare output dirs
    for split in ("train", "val", "test"):
        safe_mkdir(out_root / split)

    op = shutil.copy2 if COPY_FILES else shutil.move

    split_counts = defaultdict(lambda: {"train": 0, "val": 0, "test": 0})

    for cid_dir in sorted(raw_root.iterdir()):
        if not cid_dir.is_dir():
            continue
        cid = cid_dir.name
        imgs = list_images(cid_dir)

        if len(imgs) < MIN_PER_ID:
            print(f"[skip] {cid}: only {len(imgs)} images (< {MIN_PER_ID})")
            continue

        random.shuffle(imgs)
        n = len(imgs)
        n_train = max(1, int(round(n * TRAIN_RATIO)))
        n_val = max(1, int(round(n * VAL_RATIO)))
        n_test = max(1, n - n_train - n_val)

        while n_train + n_val + n_test > n:
            if n_train > 1:
                n_train -= 1
            elif n_val > 1:
                n_val -= 1
            elif n_test > 1:
                n_test -= 1
            else:
                break

        train_imgs = imgs[:n_train]
        val_imgs = imgs[n_train:n_train + n_val]
        test_imgs = imgs[n_train + n_val:]

        for split_name, split_imgs in zip(
            ["train", "val", "test"], [train_imgs, val_imgs, test_imgs]
        ):
            dst_dir = out_root / split_name / cid
            safe_mkdir(dst_dir)
            for img in split_imgs:
                op(str(img), str(dst_dir / img.name))
            split_counts[cid][split_name] = len(split_imgs)

    print("\n[done] Dataset split complete!")
    for cid, counts in split_counts.items():
        print(f"{cid}: train={counts['train']} val={counts['val']} test={counts['test']}")

if __name__ == "__main__":
    main()


: 