## 准备数据

In [16]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn, optim
import numpy as np

def mnist_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 归一化
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_loader, test_loader

In [17]:
class MyModel(nn.Module):
    def __init__(self):
        ####################
        '''声明模型对应的参数'''
        ####################
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(784, 256)
        self.layer2 = nn.Linear(256, 128)
        self.layer3 = nn.Linear(128, 10)
        
    def forward(self, x):
        ####################
        '''实现模型函数体，返回未归一化的logits'''
        ####################
        x = x.view(-1, 784)
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        return self.layer3(x)

model = MyModel()

optimizer = optim.Adam(model.parameters())

## 计算 loss

In [25]:
from torch.nn import functional as F

# 定义损失计算
def compute_loss(logits, labels):
    return F.cross_entropy(logits, labels)

# 定义准确率计算
def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=1)
    label = torch.argmax(labels, dim=1)
    correct = torch.sum(predictions == label).item()
    total = labels.size(0)
    accuracy = correct / total
    return accuracy

# 定义单步训练过程
def train_one_step(model, optimizer, x, y):
    model.train()
    logits = model(x)
    optimizer.zero_grad()
    loss = compute_loss(logits, y)
    loss.backward()
    optimizer.step()
    accuracy = compute_accuracy(logits, y)
    return loss, accuracy

In [19]:
# 定义测试函数
def test(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for X, y in data_loader:
            logits = model(X)
            test_loss += compute_loss(logits, y).item() * X.size(0)
            correct += (logits.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= len(data_loader.dataset)
    accuracy = correct / len(data_loader.dataset)
    print(f'Test Loss: {test_loss}, Test Accuracy: {accuracy}')

### 实际训练与测试

In [28]:
train_loader, test_loader = mnist_dataset()
num_train_samples = len(train_loader.dataset)
num_test_samples = len(test_loader.dataset)
num_classes = 10  # MNIST数据集有10个类别

train_labels = np.zeros((num_train_samples, num_classes))
test_labels = np.zeros((num_test_samples, num_classes))

train_images = torch.zeros((num_train_samples, 1, 28, 28))
test_images = torch.zeros((num_test_samples, 1, 28, 28))

for batch_idx, (images, labels) in enumerate(train_loader):
    one_hot_labels = F.one_hot(labels, num_classes=num_classes)
    start_idx = batch_idx * train_loader.batch_size
    end_idx = start_idx + labels.size(0)
    train_labels[start_idx:end_idx, :] = one_hot_labels.numpy()
    train_images[start_idx:end_idx] = images

for batch_idx, (images, labels) in enumerate(test_loader):
    one_hot_labels = F.one_hot(labels, num_classes=num_classes)
    start_idx = batch_idx * test_loader.batch_size
    end_idx = start_idx + labels.size(0)
    test_labels[start_idx:end_idx, :] = one_hot_labels.numpy()
    test_images[start_idx:end_idx] = images

# print(train_images.shape, train_labels.shape)
# print(test_images.shape, test_labels.shape)

train_images = train_images.reshape(60000, 784)
test_images = test_images.reshape(10000, 784)

# print(train_images.shape, train_labels.shape)
# print(test_images.shape, test_labels.shape)

train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)

# 每个epoch（整个测试集）训练一次
for epoch in range(50):
    loss, accuracy =  train_one_step(model, optimizer, train_images, train_labels)
    print('epoch', epoch, ': loss', loss.detach().numpy(), '; accuracy', accuracy)

epoch 0 : loss 1.8548702682649096 ; accuracy 0.30428333333333335
epoch 1 : loss 1.8051024422374864 ; accuracy 0.25035
epoch 2 : loss 1.5769686650415262 ; accuracy 0.37895
epoch 3 : loss 1.366286783773328 ; accuracy 0.5770333333333333
epoch 4 : loss 1.28166658418489 ; accuracy 0.6205333333333334
epoch 5 : loss 1.2788800315156579 ; accuracy 0.6307333333333334
epoch 6 : loss 1.290646019763189 ; accuracy 0.6416833333333334
epoch 7 : loss 1.3042788376756012 ; accuracy 0.6574833333333333
epoch 8 : loss 1.3071002357546706 ; accuracy 0.6673833333333333
epoch 9 : loss 1.2713769735541505 ; accuracy 0.6734
epoch 10 : loss 1.2165665359789506 ; accuracy 0.67795
epoch 11 : loss 1.159590957188451 ; accuracy 0.6882666666666667
epoch 12 : loss 1.095249577836568 ; accuracy 0.7192
epoch 13 : loss 1.0304454223492494 ; accuracy 0.7416
epoch 14 : loss 0.9734991618441418 ; accuracy 0.7370166666666667
epoch 15 : loss 0.9307171760216355 ; accuracy 0.7363833333333333
epoch 16 : loss 0.9097373141319491 ; accurac

In [29]:
test(model, test_loader)

Test Loss: 0.507794650888443, Test Accuracy: 0.8624
