In [29]:
# 构建数据集
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST(root='./dataset/mnist/',
                            train=True,
                            download=False,
                            transform=transform)

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=64,
                              shuffle=True)

test_data = datasets.MNIST(root='./dataset/mnist/',
                           train=False,
                           download=False,
                           transform=transform)

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=64,
                             shuffle=False)

In [30]:
# 建立模型
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)

model = Net()

In [31]:
# 损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.5)

In [32]:
# 训练
def train():
    for idx, data in enumerate(train_dataloader, 0):
        inputs, labels = data
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for idx, data in enumerate(test_dataloader, 0):
            inputs, labels = data

            outputs = model(inputs)
            _, pred = torch.max(outputs.data, dim=1)
            correct += (pred == labels).sum().item()
            total += pred.size(0)
        print("correct acc is %lf" % (correct / total))

In [33]:
for epoch in range(100):
    train()
    if epoch % 10 == 9:
        test()

correct acc is 0.979400
correct acc is 0.983000
correct acc is 0.984600
correct acc is 0.984600
correct acc is 0.984600
correct acc is 0.984700
correct acc is 0.984800
correct acc is 0.984900
correct acc is 0.984900
correct acc is 0.984800
