### 划分训练集和验证集

In [3]:
import os
 
# 获取数据集文件夹路径
CustomDatasetPath = r'mmpretrain/data/'
# 获取数据集文件夹下的所有文件
CustomDatasetFile = os.listdir(CustomDatasetPath)
# 如果文件夹中不存在train、val、test文件夹，则创建
dataset_type = ['train', 'val', 'test']
for type in dataset_type:
    if type not in CustomDatasetFile:
        os.mkdir(os.path.join(CustomDatasetPath, type))
    else:
        # 清空文件夹
        os.removedirs(os.path.join(CustomDatasetPath, type))
 
# 遍历所有文件
for fruit_name in CustomDatasetFile:
    for type in dataset_type:
        os.mkdir(os.path.join(CustomDatasetPath, type, fruit_name))
    # 水果文件夹路径
    fruit_path = os.path.join(CustomDatasetPath, fruit_name)
    # 获取水果文件夹下的所有文件
    fruit_file = os.listdir(fruit_path)
    # 将水果文件夹下的所有文件分为训练集、验证集、测试集
    train_file = fruit_file[:int(len(fruit_file)*0.8)]
    val_file = fruit_file[int(len(fruit_file)*0.8):int(len(fruit_file)*0.9)]
    test_file = fruit_file[int(len(fruit_file)*0.9):]
    # 将训练集、验证集、测试集分别放入对应文件夹
    for file in train_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'train', fruit_name, file))
    for file in val_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'val', fruit_name, file))
    for file in test_file:
        os.rename(os.path.join(fruit_path, file), os.path.join(CustomDatasetPath, 'test', fruit_name, file))
    # 删除空文件夹
    os.removedirs(fruit_path)

In [7]:
# 查看训练集内容

!tree mmpretrain/data/train/ --filelimit=30

mmpretrain/data/train/
├── 山竹 [114 entries exceeds filelimit, not opening dir]
├── 柚子 [118 entries exceeds filelimit, not opening dir]
├── 榴莲 [118 entries exceeds filelimit, not opening dir]
├── 石榴 [120 entries exceeds filelimit, not opening dir]
├── 砂糖橘 [113 entries exceeds filelimit, not opening dir]
├── 苦瓜 [115 entries exceeds filelimit, not opening dir]
├── 苹果-青 [120 entries exceeds filelimit, not opening dir]
└── 葡萄-白 [99 entries exceeds filelimit, not opening dir]

8 directories, 0 files


### 使用 MMPreTrain 算法库，编写配置文件，正确加载预训练模型

resnet18_fintune.py

In [None]:
# _base_ = [
#     '../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
#     '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
# ]
 
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=30,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1,5),
    ),
    init_cfg = dict(type='Pretrained',
                    checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth')
)
 
 
 
# dataset settings
dataset_type = 'CustomDataset'
data_preprocessor = dict(
    num_classes=30,
    # RGB format normalization parameters
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    # convert image from BGR to RGB
    to_rgb=True,
)
 
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]
 
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackInputs'),
]
 
train_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='fruit/train',
        # ann_file='meta/train.txt',
        # data_prefix='train',
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)
 
val_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        # data_root='data/imagenet',
        data_root='fruit/val',
        # ann_file='meta/val.txt',
        # data_prefix='val',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
test_dataloader = dict(
    batch_size=64,
    num_workers=12,
    dataset=dict(
        type=dataset_type,
        # data_root='data/imagenet',
        data_root='fruit/test',
        # ann_file='meta/val.txt',
        # data_prefix='val',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1,))
 
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
 
 
# optimizer
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
 
# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[3, 6, 9], gamma=0.5)
 
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=10, val_interval=2)
val_cfg = dict()
test_cfg = dict()
 
# # NOTE: `auto_scale_lr` is for automatically scaling LR,
# # based on the actual training batch size.
# auto_scale_lr = dict(base_batch_size=256)
 
 
# defaults to use registries in mmpretrain
default_scope = 'mmpretrain'
 
# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type='IterTimerHook'),
 
    # print log every 100 iterations.
    logger=dict(type='LoggerHook', interval=2),
 
    # enable the parameter scheduler.
    param_scheduler=dict(type='ParamSchedulerHook'),
 
    # save checkpoint per epoch.
    checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2, save_best='auto'),
 
    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type='DistSamplerSeedHook'),
 
    # validation results visualization, set True to enable it.
    visualization=dict(type='VisualizationHook', enable=False),
)
 
# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,
 
    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
 
    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)
 
# set visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends)
 
# set log level
log_level = 'INFO'
 
# load from which checkpoint
load_from = None
 
# whether to resume training from the loaded checkpoint
resume = False
 
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=23, deterministic=False)

In [None]:
# 训练(建议使用命令行进入mmpretrain文件夹执行)  
!python tools/train.py data/resnet18_fintune.py

### 使用 MMPreTrain 的 ImageClassificationInferencer 接口，对网络水果图像，或自己拍摄的水果图像，使用训练好的模型进行分类

In [None]:
from mmpretrain import ImageClassificationInferencer
 
inferencer = ImageClassificationInferencer('mmpretrain/resnet18_fintuneM.py',
                                           pretrained='mmpretrain/work_dirs/resnet18_fintuneM/epoch_10.pth')
 
image_list = ['mmpretrain/data/apple.jpeg', 'mmpretrain/data/banana.jpeg', 'mmpretrain/data/fruit.jpeg', 'mmpretrain/data/grapes.jpg']
 
# # 单独对每张图片预测
# for i in range(len(image_list)):
#     # result0 = inferencer(image_list[i], show=True)
#     result0 = inferencer(image_list[i])
#     print(f"file name: {image_list[i]}")
#     print(result0[0][list(result0[0].keys())[1]])
#     print(result0[0][list(result0[0].keys())[3]])
#     print()
 
# 批量预测
results = inferencer(image_list, batch_size=4)
print_keys = list(results[0].keys())
for i in range(len(image_list)):
    print(f"file name: {image_list[i]}")
    print(results[i][print_keys[1]])
    print(results[i][print_keys[3]])
    print()
 