# 背景：使用基于卷积的深度神经网络 ResNet50 对 30 种水果进行分类

## 1、划分训练集和验证集

In [1]:
import os
import shutil
import random
import pandas as pd

In [2]:
# 指定数据集路径
dataset_path = 'data/fruit30_train'

In [3]:
dataset_name = dataset_path.split('_')[0]
print('数据集', dataset_name)

数据集 data/fruit30


In [4]:
classes = os.listdir(dataset_path)

In [5]:
len(classes)

30

In [6]:
classes

['哈密瓜',
 '圣女果',
 '山竹',
 '杨梅',
 '柚子',
 '柠檬',
 '桂圆',
 '梨',
 '椰子',
 '榴莲',
 '火龙果',
 '猕猴桃',
 '石榴',
 '砂糖橘',
 '胡萝卜',
 '脐橙',
 '芒果',
 '苦瓜',
 '苹果-红',
 '苹果-青',
 '草莓',
 '荔枝',
 '菠萝',
 '葡萄-白',
 '葡萄-红',
 '西瓜',
 '西红柿',
 '车厘子',
 '香蕉',
 '黄瓜']

In [7]:
# 创建 train 文件夹
os.mkdir(os.path.join(dataset_path, 'train'))

# 创建 val 文件夹
os.mkdir(os.path.join(dataset_path, 'val'))

# 在 train 和 test 文件夹中创建各类别子文件夹
for fruit in classes:
    os.mkdir(os.path.join(dataset_path, 'train', fruit))
    os.mkdir(os.path.join(dataset_path, 'val', fruit))

In [8]:
test_frac = 0.2  # 验证集比例
random.seed(123) # 随机数种子，便于复现

In [9]:
# 划分训练集、验证集，移动文件
df = pd.DataFrame()

print('{:^18} {:^18} {:^18}'.format('类别', '训练集数据个数', '测试集数据个数'))

for fruit in classes: # 遍历每个类别

    # 读取该类别的所有图像文件名
    old_dir = os.path.join(dataset_path, fruit)
    images_filename = os.listdir(old_dir)
    random.shuffle(images_filename) # 随机打乱

    # 划分训练集和测试集
    testset_numer = int(len(images_filename) * test_frac) # 测试集图像个数
    testset_images = images_filename[:testset_numer]      # 获取拟移动至 test 目录的测试集图像文件名
    trainset_images = images_filename[testset_numer:]     # 获取拟移动至 train 目录的训练集图像文件名

    # 移动图像至 test 目录
    for image in testset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)         # 获取原始文件路径
        new_test_path = os.path.join(dataset_path, 'val', fruit, image) # 获取 test 目录的新文件路径
        shutil.move(old_img_path, new_test_path) # 移动文件

    # 移动图像至 train 目录
    for image in trainset_images:
        old_img_path = os.path.join(dataset_path, fruit, image)           # 获取原始文件路径
        new_train_path = os.path.join(dataset_path, 'train', fruit, image) # 获取 train 目录的新文件路径
        shutil.move(old_img_path, new_train_path) # 移动文件
    
    # 删除旧文件夹
    assert len(os.listdir(old_dir)) == 0 # 确保旧文件夹中的所有图像都被移动走
    shutil.rmtree(old_dir) # 删除文件夹
    
    # 工整地输出每一类别的数据个数
    print('{:^18} {:^18} {:^18}'.format(fruit, len(trainset_images), len(testset_images)))
    
    # 保存到表格中
    df = df.append({'class':fruit, 'trainset':len(trainset_images), 'testset':len(testset_images)}, ignore_index=True)
# 重命名数据集文件夹
shutil.move(dataset_path, dataset_name+'_split')

# 数据集各类别数量统计表格，导出为 csv 文件
df['total'] = df['trainset'] + df['testset']
df.to_csv('数据量统计.csv', index=False)

        类别              训练集数据个数            测试集数据个数      
       哈密瓜                121                 30        
       圣女果                123                 30        
        山竹                115                 28        
        杨梅                120                 29        
        柚子                119                 29        
        柠檬                 96                 23        
        桂圆                123                 30        
        梨                 121                 30        
        椰子                124                 30        
        榴莲                119                 29        
       火龙果                117                 29        
       猕猴桃                120                 30        
        石榴                120                 30        
       砂糖橘                114                 28        
       胡萝卜                117                 29        
        脐橙                121                 30        
        芒果                106  

In [10]:
df

Unnamed: 0,class,trainset,testset,total
0,哈密瓜,121.0,30.0,151.0
1,圣女果,123.0,30.0,153.0
2,山竹,115.0,28.0,143.0
3,杨梅,120.0,29.0,149.0
4,柚子,119.0,29.0,148.0
5,柠檬,96.0,23.0,119.0
6,桂圆,123.0,30.0,153.0
7,梨,121.0,30.0,151.0
8,椰子,124.0,30.0,154.0
9,榴莲,119.0,29.0,148.0


## 2、按照 MMPreTrain CustomDataset 格式组织训练集和验证集
## 3、在水果数据集上进行微调训练
## 4、使用 MMPreTrain 算法库，编写配置文件，正确加载预训练模型

In [11]:
# 以resnet50_8xb32_in1k为基础进行配置文件config的编写

In [12]:
cd projects/fruit30

F:\mmpretrain\projects\fruit30


In [13]:
!mim train mmpretrain resnet50_finetune.py --work-dir=./exp

Training command is C:\Users\y\anaconda3\envs\openmmlab\python.exe f:\mmpretrain\mmpretrain\.mim\tools\train.py resnet50_finetune.py --launcher none --work-dir=./exp. 
06/10 14:42:28 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: win32
    Python: 3.8.16 (default, Mar  2 2023, 03:18:16) [MSC v.1916 64 bit (AMD64)]
    CUDA available: True
    numpy_random_seed: 1492331612
    GPU 0: NVIDIA GeForce RTX 3090
    CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7
    NVCC: Cuda compilation tools, release 11.7, V11.7.64
    MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.35.32217.1 版
    GCC: n/a
    PyTorch: 2.0.0
    PyTorch compiling details: PyTorch built with:
  - C++ Version: 199711
  - MSVC 193431937
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37

## 5、使用 MMPreTrain 的 ImageClassificationInferencer 接口，对网络水果图像，或自己拍摄的水果图像，使用训练好的模型进行分类

In [14]:
from mmpretrain import ImageClassificationInferencer

In [16]:
cd ..

F:\mmpretrain\projects


In [17]:
cd ..

F:\mmpretrain


In [18]:
inferencer = ImageClassificationInferencer('projects/fruit30/resnet50_finetune.py', pretrained='projects/fruit30/exp/epoch_5.pth')

Loads checkpoint by local backend from path: projects/fruit30/exp/epoch_5.pth


In [19]:
inferencer("projects/fruit30/li.jpg", show=True)



  s, (width, height) = canvas.print_to_buffer()
  s, (width, height) = canvas.print_to_buffer()


[{'pred_scores': array([1.1979168e-01, 1.2588964e-06, 2.2537324e-05, 4.9358977e-07,
         3.7423637e-02, 5.0290688e-03, 4.0473728e-04, 7.8806341e-01,
         5.6093977e-05, 1.7782133e-05, 1.1162100e-06, 4.5782113e-03,
         2.2903263e-05, 7.8769699e-06, 1.0055930e-06, 8.3296269e-04,
         4.3353181e-02, 5.5623860e-07, 2.9802014e-04, 3.6997197e-05,
         8.2375857e-07, 2.5464797e-06, 1.7265518e-05, 1.5533331e-05,
         9.3908022e-07, 7.7071245e-06, 1.6478525e-06, 2.9913932e-07,
         6.3702346e-06, 3.3079939e-06], dtype=float32),
  'pred_label': 7,
  'pred_score': 0.7880634069442749,
  'pred_class': '梨'}]