In [None]:
import torch.nn
import torchvision

In [None]:
# 模型定义
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(stride=2, kernel_size=2)     #最大池化
        )
        self.dense = torch.nn.Sequential(   #全连接层
            torch.nn.Linear(128 * 14 * 14, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024, 10)
        )

    # 前向传播函数
    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 128 * 14 * 14)
        x = self.dense(x)
        return x


In [None]:
train_dataset = torchvision.datasets.MNIST(root='.',
        train=True, transform=torchvision.transforms.ToTensor(),
        download=False)         # 训练数据
test_dataset = torchvision.datasets.MNIST(root='.',
        train=False, transform=torchvision.transforms.ToTensor(),
        download=False)         # 测试数据

batch_size = 100

train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=batch_size)

In [None]:
net = Net()     # 实例化网络
criterion = torch.nn.CrossEntropyLoss()     # 定义交叉熵损失函数
optimizer = torch.optim.Adam(net.parameters())  # Adam优化

num_epochs = 5

for epoch in range(num_epochs):
    for idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        preds = net(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if idx % 100 == 0:
            print('epoch [%d/%d], Step [%d/%d], Loss = {%.4f}'
                % (epoch + 1, num_epochs, idx,
                len(train_dataset) // batch_size, loss.data))

In [None]:
correct = 0
total = 0

for images, labels in test_loader:
    outputs = net(images)
    pred = torch.argmax(outputs, 1)
    total += labels.size(0)
    correct += (pred == labels).sum().item()

accuracy = correct / total
print('Accuracy of the network on the 10000 test images: %d %%'
    % (100 * accuracy))     # 输出准确率