In [None]:
import os
import shutil
import random
from tqdm import tqdm
from typing import List, Dict


def split_dataset_ss(source_dir: str, target_dir: str, train_ratio: float = 0.8, val_ratio: float = 0.5):

    # defining train, test, val directories
    train_dir = os.path.join(target_dir, 'train')
    test_dir = os.path.join(target_dir, 'test')
    val_dir = os.path.join(target_dir, 'val')

    # creating base dir
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir) 

    os.makedirs(train_dir, exist_ok=False)
    os.makedirs(test_dir, exist_ok=False)
    os.makedirs(val_dir, exist_ok=False)

    # get class folders from source
    class_folders = [name for name in os.listdir(source_dir)
                     if os.path.isdir(os.path.join(source_dir, name))]

    # splitting criteria
    for class_name in tqdm(class_folders, desc=f"Splitting classes from {os.path.basename(source_dir)}"):
        class_path = os.path.join(source_dir, class_name)

        # getting all image files in the current class folder
        images = [f for f in os.listdir(class_path) if os.path.isfile(
            os.path.join(class_path, f))]
        random.shuffle(images)
        total_count = len(images)

        if total_count < 1:
            continue

        # splitting for training images
        train_split_point = int(total_count * train_ratio)
        train_images = images[:train_split_point]
        temp_images = images[train_split_point:]

        # splitting for testing, validation images
        val_split_point = int(len(temp_images) * val_ratio)
        val_images = temp_images[:val_split_point]
        test_images = temp_images[val_split_point:]

        splits: Dict[str, List[str]] = {
            'train': train_images,
            'val': val_images,
            'test': test_images
        }

        # copy files to target dir
        for split_name, image_list in splits.items():

            # example target path : data/train/sugarcane_yellow disease
            target_class_dir = os.path.join(target_dir, split_name, class_name)
            os.makedirs(target_class_dir, exist_ok=True)

            for img in image_list:
                src = os.path.join(class_path, img)
                dst = os.path.join(target_class_dir, img)
                shutil.copy(src, dst)

    print(f"\n Split completed. Check the structure in: {target_dir}")
    print(f"   Train: {int(train_ratio*100)}%, Val: {int((1-train_ratio)*val_ratio*100)}%, Test: {int((1-train_ratio)*(1-val_ratio)*100)}%")

    

In [2]:
source='ss'
target='data_ss'


split_dataset_ss(source, target)

Splitting classes from ss: 100%|██████████| 23/23 [01:08<00:00,  2.99s/it]


✅ Split completed. Check the structure in: data_ss
   Train: 80%, Val: 9%, Test: 9%



