In [None]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np

learning_rate = 1e-4  # 学习率
keep_prob_rate = 0.7  # Dropout 保留的概率
max_epoch = 3  # 最大训练轮数
BATCH_SIZE = 50  # 批处理大小

DOWNLOAD_MNIST = False  # 是否下载 MNIST 数据集
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):  # 如果 MNIST 数据集文件夹不存在或为空
    DOWNLOAD_MNIST = True  # 设置为下载数据集

# 加载训练数据集
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)  # 创建数据加载器

# 加载测试数据集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:500] / 255.  # 处理测试数据
test_y = test_data.test_labels[:500].numpy()  # 处理测试标签

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(  # 卷积层1：输入通道1（MNIST是灰度图像），输出通道32，卷积核大小7x7，步幅1，填充3
                in_channels=1,
                out_channels=32,
                kernel_size=7,
                stride=1,
                padding=3,
            ),
            nn.ReLU(),        # 激活函数
            nn.MaxPool2d(2),  # 最大池化操作，池化窗口为2x2
        )
        self.conv2 = nn.Sequential(  # 卷积层2：输入通道32，输出通道64，卷积核大小5x5，步幅1，填充2
            nn.Conv2d(  
                in_channels=32,
                out_channels=64,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),        # 激活函数
            nn.MaxPool2d(2),  # 最大池化操作，池化窗口为2x2
        )
        self.out1 = nn.Linear(7*7*64, 1024, bias=True)   # 全连接层1

        self.dropout = nn.Dropout(keep_prob_rate)  # Dropout 层
        self.out2 = nn.Linear(1024, 10, bias=True)  # 输出层，10个输出对应10个分类

    def forward(self, x):
        x = self.conv1(x)  # 卷积层1
        x = self.conv2(x)  # 卷积层2
        x = x.view(x.size(0), -1)  # 展平conv2层的输出，形状为(batch_size, 7*7*64)
        out1 = self.out1(x)  # 全连接层1
        out1 = F.relu(out1)  # 使用ReLU激活函数
        out1 = self.dropout(out1)  # Dropout操作
        out2 = self.out2(out1)  # 输出层
        output = F.softmax(out2, dim=1)  # 使用Softmax函数获得最终的分类概率
        return output


def test(cnn):
    global prediction
    y_pre = cnn(test_x)  # 获取模型的预测输出
    _, pre_index = torch.max(y_pre, 1)  # 获取预测结果的最大值索引
    pre_index = pre_index.view(-1)  # 转换成一维
    prediction = pre_index.data.numpy()  # 转换为numpy数组
    correct = np.sum(prediction == test_y)  # 计算正确预测的数量
    return correct / 500.0  # 返回测试准确率


def train(cnn):
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)  # Adam优化器
    loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数
    for epoch in range(max_epoch):  # 遍历最大训练轮数
        for step, (x_, y_) in enumerate(train_loader):  # 遍历每个batch
            x, y = Variable(x_), Variable(y_)
            output = cnn(x)  # 前向传播
            loss = loss_func(output, y)  # 计算损失
            optimizer.zero_grad()  # 清除之前的梯度
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            if step != 0 and step % 20 == 0:  # 每20步输出一次测试准确率
                print("=" * 10, step, "=" * 5, "=" * 5, "test accuracy is ", test(cnn), "=" * 10)


if __name__ == '__main__':
    cnn = CNN()  # 初始化CNN模型
    train(cnn)  # 开始训练


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:11<00:00, 882kB/s] 


Extracting ./mnist/MNIST\raw\train-images-idx3-ubyte.gz to ./mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:01<00:00, 17.2kB/s]


Extracting ./mnist/MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:07<00:00, 233kB/s] 


Extracting ./mnist/MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.10MB/s]
  test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:500] / 255.


Extracting ./mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist/MNIST\raw

