In [16]:
import os
import mindspore as ms
# 导入mindspore中context模块，用于配置当前执行环境，包括执行模式等特性。
import mindspore.context as context
# c_transforms模块提供常用操作，包括OneHotOp和TypeCast
import mindspore.dataset.transforms as C
# vision.c_transforms模块是处理图像增强的高性能模块，用于数据增强图像数据改进训练模型。
import mindspore.dataset.vision as CV
import numpy as np
from mindspore import nn
from mindspore.nn import Accuracy
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
import matplotlib.pyplot as plt
from download import download
# 设置MindSpore的执行模式和设备
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') # Ascend, CPU, GPU

In [17]:
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:00<00:00, 18.5MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./


In [11]:
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    data_train = os.path.join(data_dir, 'train') # 训练集信息
    data_test = os.path.join(data_dir, 'test') # 测试集信息
    ds = ms.dataset.MnistDataset(data_train if training else data_test)
    ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)
    return ds

In [12]:
class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        #设置卷积网络（输入输出通道数，卷积核尺寸，步长，填充方式）
        self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(400, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
    #构建网络
    def construct(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [14]:
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=5):
    ds_train = create_dataset(data_dir)
    ds_eval = create_dataset(data_dir, training=False)
    net = LeNet5()
    #计算softmax交叉熵。
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设置Momentum优化器
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
    metrics = {"Accuracy": Accuracy(), "Confusion_matrix": nn.ConfusionMatrix(num_classes=10)}
    model = Model(net, loss, opt, metrics)
    model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=True)
    metrics_result = model.eval(ds_eval)
    res = metrics_result["Confusion_matrix"]
    print('Accuracy:',metrics_result["Accuracy"])
    print('Confusion_matrix:', res)
    return res

In [15]:
train('./MNIST_Data')

epoch: 1 step: 1875, loss is 0.0660281628370285
epoch: 2 step: 1875, loss is 0.0023815338499844074
epoch: 3 step: 1875, loss is 0.19258147478103638
epoch: 4 step: 1875, loss is 0.04538374021649361
epoch: 5 step: 1875, loss is 0.062643863260746
Accuracy: 0.98828125
Confusion_matrix: [[9.720e+02 1.000e+00 1.000e+00 0.000e+00 0.000e+00 1.000e+00 1.000e+00
  0.000e+00 0.000e+00 2.000e+00]
 [0.000e+00 1.124e+03 2.000e+00 3.000e+00 1.000e+00 1.000e+00 0.000e+00
  1.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 1.000e+00 1.026e+03 0.000e+00 4.000e+00 0.000e+00 0.000e+00
  1.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 3.000e+00 1.005e+03 0.000e+00 0.000e+00 0.000e+00
  1.000e+00 1.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 9.730e+02 0.000e+00 2.000e+00
  0.000e+00 0.000e+00 4.000e+00]
 [1.000e+00 0.000e+00 2.000e+00 1.100e+01 1.000e+00 8.690e+02 1.000e+00
  1.000e+00 0.000e+00 6.000e+00]
 [3.000e+00 2.000e+00 1.000e+00 1.000e+00 3.000e+00 1.000e+00 9.450e+02
  0.000e+00 0.0

array([[9.720e+02, 1.000e+00, 1.000e+00, 0.000e+00, 0.000e+00, 1.000e+00,
        1.000e+00, 0.000e+00, 0.000e+00, 2.000e+00],
       [0.000e+00, 1.124e+03, 2.000e+00, 3.000e+00, 1.000e+00, 1.000e+00,
        0.000e+00, 1.000e+00, 0.000e+00, 0.000e+00],
       [0.000e+00, 1.000e+00, 1.026e+03, 0.000e+00, 4.000e+00, 0.000e+00,
        0.000e+00, 1.000e+00, 0.000e+00, 0.000e+00],
       [0.000e+00, 0.000e+00, 3.000e+00, 1.005e+03, 0.000e+00, 0.000e+00,
        0.000e+00, 1.000e+00, 1.000e+00, 0.000e+00],
       [0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 9.730e+02, 0.000e+00,
        2.000e+00, 0.000e+00, 0.000e+00, 4.000e+00],
       [1.000e+00, 0.000e+00, 2.000e+00, 1.100e+01, 1.000e+00, 8.690e+02,
        1.000e+00, 1.000e+00, 0.000e+00, 6.000e+00],
       [3.000e+00, 2.000e+00, 1.000e+00, 1.000e+00, 3.000e+00, 1.000e+00,
        9.450e+02, 0.000e+00, 0.000e+00, 0.000e+00],
       [1.000e+00, 0.000e+00, 5.000e+00, 1.000e+00, 0.000e+00, 0.000e+00,
        0.000e+00, 1.013e+03, 0.000e+