In [1]:
# from torchvision import datasets, transforms  # 不推荐
import torchvision.datasets as datasets  # 推荐
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

用torch实现一个对照组

In [2]:
# 读入数据

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(
    root='../data',
    train=True,
    download=False,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='../data',
    train=False,
    download=False,
    transform=transform
)

In [3]:
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear1 = nn.Linear(28 * 28, 100)
        self.linear2 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)

        return x

In [4]:
net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

In [5]:
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, shuffle=False)

x_train = train_loader.dataset.data.float().div(255)  # type: ignore
x_train = x_train.view(-1, 28 * 28)
t_train = train_loader.dataset.targets  # type: ignore

x_test = test_loader.dataset.data.float().div(255)  # type: ignore
x_test = x_test.view(-1, 28 * 28)
t_test = test_loader.dataset.targets  # type: ignore

In [6]:
epoch_num = 16

train_acc_list = []
test_acc_list = []

# 训练开始前的预测
net.eval()
with torch.no_grad():
    # 计算训练损失
    y_train_pred = net(x_train)
    train_loss = criterion(y_train_pred, t_train).item()
    
    # 计算训练准确率
    train_preds = torch.argmax(y_train_pred, dim=1)
    train_acc = (train_preds == t_train).float().mean().item()
    
    # 计算测试准确率
    y_test_pred = net(x_test)
    test_preds = torch.argmax(y_test_pred, dim=1)
    test_acc = (test_preds == t_test).float().mean().item()

print(f"train_loss: {train_loss}")
print(f"train_acc: {train_acc}")
print(f"test_acc: {test_acc}")

for epoch in range(epoch_num):
    net.train()
    for batch_idx, (x_batch, t_batch) in enumerate(train_loader):
        # 前向传播
        y_batch = net(x_batch)
        loss = criterion(y_batch, t_batch)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    net.eval()
    with torch.no_grad():
        y_train = net(x_train)
        train_preds = torch.argmax(y_train, dim=1)
        train_acc = (train_preds == t_train).float().mean().item() * 100
        train_acc_list.append(train_acc)

        y_test = net(x_test)
        test_preds = torch.argmax(y_test, dim=1)
        test_acc = (test_preds == t_test).float().mean().item() * 100
        test_acc_list.append(test_acc)

    print(f"epoch {epoch + 1} train_acc: {train_acc:.2f}%")
    print(f"epoch {epoch + 1} test_acc: {test_acc:.2f}%")

train_loss: 2.2934887409210205
train_acc: 0.13118332624435425
test_acc: 0.12540000677108765
epoch 1 train_acc: 91.26%
epoch 1 test_acc: 91.53%
epoch 2 train_acc: 93.05%
epoch 2 test_acc: 93.03%
epoch 3 train_acc: 94.55%
epoch 3 test_acc: 94.53%
epoch 4 train_acc: 95.47%
epoch 4 test_acc: 95.46%
epoch 5 train_acc: 96.11%
epoch 5 test_acc: 95.95%
epoch 6 train_acc: 96.69%
epoch 6 test_acc: 96.33%
epoch 7 train_acc: 97.00%
epoch 7 test_acc: 96.50%
epoch 8 train_acc: 97.33%
epoch 8 test_acc: 96.73%
epoch 9 train_acc: 97.54%
epoch 9 test_acc: 96.91%
epoch 10 train_acc: 97.79%
epoch 10 test_acc: 97.23%
epoch 11 train_acc: 97.96%
epoch 11 test_acc: 97.15%
epoch 12 train_acc: 98.09%
epoch 12 test_acc: 97.30%
epoch 13 train_acc: 98.25%
epoch 13 test_acc: 97.32%
epoch 14 train_acc: 98.47%
epoch 14 test_acc: 97.35%
epoch 15 train_acc: 98.47%
epoch 15 test_acc: 97.44%
epoch 16 train_acc: 98.63%
epoch 16 test_acc: 97.48%
