### 划分训练集和验证集

In [2]:
import os
 
# 获取数据集文件夹路径
CustomDatasetPath = r'mmpretrain/data/fruit/'
# 获取数据集文件夹下的所有文件
CustomDatasetFile = os.listdir(CustomDatasetPath)
# 如果文件夹中不存在train、val、test文件夹，则创建
dataset_type = ['train', 'val', 'test']
for type in dataset_type:
    if type not in CustomDatasetFile:
        os.mkdir(os.path.join(CustomDatasetPath, type))
    else:
        # 清空文件夹
        os.removedirs(os.path.join(CustomDatasetPath, type))
 
# 遍历所有文件
for fruit_name in CustomDatasetFile:
    for type in dataset_type:
        os.mkdir(os.path.join(CustomDatasetPath, type, fruit_name))
    # 水果文件夹路径
    fruit_path = os.path.join(CustomDatasetPath, fruit_name)
    # 获取水果文件夹下的所有文件
    fruit_file = os.listdir(fruit_path)
    # 将水果文件夹下的所有文件分为训练集、验证集、测试集
    train_file = fruit_file[:int(len(fruit_file)*0.8)]
    val_file = fruit_file[int(len(fruit_file)*0.8):int(len(fruit_file)*0.9)]
    test_file = fruit_file[int(len(fruit_file)*0.9):]
    # 将训练集、验证集、测试集分别放入对应文件夹
    for file in train_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'train', fruit_name, file))
    for file in val_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'val', fruit_name, file))
    for file in test_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'test', fruit_name, file))
    # 删除空文件夹
    os.removedirs(fruit_path)

In [2]:
# 查看训练集内容

!tree mmpretrain/data/fruit/train/ --filelimit=30

mmpretrain/data/fruit/train/
├── 哈密瓜 [120 entries exceeds filelimit, not opening dir]
├── 圣女果 [122 entries exceeds filelimit, not opening dir]
├── 山竹 [114 entries exceeds filelimit, not opening dir]
├── 杨梅 [119 entries exceeds filelimit, not opening dir]
├── 柚子 [118 entries exceeds filelimit, not opening dir]
├── 柠檬 [95 entries exceeds filelimit, not opening dir]
├── 桂圆 [122 entries exceeds filelimit, not opening dir]
├── 梨 [120 entries exceeds filelimit, not opening dir]
├── 椰子 [123 entries exceeds filelimit, not opening dir]
├── 榴莲 [118 entries exceeds filelimit, not opening dir]
├── 火龙果 [116 entries exceeds filelimit, not opening dir]
├── 猕猴桃 [120 entries exceeds filelimit, not opening dir]
├── 石榴 [120 entries exceeds filelimit, not opening dir]
├── 砂糖橘 [113 entries exceeds filelimit, not opening dir]
├── 胡萝卜 [116 entries exceeds filelimit, not opening dir]
├── 脐橙 [120 entries exceeds filelimit, not opening dir]
├── 芒果 [105 entries exceeds filelimit, not opening dir]
├── 苦瓜 [115 ent

### 使用 MMPreTrain 算法库，编写配置文件，正确加载预训练模型

resnet50_fintune.py

In [None]:
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=30,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5)),
    init_cfg=dict(
        type='Pretrained',
        checkpoint=
        'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth'
    ))
dataset_type = 'CustomDataset'
data_preprocessor = dict(
    num_classes=30,
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackInputs')
]
train_dataloader = dict(
    pin_memory=True,
    persistent_workers=True,
    collate_fn=dict(type='default_collate'),
    batch_size=32,
    num_workers=12,
    dataset=dict(
        type='CustomDataset',
        data_root='data/fruit/train',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='RandomResizedCrop', scale=224),
            dict(type='RandomFlip', prob=0.5, direction='horizontal'),
            dict(type='PackInputs')
        ]),
    sampler=dict(type='DefaultSampler', shuffle=True))
val_dataloader = dict(
    pin_memory=True,
    persistent_workers=True,
    collate_fn=dict(type='default_collate'),
    batch_size=64,
    num_workers=12,
    dataset=dict(
        type='CustomDataset',
        data_root='data/fruit/val',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='ResizeEdge', scale=256, edge='short'),
            dict(type='CenterCrop', crop_size=224),
            dict(type='PackInputs')
        ]),
    sampler=dict(type='DefaultSampler', shuffle=False))
val_evaluator = dict(type='Accuracy', topk=(1, 5))
test_dataloader = dict(
    pin_memory=True,
    persistent_workers=True,
    collate_fn=dict(type='default_collate'),
    batch_size=64,
    num_workers=12,
    dataset=dict(
        type='CustomDataset',
        data_root='data/fruit/test',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='ResizeEdge', scale=256, edge='short'),
            dict(type='CenterCrop', crop_size=224),
            dict(type='PackInputs')
        ]),
    sampler=dict(type='DefaultSampler', shuffle=False))
test_evaluator = dict(type='Accuracy', topk=(1, 5))
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[3, 6, 9], gamma=0.1)
train_cfg = dict(by_epoch=True, max_epochs=10, val_interval=2)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=256)
default_scope = 'mmpretrain'
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(
        type='CheckpointHook', interval=1, max_keep_ckpts=2, save_best='auto'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='VisualizationHook', enable=False))
env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='UniversalVisualizer', vis_backends=[dict(type='LocalVisBackend')])
log_level = 'INFO'
load_from = None
resume = False
randomness = dict(seed=23, deterministic=False)
launcher = 'none'
work_dir = './exp3_resnet50'

In [None]:
# 训练(建议使用命令行进入mmpretrain文件夹执行)  
!python tools/train.py data/resnet50_fintune.py

## val结果

![result](pic/fruit_result_epoch10.png)

### 使用 MMPreTrain 的 ImageClassificationInferencer 接口，对网络水果图像，或自己拍摄的水果图像，使用训练好的模型进行分类

In [9]:
from mmpretrain import ImageClassificationInferencer
 
inferencer = ImageClassificationInferencer('mmpretrain/data/resnet50_fintune.py',
                                           pretrained='mmpretrain/exp3_resnet50/best_accuracy_top1_epoch_6.pth')
 
image_list = ['mmpretrain/data/apple.jpg', 'mmpretrain/data/banana.jpg' , 'mmpretrain/data/grapes.jpg']
 
# # 单独对每张图片预测
# for i in range(len(image_list)):
#     # result0 = inferencer(image_list[i], show=True)
#     result0 = inferencer(image_list[i])
#     print(f"file name: {image_list[i]}")
#     print(result0[0][list(result0[0].keys())[1]])
#     print(result0[0][list(result0[0].keys())[3]])
#     print()
 
# 批量预测
results = inferencer(image_list, batch_size=4)
print_keys = list(results[0].keys())
for i in range(len(image_list)):
    print(f"file name: {image_list[i]}")
    print(results[i][print_keys[1]])
    print(results[i][print_keys[3]])
    print()

Loads checkpoint by local backend from path: mmpretrain/exp3_resnet50/best_accuracy_top1_epoch_6.pth


file name: mmpretrain/data/apple.jpg
18
苹果-红

file name: mmpretrain/data/banana.jpg
28
香蕉

file name: mmpretrain/data/grapes.jpg
24
葡萄-红

