In [3]:
# 导入必要的模块和库
import os
import mindspore
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.dataset import vision, transforms
from mindspore.train import Model
from mindspore.common.initializer import Normal
from mindspore.dataset import MnistDataset

# 定义LeNet5网络结构
class LeNet5(nn.Cell):
    """
    LeNet5卷积神经网络模型。
    参数:
    - num_class (int): 类别数量，默认为10。
    - num_channel (int): 输入通道数量，默认为1。
    - include_top (bool): 是否包括顶层（全连接层），默认为True。
    """
    def __init__(self, num_class=10, num_channel=1, include_top=True):
        super(LeNet5, self).__init__()
        # 定义网络的卷积层和激活函数
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.include_top = include_top  # 是否包含顶层的标志
        # 如果包含顶层，则定义全连接层
        if self.include_top:
            self.flatten = nn.Flatten()
            self.fc1 = nn.Dense(256, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

    """
    构造函数，定义网络的前向传播路径。
    """
    def construct(self, x):
        # 通过卷积层、激活函数和池化层处理输入
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        # 如果不包含顶层，则在池化层后直接返回
        if not self.include_top:
            return x
        # 如果包含顶层，则将特征图展平并经过全连接层处理
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建LeNet5模型实例
model = LeNet5()

In [4]:
def datapipe(path, batch_size):
    """
    创建一个数据管道，用于加载和预处理MNIST数据集。

    参数:
    path (str): 数据集的路径，分为训练集和测试集。
    batch_size (int): 批处理的大小。

    返回:
    dataset: 经过预处理和批处理后的数据集。
    """
    # 图像预处理变换列表
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),  # 图像缩放到0-1范围
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),  # 图像标准化
        vision.HWC2CHW()  # 图像通道顺序从HWC转换为CHW
    ]
    # 标签预处理变换
    label_transform = transforms.TypeCast(mindspore.int32)

    # 加载MNIST数据集
    dataset = MnistDataset(path)
    # 应用图像预处理变换
    dataset = dataset.map(image_transforms, 'image')
    # 应用标签预处理变换
    dataset = dataset.map(label_transform, 'label')
    # 批处理
    dataset = dataset.batch(batch_size)
    return dataset

# 加载训练集和测试集
train_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)


In [5]:
# 初始化学习率、训练轮数、损失函数和优化器
lr = 0.01
epochs = 10
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=lr)

# 定义前向传播函数，计算损失和激活值
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 编译前向传播函数，以支持梯度计算和参数更新
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 定义一个训练步骤，包括梯度计算和参数更新
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

# 定义训练循环，遍历数据集并进行训练
def train_loop(model, dataset):
    size = dataset.get_dataset_size()  # 获取数据集大小
    model.set_train()  # 设置模型为训练模式
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)  # 执行一个训练步骤

        # 每隔100个批次打印一次当前损失和批次号
        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

# 定义测试循环，用于评估模型性能
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()  # 获取数据集大小
    model.set_train(False)  # 设置模型为评估模式
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)  # 做预测
        total += len(data)  # 累计样本数
        test_loss += loss_fn(pred, label).asnumpy()  # 累计损失
        correct += (pred.argmax(1) == label).asnumpy().sum()  # 累计正确预测数
    test_loss /= num_batches  # 计算平均损失
    correct /= total  # 计算准确率
    print(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [6]:
# 检查模型文件是否存在，如果存在则加载模型参数，否则训练模型并保存
model_path = 'model.ckpt'
if os.path.exists(model_path):
    # 加载模型参数
    param_dict = mindspore.load_checkpoint(model_path)
    mindspore.load_param_into_net(model, param_dict)
    print('Loaded model from', model_path)
else:
    # 训练模型
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        # 获取训练数据集大小
        size = train_dataset.get_dataset_size()
        model.set_train()
        # 迭代训练数据集，并更新模型参数
        for batch, (data, label) in enumerate(train_dataset.create_tuple_iterator()):
            (loss, _), grads = grad_fn(data, label)
            optimizer(grads)

            # 每隔100个批次，打印当前损失值
            if batch % 100 == 0:
                loss, current = loss.asnumpy(), batch
                print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

        # 在每个epoch结束时，进行模型测试
        num_batches = test_dataset.get_dataset_size()
        model.set_train(False)
        total, test_loss, correct = 0, 0, 0
        # 测试模型性能
        for data, label in test_dataset.create_tuple_iterator():
            pred = model(data)
            total += len(data)
            test_loss += loss_fn(pred, label).asnumpy()
            correct += (pred.argmax(1) == label).asnumpy().sum()
        test_loss /= num_batches
        correct /= total
        # 打印测试结果
        print(f"Test: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    print("Done!")
    # 保存训练好的模型参数
    mindspore.save_checkpoint(model, 'model.ckpt')

Loaded model from model.ckpt


In [7]:
# 遍历测试数据集，并对每个样本进行预测
for i in range(10):
    # 获取一个测试样本及其标签
    img, label = next(iter(test_dataset))
    # 将图像样本调整为适合模型输入的格式
    img = img[i].unsqueeze(0)
    # 将图像数据转换为张量，并指定数据类型
    input_data = Tensor(img, mindspore.float32)
    # 初始化模型
    net = Model(model)
    # 使用模型进行预测
    result = net.predict(input_data)
    # 从预测结果中提取预测类别
    pred = ops.argmax(result,dim=1).item()
    # 打印预测类别和真实标签
    print(f'Predicted class: {pred}, actual value: {label[i]}')

Predicted class: 1, actual value: 1
Predicted class: 1, actual value: 1
Predicted class: 3, actual value: 3
Predicted class: 0, actual value: 0
Predicted class: 2, actual value: 2
Predicted class: 5, actual value: 5
Predicted class: 4, actual value: 4
Predicted class: 2, actual value: 2
Predicted class: 0, actual value: 0
Predicted class: 7, actual value: 7
