# 划分训练集和测试集

同济子豪兄 https://space.bilibili.com/1900783

2022-7-22

## 导入工具包

In [1]:
import os
import shutil
import random
import pandas as pd

## 删除其它文件夹（如有）

In [2]:
!rm -rf fruit21_full
!rm -rf fruit21_split

## 下载数据集

In [3]:
# 下载压缩包
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit21_full.zip -O fruit21_full.zip

--2022-07-23 10:46:38--  https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit21_full.zip
Connecting to 172.16.0.13:5848... connected.
Proxy request sent, awaiting response... 200 OK
Length: 206385999 (197M) [application/zip]
Saving to: ‘fruit21_full.zip’


2022-07-23 10:46:44 (33.7 MB/s) - ‘fruit21_full.zip’ saved [206385999/206385999]



In [4]:
# 解压
!unzip fruit21_full.zip >> /dev/null
# !unzip fruit21_full.zip

## 获得所有类别名称

In [5]:
# 指定数据集路径
dataset_path = 'fruit21_full'

In [6]:
classes = os.listdir(dataset_path)

In [7]:
classes

['圣女果',
 '芒果',
 '菠萝',
 '猕猴桃',
 '哈密瓜',
 '山楂',
 '脐橙',
 '杨梅',
 '草莓',
 '椰子',
 '西瓜',
 '桂圆',
 '荔枝',
 '香蕉',
 '水蜜桃',
 '柠檬',
 '砂糖橘',
 '樱桃',
 '榴莲',
 '西红柿',
 '油桃']

## 创建训练集文件夹和测试集文件夹

In [8]:
# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))

# 创建 test 文件夹
os.mkdir(os.path.join(dataset_path, 'test'))

# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
    os.mkdir(os.path.join(dataset_path, 'train', fruit))
    os.mkdir(os.path.join(dataset_path, 'test', fruit))

## 划分训练集、测试集，移动文件

In [9]:
test_frac = 0.2  # 测试集比例
random.seed(123) # 随机数种子

In [10]:
df = pd.DataFrame()

print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))

for fruit in classes: # 遍历每个类别

    # 读取该类别的所有图像文件名
    old_dir = os.path.join(dataset_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱

    # 划分训练集和测试集
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
    testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名
    trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名

    # 移动图像至 test 目录
    for image in testset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)         # 获取原始文件路径
        new_test_path = os.path.join(dataset_path, 'test', fruit, image) # 获取 test 目录的新文件路径
        shutil.move(old_img_path, new_test_path) # 移动文件

    # 移动图像至 train 目录
    for image in trainset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)           # 获取原始文件路径
        new_train_path = os.path.join(dataset_path, 'train', fruit, image) # 获取 train 目录的新文件路径
        shutil.move(old_img_path, new_train_path) # 移动文件
    
    # 删除旧文件夹
    assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
    shutil.rmtree(old_dir) # 删除文件夹
    
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)

# 重命名数据集文件夹
shutil.move(dataset_path, 'fruit21_split')

# 数据集各类别数量统计表格，导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据集统计.csv', index=False)

        类别              训练集数据个数            测试集数据个数      
       圣女果                160                 39        
        芒果                119                 29        
        菠萝                160                 40        
       猕猴桃                161                 40        
       哈密瓜                160                 39        
        山楂                159                 39        
        脐橙                148                 37        
        杨梅                152                 38        
        草莓                160                 40        
        椰子                160                 40        
        西瓜                156                 39        
        桂圆                160                 40        
        荔枝                141                 35        
        香蕉                161                 40        
       水蜜桃                153                 38        
        柠檬                149                 37        
       砂糖橘                148  

In [11]:
df

Unnamed: 0,class,trainset,testset,total
0,圣女果,160.0,39.0,199.0
1,芒果,119.0,29.0,148.0
2,菠萝,160.0,40.0,200.0
3,猕猴桃,161.0,40.0,201.0
4,哈密瓜,160.0,39.0,199.0
5,山楂,159.0,39.0,198.0
6,脐橙,148.0,37.0,185.0
7,杨梅,152.0,38.0,190.0
8,草莓,160.0,40.0,200.0
9,椰子,160.0,40.0,200.0


In [12]:
df.to_csv('数据集统计.csv', index=False)

## 查看文件目录结构

In [None]:
!sudo snap install tree

In [14]:
!tree fruit21_split -L 2

[01;34mfruit21_split[00m
├── [01;34mtest[00m
│   ├── [01;34m哈密瓜[00m
│   ├── [01;34m圣女果[00m
│   ├── [01;34m山楂[00m
│   ├── [01;34m杨梅[00m
│   ├── [01;34m柠檬[00m
│   ├── [01;34m桂圆[00m
│   ├── [01;34m椰子[00m
│   ├── [01;34m榴莲[00m
│   ├── [01;34m樱桃[00m
│   ├── [01;34m水蜜桃[00m
│   ├── [01;34m油桃[00m
│   ├── [01;34m猕猴桃[00m
│   ├── [01;34m砂糖橘[00m
│   ├── [01;34m脐橙[00m
│   ├── [01;34m芒果[00m
│   ├── [01;34m草莓[00m
│   ├── [01;34m荔枝[00m
│   ├── [01;34m菠萝[00m
│   ├── [01;34m西瓜[00m
│   ├── [01;34m西红柿[00m
│   └── [01;34m香蕉[00m
└── [01;34mtrain[00m
    ├── [01;34m哈密瓜[00m
    ├── [01;34m圣女果[00m
    ├── [01;34m山楂[00m
    ├── [01;34m杨梅[00m
    ├── [01;34m柠檬[00m
    ├── [01;34m桂圆[00m
    ├── [01;34m椰子[00m
    ├── [01;34m榴莲[00m
    ├── [01;34m樱桃[00m
    ├── [01;34m水蜜桃[00m
    ├── [01;34m油桃[00m
    ├── [01;34m猕猴桃[00m
    ├── [01;34m砂糖橘[00m
    ├── [01;34m脐橙[00m
    ├── [01;34m芒果[00m
    ├── [01;34m草莓[00m
    ├── [01;34m荔枝[00m
    ├