## 準備:
- 資料夾結構：
        |-- ori_photo
            |-- class_0
            |-- class_1
            |-- class_2
            |-- class_3
        |-- data_to_train+val.ipynb

## 結果:
- 生成的資料夾:
        |-- data
            |-- train
                |-- class_0
                |-- class_1
                |-- class_2
                |-- class_3
            |-- val
                |-- class_0
                |-- class_1
                |-- class_2
                |-- class_3

In [1]:
import os
from shutil import copy, rmtree
import random


def make_dir(file_path):
    if os.path.exists(file_path):
        # 如果文件夾存在，則先刪除原文件夾再創建
        rmtree(file_path)
    os.makedirs(file_path)


def split_data(input_file_path, output_file_path, split_rate, seed='random'):
    if seed == 'fixed':
        random.seed(0)
    else:
        random.seed()
    # 獲取當前文件路徑
    cwd = os.getcwd()
    input_dataset_path = os.path.join(cwd, input_file_path)
    output_dataset_path = os.path.join(cwd, output_file_path)
    assert os.path.exists(input_dataset_path), f"path '{input_dataset_path}' does not exist."

    dataset_classes = [dataset_class for dataset_class in os.listdir(input_dataset_path) if
                       os.path.isdir(os.path.join(input_dataset_path, dataset_class))]
    # 訓練集
    train_path = os.path.join(output_dataset_path, 'train')
    make_dir(train_path)
    for dataset_class in dataset_classes:
        make_dir(os.path.join(train_path, dataset_class))
    # 驗證集
    val_path = os.path.join(output_dataset_path, 'val')
    make_dir(val_path)
    for dataset_class in dataset_classes:
        make_dir(os.path.join(val_path, dataset_class))

    for dataset_class in dataset_classes:
        input_dataset_class_path = os.path.join(input_dataset_path, dataset_class)
        images = os.listdir(input_dataset_class_path)
        images_num = len(images)
        # 隨機選取驗證集
        val_images = random.sample(images, k=int(images_num * split_rate))
        for index, image in enumerate(images):
            # 獲取圖像路徑
            image_path = os.path.join(input_dataset_class_path, image)
            if image in val_images:
                # 將圖像文件copy到驗證集對應路徑
                copy(image_path, os.path.join(val_path, dataset_class))
            else:
                copy(image_path, os.path.join(train_path, dataset_class))
            #print(f'[{dataset_class}] is processing: {index + 1}/{images_num}')  # 查看分割進度用
    print('Process Finished.')


In [2]:
if __name__ == '__main__':
    original_data_file_path = 'ori_photo'
    spilit_data_file_path = 'data'
    split_rate = 0.2
    split_data(original_data_file_path, spilit_data_file_path, split_rate)

Process Finished.
