# 划分训练集和测试集
# Delineate training and test sets

## 导入工具包
## Import Toolkit

In [4522]:
import os
import shutil
import random
import pandas as pd

## 获得所有类别名称
## Get all category names

In [4523]:
# 指定数据集路径/Specify the dataset path
dataset_path = '../3-【Pytorch】迁移学习训练自己的图像分类模型/data/r-ds'#R-X/R-100

In [4524]:
dataset_name = dataset_path.split('_')[0]
print('数据集', dataset_name)

数据集 ../3-【Pytorch】迁移学习训练自己的图像分类模型/data/r-ds


In [4525]:
classes = os.listdir(dataset_path)

In [4526]:
len(classes)

114

In [4527]:
classes

['snowman',
 'sun',
 'cake',
 'submarine',
 'suitcase',
 'skull',
 'trombone',
 'bridge',
 'flower',
 'dog',
 'axe',
 'piano',
 'campfire',
 'banana',
 'church',
 'lantern',
 'basketball',
 'spider',
 'skateboard',
 'parrot',
 'envelope',
 'tornado',
 'asparagus',
 'strawberry',
 'drums',
 'truck',
 'penguin',
 'apple',
 'rhinoceros',
 'flying-saucer',
 'kangaroo',
 'dragon',
 'horse',
 'bat',
 'swan',
 'onion',
 'owl',
 'cloud',
 'butterfly',
 'harp',
 'car',
 'tractor',
 'pig',
 'cat',
 'mountain',
 'hockey-puck',
 'monkey',
 'bicycle',
 'lighthouse',
 'cups',
 'ocean',
 'anvil',
 'hat',
 'circle',
 'microphone',
 'stethoscope',
 'hourglass',
 'lightning',
 'sheep',
 'fish',
 'light-bulb',
 'leaf',
 'mushroom',
 'guitar',
 'soccer-ball',
 'yoga',
 'watermelon',
 'crab',
 'airplane',
 'blueberry',
 'ant',
 'windmill',
 'candle',
 'saxophone',
 'clock',
 'rain',
 'lollipop',
 'sailboat',
 'map',
 'duck',
 'brain',
 'slippers',
 'vase',
 'tent',
 'train',
 'camel',
 'mouse',
 'triangle'

## 创建训练集文件夹和测试集文件夹
## Create training set folder and test set folder

In [4528]:
# 训练和测试集名/Training and test set names
train = 'tra'
val = 'val'

In [4529]:
# 创建 train 文件夹/Create train folder
os.mkdir(os.path.join(dataset_path, train))

# 创建 test 文件夹/Create test folder
os.mkdir(os.path.join(dataset_path, val))

# 在 train 和 test 文件夹中创建各类别子文件夹/Create subfolders for each category in the train and test folders
for fruit in classes:
    os.mkdir(os.path.join(dataset_path, train, fruit))
    os.mkdir(os.path.join(dataset_path, val, fruit))

## 划分训练集、测试集，移动文件
## Delineate training set, test set, move file

In [4530]:
test_frac = 0.2# 测试集比例/Test set ratio
random.seed(123) # 随机数种子，便于复现/Random number seeds for easy reproduction

In [4531]:
df = pd.DataFrame()

print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))

for fruit in classes: # 遍历每个类别/Iterate through each category

    # 读取该类别的所有图像文件名/Retrieve the names of all image files in this category
    old_dir = os.path.join(dataset_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱/randomly upset

    # 划分训练集和测试集/Divide the training set and test set
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数/Number of images in the test set
    # 获取拟移动至 test 目录的测试集图像文件名/Get the name of the test set image file to be moved to the test directory
    testset_images = images_filename[:testset_numer]      
    # 获取拟移动至 train 目录的训练集图像文件名/Get the filename of the training set image to be moved to the train directory
    trainset_images = images_filename[testset_numer:]     

    # 移动图像至 test 目录/Moving images to the test catalog
    for image in testset_images:
        old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径/Get original file path
        # 获取 test 目录的新文件路径/Get the path to the new file in the test directory
        new_test_path = os.path.join(dataset_path, val, fruit, image) 
        shutil.move(old_img_path, new_test_path) # 移动文件/Moving files

    # 移动图像至 train 目录/Moving images to the train catalog
    for image in trainset_images:
        old_img_path = os.path.join(dataset_path, fruit, image) # 获取原始文件路径/Get original file path
        # 获取 train 目录的新文件路径/Get the path to the new file in the train directory
        new_train_path = os.path.join(dataset_path, train, fruit, image) 
        shutil.move(old_img_path, new_train_path) # 移动文件/Moving files
    
    # 删除旧文件夹/Deleting old folders
    # 确保旧文件夹中的所有图像都被移动走/Make sure all the images in the old folder are moved away
    assert len(os.listdir(old_dir)) == 0 
    shutil.rmtree(old_dir) # 删除文件夹/Delete Folder
    
    # 工整地输出每一类别的数据个数/neatly output the number of data in each category
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
    
    # 保存到表格中/Save to table
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)

# 重命名数据集文件夹/Rename the dataset folder
shutil.move(dataset_path, dataset_name + '_' + 'split')

# 数据集各类别数量统计表格，导出为 csv 文件/Table of statistics on the number of categories in the dataset, exported as a csv file
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)

        类别              训练集数据个数            测试集数据个数      
     snowman              100                 25        
       sun                100                 25        
       cake               100                 25        
    submarine             100                 25        
     suitcase             100                 25        
      skull               100                 25        
     trombone             100                 25        
      bridge              100                 25        
      flower              100                 25        
       dog                100                 25        
       axe                100                 25        
      piano               100                 25        
     campfire             100                 25        
      banana              100                 25        
      church              100                 25        
     lantern              100                 25        
    basketball            100  

In [4532]:
df

Unnamed: 0,class,trainset,testset,total
0,snowman,100.0,25.0,125.0
1,sun,100.0,25.0,125.0
2,cake,100.0,25.0,125.0
3,submarine,100.0,25.0,125.0
4,suitcase,100.0,25.0,125.0
...,...,...,...,...
109,snake,100.0,25.0,125.0
110,elephant,100.0,25.0,125.0
111,tiger,100.0,25.0,125.0
112,The-Great-Wall-of-China,100.0,25.0,125.0


## 查看文件目录结构
## Viewing file directory structure

In [4533]:
#!sudo snap install tree

In [4534]:
#!tree fruit81_split -L 2