# 导入相关库

In [None]:
import paddle
import numpy as np
import matplotlib.pyplot as plt

paddle.__version__

  from collections import MutableMapping
  from collections import Iterable, Mapping
  from collections import Sized


'2.0.0'

# ② 数据准备

## 2.1 数据加载和预处理

In [None]:
import paddle.vision.transforms as T

# 数据的加载和预处理
transform = T.Compose([T.Resize((256, 256)), 
                        T.CenterCrop(224), 
                        T.Transpose(),
                        T.Normalize(mean=255, data_format='CHW', to_rgb=True),
                        T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])

# 训练数据集
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)

# 评估数据集
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

print('训练集样本量: {}，验证集样本量: {}'.format(len(train_dataset), len(eval_dataset)))

训练集样本量: 60000，验证集样本量: 10000


# ③ 模型选择和开发

## 3.1 模型组网

In [None]:
network = paddle.vision.models.resnet18(num_classes=100,pretrained=True)



## 模型网络结构可视化

In [None]:
# 模型封装
model = paddle.Model(network)

# 模型可视化
#model.summary((1, 28, 28))

In [None]:
# 配置优化器、损失函数、评估指标
model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=network.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())
              
# 启动模型全流程训练
model.fit(train_dataset,  # 训练数据集
          eval_dataset,   # 评估数据集
          epochs=5,       # 训练的总轮次
          batch_size=64,  # 训练使用的批大小
          verbose=1)      # 日志展示形式

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5


  return (isinstance(seq, collections.Sequence) and
  "When training, we now always track global mean and variance.")


Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
Epoch 2/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
Epoch 3/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
Epoch 4/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
Epoch 5/5
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000


# ⑤ 模型评估测试

## 5.1 模型评估

In [None]:
# 模型评估，根据prepare接口配置的loss和metric进行返回
result = model.evaluate(eval_dataset, verbose=1)

print(result)

Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
{'loss': [1.1920928e-07], 'acc': 0.9926}


# ⑥ 部署上线

## 6.1 保存模型

In [None]:
# 保存用于后续继续调优训练的模型
model.save('finetuning/mnist')

## 6.2 继续调优训练

In [None]:
from paddle.static import InputSpec


# 模型封装，为了后面保存预测模型，这里传入了inputs参数
model_2 = paddle.Model(network, inputs=[InputSpec(shape=[-1, 28, 28], dtype='float32', name='image')])

# 加载之前保存的阶段训练模型
model_2.load('finetuning/mnist')

# 模型配置
model_2.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=network.parameters()),
                paddle.nn.CrossEntropyLoss(),
                paddle.metric.Accuracy())

# 模型全流程训练
model_2.fit(train_dataset, 
            eval_dataset,
            epochs=2,
            batch_size=64,
            verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000
Epoch 2/2
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 10000


## 6.3 保存预测模型

In [None]:
# 保存用于后续推理部署的模型
model_2.save('infer/mnist')

TypeError: can't pickle paddle.fluid.core_avx.BlockDesc objects

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=be243217-dd5d-47ab-9ef2-99bb2e76d181' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>