## 环境检测

In [2]:
import torch

print(torch.__version__)
torch.cuda.is_available()

2.0.0+cu118


True

## 环境安装
因为要基于 MMPretrain 进行训练，所以基于源码使用；如果只要使用其中几个模块，直接安装即可。

命令行中执行：
```
git clone https://github.com/open-mmlab/mmpretrain
cd mmpretrain
pip install openmim
```

使用 `mim --help` 查看帮助信息；
使用以下命令安装多模态所需的额外依赖：
```
mim install -e ".[multimodal]"
```

In [None]:
# 以下命令在 IPython 中执行

# 查看 mmpretrain 的版本
import mmpretrain

print(mmpretrain.__version__)

# 使用预训练模型查看推理效果
from mmpretrain import get_model, list_models, inference_model

# 查看模型
list_models(task='Image Classification', pattern='resnet18')
list_models(task='Image Caption', pattern='blip') 

# 查看模型类型
model = get_model('resnet18_8xb16_cifar10') 
type(model)

# 查看模型的骨干网络
model = get_model('resnet18_8xb32_in1k') 
type(model.backbone)

## 推理测试

In [None]:
import mmpretrain
from mmpretrain import get_model, list_models, inference_model

list_models(task='Image Caption', pattern='blip')
inference_model('blip-base_3rdparty_caption', 'demo/cat-dog.png',show=True) 

## 模型训练
### 准备数据

In [15]:
import os
import random
import shutil

os.chdir('mmpretrain')

In [None]:
# 定义常量
DATA_DIR = "data"
NUM_IMAGES_PER_FRUIT = 100
NUM_IMAGES_TO_CUT = 10

# 获取所有水果分类文件夹的名称
fruit_dirs = os.listdir(DATA_DIR)

# 循环遍历每个水果分类文件夹
for fruit_dir in fruit_dirs:
    # 构造水果分类文件夹的路径和新文件夹的路径
    fruit_path = os.path.join(DATA_DIR, fruit_dir)
    new_fruit_path = os.path.join(DATA_DIR, "new_" + fruit_dir)
    
    # 如果新文件夹不存在，则创建它
    if not os.path.exists(new_fruit_path):
        os.mkdir(new_fruit_path)
    
    # 获取当前水果分类文件夹中所有图片的路径
    image_paths = [os.path.join(fruit_path, img) for img in os.listdir(fruit_path)]
    
    # 随机从当前水果分类文件夹中选择10张图片，并将它们剪切到新文件夹中
    selected_images = random.sample(image_paths, NUM_IMAGES_TO_CUT)
    for img_path in selected_images:
        img_name = os.path.basename(img_path)
        new_img_path = os.path.join(new_fruit_path, img_name)
        shutil.move(img_path, new_img_path)

上述代码从每个水果分类中随机抽取 10 张图片作为验证集。

## 模型选择
在命令行中执行以下命令查看模型对应的配置文件并查看：
```
ls configs/resnet/
vi configs/resnet/resnet50_8xb32_in1k.py
```
显示如下：
```
_base_ = [
    '../_base_/models/resnet50.py', '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
```
分别对应：模型(model)、数据(data)、训练策略(schedule)、运行设置(runtime)

## 编写配置文件
从上述四个文件中复制内容到一个配置文件中。
```
mkdir project/fruit
vi fruit_resnet50_finetune.py
```

```
############################################ model settings ###########################################
# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=30,
        in_channels=2048,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ),
    init_cfg=dict(type='Pretrained', checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth')
)

############################################ dataset settings###########################################
dataset_type = 'ImageNet'
data_preprocessor = dict(
    num_classes=1000,
    # RGB format normalization parameters
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    # convert image from BGR to RGB
    to_rgb=True,
)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackInputs'),
]

train_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/train.txt',
        data_prefix='train',
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/val.txt',
        data_prefix='val',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

############################################ schedules ###########################################
# optimizer
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=256)

############################################ runtime ###########################################
default_scope = 'mmpretrain'

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type='IterTimerHook'),

    # print log every 100 iterations.
    logger=dict(type='LoggerHook', interval=100),

    # enable the parameter scheduler.
    param_scheduler=dict(type='ParamSchedulerHook'),

    # save checkpoint per epoch.
    checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=5, save_best='auto'),

    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type='DistSamplerSeedHook'),

    # validation results visualization, set True to enable it.
    visualization=dict(type='VisualizationHook', enable=False),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,

    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),

    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends)

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
```

## 开始训练
在 project 路径下，命令行中执行：
```
mim train mmpretrain fruit_resnet50_finetune.py --work-dir=./exp
```