In [1]:
import mindspore.nn as nn
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore import context
import matplotlib.pyplot as plt
from mindspore.train.callback import Callback
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
import os
from mindspore import Model
from mindspore.nn import Accuracy
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from VGG7 import VGG7
from Resnet12 import ResNet12

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


def create_dataset(data_home, repeat_num=1, batch_size=32, do_train=True, device_target="GPU"):
    """
    create data for next use such as training or inferring
    """

    cifar_ds = ds.Cifar10Dataset(data_home, num_parallel_workers=8, shuffle=True)

    c_trans = []
    if do_train:
        c_trans += [
            C.RandomCrop((32, 32), (4, 4, 4, 4)),
            C.RandomHorizontalFlip(prob=0.5)
        ]

    c_trans += [
        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        C.HWC2CHW()
    ]

    type_cast_op = C2.TypeCast(mstype.int32)

    cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
    cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=8)

    cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
    cifar_ds = cifar_ds.repeat(repeat_num)

    return cifar_ds


class StepLossAccInfo(Callback):
    def __init__(self, model, eval_dataset, steps_loss, steps_eval):
        self.model = model
        self.eval_dataset = eval_dataset
        self.steps_loss = steps_loss
        self.steps_eval = steps_eval

    def step_end(self, run_context):
        cb_params = run_context.original_args()
        # cur_epoch = cb_params.cur_epoch_num
        # cur_step = (cur_epoch-1)*1562 + cb_params.cur_step_num
        cur_step = cb_params.cur_step_num
        if cur_step % 10 == 0:
            self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
            self.steps_loss["step"].append(str(cur_step))
        if cur_step % 100 == 0:
            acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
            self.steps_eval["step"].append(cur_step)
            self.steps_eval["acc"].append(acc["Accuracy"])


def train(net_type, epoch_size):
    # 设置网络结构
    if net_type == "vgg7":
        net = VGG7()
    elif net_type == 'resnet12':
        net = ResNet12()
    # 设置损失函数和优化器
    loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    optimizer = nn.Adam(net.trainable_params(), learning_rate=1e-3)
    # 初始化模型
    model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={"Accuracy": Accuracy()})
    # 设置训练数据集和测试数据集
    ds_train_path = "./datasets/cifar10/train/"
    ds_test_path = "./datasets/cifar10/test/"
    ds_train = create_dataset(ds_train_path)
    ds_test = create_dataset(ds_test_path)
    # 设置模型保存的路径，删除掉之前训练保存的模型
    model_path = "./models/ckpt/mindspore_vision_application/"
    os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))
    # 得到每个batch的训练步数
    batch_num = ds_train.get_dataset_size()
    # 设置与模型保存相关的参数
    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1)
    ckpoint_cb = ModelCheckpoint(prefix="train_" + net_type + "_cifar10", directory=model_path, config=config_ck)
    # 设置损失记录器，参数是打印Loss信息的步长，142的意思是每142步打印一次loss信息
    loss_cb = LossMonitor(142)
    # 设置记录损失的数据结构，用于在后面打印曲线图像
    steps_loss = {"step": [], "loss_value": []}
    steps_eval = {"step": [], "acc": []}
    step_loss_acc_info = StepLossAccInfo(model, ds_test, steps_loss, steps_eval)
    # 执行模型的训练
    print("begin training...")
    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, loss_cb, step_loss_acc_info], dataset_sink_mode=False)
    # model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, loss_cb, step_loss_acc_info])
    # 训练完后，对模型进行测试，打印正确率
    res = model.eval(ds_test)
    print("result: ", res)
    # 利用训练中保存的信息，打印训练损失曲线和训练正确率曲线
    steps = steps_loss["step"]
    loss_value = steps_loss["loss_value"]
    steps = list(map(int, steps))
    loss_value = list(map(float, loss_value))
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot(steps, loss_value, color="red")
    plt.xlabel("Steps")
    plt.ylabel("Loss_value")
    plt.title("Change chart of model loss value")
    steps = steps_eval["step"]
    acc_value = steps_eval["acc"]
    steps = list(map(int, steps))
    loss_value = list(map(float, acc_value))
    plt.subplot(1, 2, 2)
    plt.plot(steps, acc_value, color="blue")
    plt.xlabel("Steps")
    plt.ylabel("Acc_value")
    plt.title("Change chart of model acc value")
    plt.show()

In [None]:
train("resnet12", 3)

In [None]:
train("vgg7", 1)