# LeNet + MNIST

http://yann.lecun.com/exdb/lenet/

http://yann.lecun.com/exdb/mnist/

In [5]:
import os
from urllib.parse import urlparse
import argparse
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
import mindspore.ops.operations as P
from mindspore.common.initializer import TruncatedNormal

# import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.vision.c_transforms as CV

import mindspore.dataset.transforms.c_transforms as C

# from mindspore.dataset.transforms.vision import Inter
from mindspore.dataset.vision import Inter

from mindspore.nn.metrics import Accuracy
from mindspore.common import dtype as mstype
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits


# import mindspore.dataset.vision.py_transforms as py_vision
# from mindspore.dataset.transforms import c_transforms

def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    """ create dataset for train or test
    Args:
        data_path: Data path
        batch_size: The number of data records in each group
        repeat_size: The number of replicated data records
        num_parallel_workers: The number of parallel workers
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)

    # define operation parameters
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Resize images to (32, 32)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images
    rescale_op = CV.Rescale(rescale, shift) # rescale images
    hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
    type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network

    # apply map operations on images
    mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    """Conv layer weight initial."""
    weight = weight_variable()
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=kernel_size, stride=stride, padding=padding,
                     weight_init=weight, has_bias=False, pad_mode="valid")


def fc_with_initialize(input_channels, out_channels):
    """Fc layer weight initial."""
    weight = weight_variable()
    bias = weight_variable()
    return nn.Dense(input_channels, out_channels, weight, bias)


def weight_variable():
    """Weight initial."""
    return TruncatedNormal(0.02)


class LeNet5(nn.Cell):
    """Lenet network structure."""
    # define the operator required
    def __init__(self):
        super(LeNet5, self).__init__()
        self.batch_size = 32
        self.conv1 = conv(1, 6, 5)
        self.conv2 = conv(6, 16, 5)
        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
        self.fc2 = fc_with_initialize(120, 84)
        self.fc3 = fc_with_initialize(84, 10)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.reshape = P.Reshape()

    # use the preceding operators to construct networks
    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.reshape(x, (self.batch_size, -1))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


def train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb):
    """Define the training method."""
    print("============== Starting Training ==============")
    # load training dataset
    ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)


def test_net(args, network, model, mnist_path):
    """Define the evaluation method."""
    print("============== Starting Testing ==============")
    # load the saved model for evaluation
    param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
    # load parameter to the network
    load_param_into_net(network, param_dict)
    # load testing dataset
    ds_eval = create_dataset(os.path.join(mnist_path, "test"))
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("============== Accuracy:{} ==============".format(acc))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
    parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
                        help='device where the code will be implemented (default: Ascend)')
    
    # for jupyter notebook
    # args = parser.parse_args()
    args = parser.parse_args(args=['--device_target', 'CPU'])

    # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, mem_Reuse=False)
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

    # learning rate setting
    lr = 0.01
    momentum = 0.9
    epoch_size = 1
    mnist_path = "./MNIST"
    # define the loss function
    # net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
    net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

    repeat_size = epoch_size
    # create the network
    network = LeNet5()
    # define the optimizer
    net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
    config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
    # save the network model and parameters for subsequence fine-tuning
    ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
    # group layers into an object with training and evaluation features
    model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

    train_net(args, model, epoch_size, mnist_path, repeat_size, ckpoint_cb)
    test_net(args, network, model, mnist_path)



epoch: 1 step: 1, loss is 2.3020830154418945
epoch: 1 step: 2, loss is 2.297511100769043
epoch: 1 step: 3, loss is 2.3021178245544434
epoch: 1 step: 4, loss is 2.2974588871002197
epoch: 1 step: 5, loss is 2.2943406105041504
epoch: 1 step: 6, loss is 2.306159734725952
epoch: 1 step: 7, loss is 2.3057701587677
epoch: 1 step: 8, loss is 2.3064053058624268
epoch: 1 step: 9, loss is 2.3006913661956787
epoch: 1 step: 10, loss is 2.312077522277832
epoch: 1 step: 11, loss is 2.296149969100952
epoch: 1 step: 12, loss is 2.3007116317749023
epoch: 1 step: 13, loss is 2.3053767681121826
epoch: 1 step: 14, loss is 2.3007004261016846
epoch: 1 step: 15, loss is 2.304736614227295
epoch: 1 step: 16, loss is 2.3093299865722656
epoch: 1 step: 17, loss is 2.2995779514312744
epoch: 1 step: 18, loss is 2.2984941005706787
epoch: 1 step: 19, loss is 2.293814182281494
epoch: 1 step: 20, loss is 2.3049755096435547
epoch: 1 step: 21, loss is 2.300083875656128
epoch: 1 step: 22, loss is 2.3048899173736572
epoch: 



