# 划分训练集和测试集

同济子豪兄 https://space.bilibili.com/1900783

代码运行[云GPU平台](https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1)

2022-7-22

## 导入工具包

In [2]:
import os
import shutil
import random
import pandas as pd

## 删除其它文件夹（如有）

In [1]:
!rm -rf fruit21_full
!rm -rf fruit21_split
!rm -rf melon17-full
!rm -rf __MACOSX

## 下载数据集

In [10]:
# 下载压缩包
# 如报错 Unable to establish SSL connection. 重新运行本代码块即可
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/melon17/melon17_full.zip

--2022-07-31 17:53:20--  https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/melon17/melon17_full.zip
Connecting to 172.16.0.13:5848... connected.
Proxy request sent, awaiting response... 200 OK
Length: 132464836 (126M) [application/zip]
Saving to: ‘melon17_full.zip’


2022-07-31 17:53:24 (54.4 MB/s) - ‘melon17_full.zip’ saved [132464836/132464836]



In [22]:
# 解压
!unzip melon17_full.zip >> /dev/null
# !unzip melon17_full.zip

## 获得所有类别名称

In [23]:
# 指定数据集路径
dataset_path = 'melon17_full'

In [24]:
dataset_name = dataset_path.split('_')[0]
print('数据集', dataset_name)

数据集 melon17


In [25]:
classes = os.listdir(dataset_path)

In [26]:
classes

['苦瓜',
 '冬瓜',
 '南瓜',
 '人参果',
 '羊角蜜',
 '哈密瓜',
 '白兰瓜',
 '西瓜',
 '佛手瓜',
 '丝瓜',
 '西葫芦',
 '甜瓜-白',
 '甜瓜-伊丽莎白',
 '黄瓜',
 '甜瓜-金',
 '甜瓜-绿',
 '木瓜']

## 创建训练集文件夹和测试集文件夹

In [27]:
# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))

# 创建 test 文件夹
os.mkdir(os.path.join(dataset_path, 'val'))

# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
    os.mkdir(os.path.join(dataset_path, 'train', fruit))
    os.mkdir(os.path.join(dataset_path, 'val', fruit))

## 划分训练集、测试集，移动文件

In [28]:
test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子

In [29]:
df = pd.DataFrame()

print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))

for fruit in classes: # 遍历每个类别

    # 读取该类别的所有图像文件名
    old_dir = os.path.join(dataset_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱

    # 划分训练集和测试集
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
    testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名
    trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名

    # 移动图像至 test 目录
    for image in testset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)         # 获取原始文件路径
        new_test_path = os.path.join(dataset_path, 'val', fruit, image) # 获取 test 目录的新文件路径
        shutil.move(old_img_path, new_test_path) # 移动文件

    # 移动图像至 train 目录
    for image in trainset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)           # 获取原始文件路径
        new_train_path = os.path.join(dataset_path, 'train', fruit, image) # 获取 train 目录的新文件路径
        shutil.move(old_img_path, new_train_path) # 移动文件
    
    # 删除旧文件夹
    assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
    shutil.rmtree(old_dir) # 删除文件夹
    
    # 工整地输出每一类别的数据个数
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
    
    # 保存到表格中
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)

# 重命名数据集文件夹
shutil.move(dataset_path, dataset_name+'_split')

# 数据集各类别数量统计表格，导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)

        类别              训练集数据个数            测试集数据个数      
        苦瓜                152                 37        
        冬瓜                124                 31        
        南瓜                148                 37        
       人参果                147                 36        
       羊角蜜                158                 39        
       哈密瓜                158                 39        
       白兰瓜                104                 25        
        西瓜                159                 39        
       佛手瓜                130                 32        
        丝瓜                152                 38        
       西葫芦                137                 34        
       甜瓜-白                69                 17        
     甜瓜-伊丽莎白               77                 19        
        黄瓜                145                 36        
       甜瓜-金                43                 10        
       甜瓜-绿                36                 8         
        木瓜                156  

In [30]:
df

Unnamed: 0,class,trainset,testset,total
0,苦瓜,152.0,37.0,189.0
1,冬瓜,124.0,31.0,155.0
2,南瓜,148.0,37.0,185.0
3,人参果,147.0,36.0,183.0
4,羊角蜜,158.0,39.0,197.0
5,哈密瓜,158.0,39.0,197.0
6,白兰瓜,104.0,25.0,129.0
7,西瓜,159.0,39.0,198.0
8,佛手瓜,130.0,32.0,162.0
9,丝瓜,152.0,38.0,190.0


## 查看文件目录结构

In [None]:
!sudo snap install tree

In [31]:
!tree melon17_split -L 2

[01;34mmelon17_split[00m
├── [01;34mtrain[00m
│   ├── [01;34m丝瓜[00m
│   ├── [01;34m人参果[00m
│   ├── [01;34m佛手瓜[00m
│   ├── [01;34m冬瓜[00m
│   ├── [01;34m南瓜[00m
│   ├── [01;34m哈密瓜[00m
│   ├── [01;34m木瓜[00m
│   ├── [01;34m甜瓜-伊丽莎白[00m
│   ├── [01;34m甜瓜-白[00m
│   ├── [01;34m甜瓜-绿[00m
│   ├── [01;34m甜瓜-金[00m
│   ├── [01;34m白兰瓜[00m
│   ├── [01;34m羊角蜜[00m
│   ├── [01;34m苦瓜[00m
│   ├── [01;34m西瓜[00m
│   ├── [01;34m西葫芦[00m
│   └── [01;34m黄瓜[00m
└── [01;34mval[00m
    ├── [01;34m丝瓜[00m
    ├── [01;34m人参果[00m
    ├── [01;34m佛手瓜[00m
    ├── [01;34m冬瓜[00m
    ├── [01;34m南瓜[00m
    ├── [01;34m哈密瓜[00m
    ├── [01;34m木瓜[00m
    ├── [01;34m甜瓜-伊丽莎白[00m
    ├── [01;34m甜瓜-白[00m
    ├── [01;34m甜瓜-绿[00m
    ├── [01;34m甜瓜-金[00m
    ├── [01;34m白兰瓜[00m
    ├── [01;34m羊角蜜[00m
    ├── [01;34m苦瓜[00m
    ├── [01;34m西瓜[00m
    ├── [01;34m西葫芦[00m
    └── [01;34m黄瓜[00m

36 directories, 0 files
