## 一.导入包

In [None]:
import os
import ast
import argparse
from src.config import mnist_cfg as cfg
from src.dataset import create_dataset
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed

## 二.设备设置

In [None]:
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')

In [None]:
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
                        help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
                        help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
                        path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, default=True,
                        help='dataset_sink_mode is False or True')

args = parser.parse_known_args()[0]

## 三.数据集处理

In [None]:
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size)
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)

network = LeNet5(cfg.num_classes)

## 四.准备工作

In [None]:
#设定loss函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
 #设定优化器
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
                             keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck)
#编译形成模型
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

## 五.开始训练

In [None]:
model.train(
    cfg["epoch_size"],
    ds_train,
    callbacks=[time_cb, ckpoint_cb, LossMonitor()],
    dataset_sink_mode=args.dataset_sink_mode,
)