In [3]:
import os
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
import math


def data_split(src_data_path, target_data_path, train_scale, val_scale, test_scale, num_workers, img_format):
    data = datasets.ImageFolder(src_data_path, transforms.ToTensor())
    class_name = list(data.class_to_idx.keys())
    image_size = len(data)
    print("总计:" + str(image_size) + "it")
    train_size = math.ceil(image_size * train_scale)
    test_size = min(image_size - train_size, math.ceil(image_size * test_scale))
    val_size = min(image_size - train_size - test_size, math.ceil(image_size * val_scale))
    loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True, num_workers=num_workers)

    for C in class_name:
        if not os.path.isdir(os.path.join(target_data_path, 'train', C)) and train_scale:
            os.makedirs(os.path.join(target_data_path, 'train', C))
        if not os.path.isdir(os.path.join(target_data_path, 'test', C)) and test_scale:
            os.makedirs(os.path.join(target_data_path, 'test', C))
        if not os.path.isdir(os.path.join(target_data_path, 'val', C)) and val_scale:
            os.makedirs(os.path.join(target_data_path, 'val', C))

    for index, (image, label) in tqdm(enumerate(loader)):
        if train_size > 0:
            save_image(image,
                       os.path.join(target_data_path, 'train', class_name[label.item()], str(index + 1) + '.' + img_format))
            train_size -= 1

        elif test_size > 0:
            save_image(image,
                       os.path.join(target_data_path, 'test', class_name[label.item()], str(index + 1) + '.' + img_format))
            test_size -= 1

        elif val_size > 0:
            save_image(image,
                       os.path.join(target_data_path, 'val', class_name[label.item()], str(index + 1) + '.' + img_format))
            val_size -= 1

    print("切分完成\n保存路径为：" + target_data_path)


if __name__ == '__main__':
    data_split(
        src_data_path='./Attachment_2/',
        target_data_path='./fruit/',
        train_scale=0.8,
        test_scale=0.2,
        val_scale=0.0,
        num_workers=12,
        img_format='jpg'
    )

总计:20705it


20705it [02:08, 161.30it/s]

切分完成
保存路径为：./fruit/



