## 准备数据

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms

# 设置环境变量
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}

def mnist_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    return train_data, test_data

## 建立模型

In [2]:
class myModel(nn.Module):
    def __init__(self, input_dim=784, output_dim=10, hidden_dim=2048):
        ####################
        '''声明模型对应的参数'''
        ####################
        super(myModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.LogSoftmax(1),
        )
        
    def forward(self, x):
        ####################
        '''实现模型函数体，返回未归一化的logits'''
        ####################
        return self.net(x)
        
model = myModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00002, weight_decay=1e-9)

## 计算 loss

In [3]:
def compute_loss(logits, labels):
    label = torch.argmax(labels, dim=1)
    log_probs_for_labels = logits[range(logits.shape[0]), label]
    loss = -torch.mean(log_probs_for_labels)
    return loss

def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=1)
    label = torch.argmax(labels, dim=1)
    correct = torch.sum(predictions == label).item()
    total = labels.size(0)
    accuracy = correct / total
    return accuracy


def train_one_step(model, optimizer, x, y):
    model.train()

    logits = model(x)
    optimizer.zero_grad()
    loss = compute_loss(logits, y)
    loss.backward()
    optimizer.step()

    accuracy = compute_accuracy(logits, y)
    # loss and accuracy is scalar tensor
    return loss, accuracy


def test(model, x, y):
    model.eval()
    logits = model(x)
    loss = compute_loss(logits, y)
    accuracy = compute_accuracy(logits, y)
    return loss, accuracy

## 实际训练

In [4]:
train_data, test_data = mnist_dataset()
train_label = np.zeros(shape=[len(train_data), 10])
test_label = np.zeros(shape=[len(test_data), 10])
for i, (_, label) in enumerate(train_data):
    train_label[i, label] = 1.
for i, (_, label) in enumerate(test_data):
    test_label[i, label] = 1.
train_data = torch.stack([image for image, _ in train_data])
test_data = torch.stack([image for image, _ in test_data])

In [5]:
train_data = train_data.reshape(60000, 784)
test_data = test_data.reshape(10000, 784)

In [6]:
train_label = torch.tensor(train_label)
test_label = torch.tensor(test_label)

In [7]:
for epoch in range(50):
    loss, accuracy = train_one_step(model, optimizer, train_data, train_label)
    print('epoch', epoch, ': loss', loss.detach().numpy(), '; accuracy', accuracy)
loss, accuracy = test(model, test_data, test_label)

print('test loss', loss.detach().numpy(), '; accuracy', accuracy)

epoch 0 : loss 2.3108542 ; accuracy 0.04985
epoch 1 : loss 2.3047106 ; accuracy 0.06091666666666667
epoch 2 : loss 2.298585 ; accuracy 0.07336666666666666
epoch 3 : loss 2.292476 ; accuracy 0.08945
epoch 4 : loss 2.2863839 ; accuracy 0.1079
epoch 5 : loss 2.2803078 ; accuracy 0.12588333333333335
epoch 6 : loss 2.2742474 ; accuracy 0.14758333333333334
epoch 7 : loss 2.268202 ; accuracy 0.17235
epoch 8 : loss 2.2621713 ; accuracy 0.19915
epoch 9 : loss 2.256154 ; accuracy 0.22688333333333333
epoch 10 : loss 2.2501507 ; accuracy 0.2553166666666667
epoch 11 : loss 2.24416 ; accuracy 0.28386666666666666
epoch 12 : loss 2.2381806 ; accuracy 0.31315
epoch 13 : loss 2.2322125 ; accuracy 0.3435
epoch 14 : loss 2.2262542 ; accuracy 0.37201666666666666
epoch 15 : loss 2.2203045 ; accuracy 0.3993333333333333
epoch 16 : loss 2.2143629 ; accuracy 0.4252666666666667
epoch 17 : loss 2.208428 ; accuracy 0.44953333333333334
epoch 18 : loss 2.2024987 ; accuracy 0.4741666666666667
epoch 19 : loss 2.196574