# 卷积神经网络

In [1]:
import torch
import torch.utils.data
import torch.nn
import torch.optim
import torchvision.datasets
import torchvision.transforms

# 数据读取
train_dataset = torchvision.datasets.MNIST(root='./data/mnist',
        train=True, transform=torchvision.transforms.ToTensor(),
        download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/mnist',
        train=False, transform=torchvision.transforms.ToTensor(),
        download=True)

batch_size = 100
train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=batch_size)

# 搭建网络结构
class Net(torch.nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
                torch.nn.Conv2d(1, 64, kernel_size=3, padding=1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(stride=2, kernel_size=2))
        self.dense = torch.nn.Sequential(
                torch.nn.Linear(128 * 14 * 14, 1024),
                torch.nn.ReLU(),
                torch.nn.Dropout(p=0.5),
                torch.nn.Linear(1024, 10))
        
    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 128 * 14 * 14)
        x = self.dense(x)
        return x

net = Net()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters()) 

# 训练
num_epochs = 5
for epoch in range(num_epochs):
    for idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        preds = net(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        
        if idx % 100 == 0:
            print('epoch {}, batch {}, 损失 = {:g}'.format(
                    epoch, idx, loss.item()))

# 测试
correct = 0
total = 0
for images, labels in test_loader:
    preds = net(images)
    predicted = torch.argmax(preds, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    
accuracy = correct / total
print('测试数据准确率: {:.1%}'.format(accuracy))

epoch 0, batch 0, 损失 = 2.30691
epoch 0, batch 100, 损失 = 0.196558
epoch 0, batch 200, 损失 = 0.246614
epoch 0, batch 300, 损失 = 0.0901112
epoch 0, batch 400, 损失 = 0.0661785
epoch 0, batch 500, 损失 = 0.019158
epoch 1, batch 0, 损失 = 0.115891
epoch 1, batch 100, 损失 = 0.0850154
epoch 1, batch 200, 损失 = 0.032694
epoch 1, batch 300, 损失 = 0.0368617
epoch 1, batch 400, 损失 = 0.0111592
epoch 1, batch 500, 损失 = 0.0236338
epoch 2, batch 0, 损失 = 0.105644
epoch 2, batch 100, 损失 = 0.0292429
epoch 2, batch 200, 损失 = 0.0154538
epoch 2, batch 300, 损失 = 0.0861663
epoch 2, batch 400, 损失 = 0.00821032
epoch 2, batch 500, 损失 = 0.00200433
epoch 3, batch 0, 损失 = 0.0112382
epoch 3, batch 100, 损失 = 0.0130245
epoch 3, batch 200, 损失 = 0.0235564
epoch 3, batch 300, 损失 = 0.0563519
epoch 3, batch 400, 损失 = 0.00335572
epoch 3, batch 500, 损失 = 0.00445928
epoch 4, batch 0, 损失 = 0.0196497
epoch 4, batch 100, 损失 = 0.00330122
epoch 4, batch 200, 损失 = 0.0172663
epoch 4, batch 300, 损失 = 0.00805396
epoch 4, batch 400, 损失 = 0.00689