[3.10 多层感知机的简洁实现](https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch)

In [1]:
lr = 0.5
batch_size = 256
num_workers = 4
num_epochs = 10

In [2]:
import numpy as np
import time
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [3]:
num_inputs = 784
num_hiddens = 256
num_outputs = 10

In [4]:
class Perceptron(nn.Module):
    def __init__(self):
        super(Perceptron, self).__init__()
        self.linear1 = nn.Linear(num_inputs, num_hiddens)
        self.linear2 = nn.Linear(num_hiddens, num_outputs)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        '''
        x: (b, h * w)
        '''
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x
        

In [5]:
loss = torch.nn.CrossEntropyLoss()
net = Perceptron()
optimizer = torch.optim.SGD(net.parameters(), lr=lr)

In [6]:
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
print(len(mnist_train), len(mnist_test))
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

60000 10000


In [7]:
start = time.time()
for X, y in train_iter:
    continue
end = time.time()
print('%.2fs' % (end - start))

2.57s


In [8]:
for i in range(num_epochs):
    print('epoch:', i + 1)
    loss_sum = 0.0
    train_sum = 0
    correct_sum = 0
    for X, y in train_iter:
        y_hat = net(X.view(-1, num_inputs))
        # print(y_hat.shape)
        # print(y.shape)
        l = loss(y_hat, y)
        net.zero_grad()
        l.backward()
        optimizer.step()
        loss_sum += l.item()
        train_sum += y.shape[0]
        correct_sum += (y == y_hat.argmax(dim=1)).sum().item()
    print(f'train loss: %.4f' % (loss_sum / train_sum))
    print(f'train acc: %.2f' % (correct_sum / train_sum))
    loss_sum = 0.0
    test_sum = 0
    correct_sum = 0
    for X, y in test_iter:
        y_hat = net(X.view(-1, num_inputs))
        l = loss(y_hat, y)
        loss_sum += l.item()
        test_sum += y.shape[0]
        correct_sum += (y == y_hat.argmax(dim=1)).sum().item()
    print(f'test loss: %.4f' % (loss_sum / test_sum))
    print(f'test acc: %.2f' % (correct_sum / test_sum))
    

epoch: 1
train loss: 0.0071
train acc: 0.68
test loss: 0.0068
test acc: 0.77
epoch: 2
train loss: 0.0066
train acc: 0.79
test loss: 0.0067
test acc: 0.79
epoch: 3
train loss: 0.0065
train acc: 0.80
test loss: 0.0067
test acc: 0.80
epoch: 4
train loss: 0.0065
train acc: 0.80
test loss: 0.0067
test acc: 0.79
epoch: 5
train loss: 0.0064
train acc: 0.83
test loss: 0.0066
test acc: 0.83
epoch: 6
train loss: 0.0064
train acc: 0.84
test loss: 0.0065
test acc: 0.83
epoch: 7
train loss: 0.0064
train acc: 0.84
test loss: 0.0065
test acc: 0.83
epoch: 8
train loss: 0.0063
train acc: 0.85
test loss: 0.0065
test acc: 0.83
epoch: 9
train loss: 0.0063
train acc: 0.85
test loss: 0.0067
test acc: 0.80
epoch: 10
train loss: 0.0063
train acc: 0.85
test loss: 0.0066
test acc: 0.83
