In [26]:
import os
import ast
import argparse
from src.config import mnist_cfg as cfg
from src.dataset import create_dataset
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed

set_seed(1)

In [27]:
if __name__ == "__main__":
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    # dataset
    ds_train = create_dataset(os.path.join('.\MNIST_DATA', "train"),
                            cfg.batch_size)
    # network
    network = LeNet5(cfg.num_classes)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    
    # save checkpoint
    config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
                                 keep_checkpoint_max=cfg.keep_checkpoint_max)
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory='./ckpt', config=config_ck)
    
    # model
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
    
    # train
    print("============== Starting Training ==============")
    model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
                dataset_sink_mode=False)



epoch: 1 step: 1, loss is 2.302579402923584
epoch: 1 step: 2, loss is 2.3028666973114014
epoch: 1 step: 3, loss is 2.3025550842285156
epoch: 1 step: 4, loss is 2.3025496006011963
epoch: 1 step: 5, loss is 2.3027262687683105
epoch: 1 step: 6, loss is 2.3019509315490723
epoch: 1 step: 7, loss is 2.3015546798706055
epoch: 1 step: 8, loss is 2.303046941757202
epoch: 1 step: 9, loss is 2.304112434387207
epoch: 1 step: 10, loss is 2.3022210597991943
epoch: 1 step: 11, loss is 2.3011741638183594
epoch: 1 step: 12, loss is 2.3022422790527344
epoch: 1 step: 13, loss is 2.299929618835449
epoch: 1 step: 14, loss is 2.302668809890747
epoch: 1 step: 15, loss is 2.298689603805542
epoch: 1 step: 16, loss is 2.3024749755859375
epoch: 1 step: 17, loss is 2.3005356788635254
epoch: 1 step: 18, loss is 2.302804708480835
epoch: 1 step: 19, loss is 2.297750473022461
epoch: 1 step: 20, loss is 2.3072292804718018
epoch: 1 step: 21, loss is 2.297698497772217
epoch: 1 step: 22, loss is 2.302112102508545
epoch: 