In [1]:
import os
import ast
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5

In [3]:
if __name__ == "__main__":
    # set network
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    network = LeNet5(cfg.num_classes)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
    repeat_size = cfg.epoch_size
    net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    # test
    print("============== Starting Testing ==============")
    ckpt_file_name = '.\ckpt\checkpoint_lenet_1-1_1875.ckpt'
    param_dict = load_checkpoint(ckpt_file_name=ckpt_file_name)
    load_param_into_net(network, param_dict)
    ds_eval = create_dataset(os.path.join('.\MNIST_DATA', "test"),
                             cfg.batch_size,
                             1)
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("============== {} ==============".format(acc))




