## 准备数据

In [22]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}

def mnist_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True,
                                   download=True, transform=transform)
    x_train = train_dataset.data.float() / 255.0
    y_train = train_dataset.targets

    test_dataset = datasets.MNIST(root='./data', train=False,
                                  download=True, transform=transform)
    x_test = test_dataset.data.float() / 255.0
    y_test = test_dataset.targets

    return (x_train, y_train), (x_test, y_test)

In [23]:
print(list(zip([1, 2, 3, 4], ['a', 'b', 'c', 'd'])))

[(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')]


## 建立模型

In [24]:
class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()
        self.W1 = nn.Parameter(torch.randn(784, 128))
        self.b1 = nn.Parameter(torch.zeros(128))
        self.W2 = nn.Parameter(torch.randn(128, 10))
        self.b2 = nn.Parameter(torch.zeros(10))

    def forward(self, x):
        x = x.view(-1, 784)
        h1 = torch.matmul(x, self.W1) + self.b1
        h1 = torch.relu(h1)
        logits = torch.matmul(h1, self.W2) + self.b2
        return logits
    
model = myModel()
optimizer = optim.Adam(model.parameters())

## 计算 loss

In [25]:
def compute_loss(logits, labels):
    criterion = nn.CrossEntropyLoss()
    return criterion(logits, labels)

def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=1)
    correct = (predictions == labels).float().sum()
    return correct / len(labels)

def train_one_step(model, optimizer, x, y):
    optimizer.zero_grad()
    logits = model(x)
    loss = compute_loss(logits, y)
    loss.backward()
    optimizer.step()
    accuracy = compute_accuracy(logits, y)
    return loss, accuracy

def test(model, x, y):
    with torch.no_grad():
        logits = model(x)
        loss = compute_loss(logits, y)
        accuracy = compute_accuracy(logits, y)
    return loss, accuracy

## 实际训练

In [33]:
train_data, test_data = mnist_dataset()
for epoch in range(50):
    loss, accuracy = train_one_step(model, optimizer,
                                    train_data[0],
                                    train_data[1])
    print(f'epoch {epoch}: loss {loss.item()}; accuracy {accuracy.item()}')

loss, accuracy = test(model,
                      test_data[0],
                      test_data[1])

print(f'test loss {loss.item()}; accuracy {accuracy.item()}')

epoch 0: loss 5.3024187088012695; accuracy 0.8157166838645935
epoch 1: loss 5.288545608520508; accuracy 0.8160499930381775
epoch 2: loss 5.274740695953369; accuracy 0.8163833618164062
epoch 3: loss 5.261007308959961; accuracy 0.8167666792869568
epoch 4: loss 5.247347354888916; accuracy 0.8170666694641113
epoch 5: loss 5.233755588531494; accuracy 0.8173999786376953
epoch 6: loss 5.220232963562012; accuracy 0.8176666498184204
epoch 7: loss 5.206782341003418; accuracy 0.818149983882904
epoch 8: loss 5.193399429321289; accuracy 0.8185166716575623
epoch 9: loss 5.180081844329834; accuracy 0.8189666867256165
epoch 10: loss 5.1668291091918945; accuracy 0.8191666603088379
epoch 11: loss 5.153646945953369; accuracy 0.819516658782959
epoch 12: loss 5.140530109405518; accuracy 0.8199166655540466
epoch 13: loss 5.127477169036865; accuracy 0.8203166723251343
epoch 14: loss 5.1144843101501465; accuracy 0.8205500245094299
epoch 15: loss 5.101556777954102; accuracy 0.820900022983551
epoch 16: loss 5.0