In [13]:
import torch
from torch import nn
from d2l import torch as d2l
import torch.optim as optim

In [14]:
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

In [15]:
batch_size, lr, num_epochs = 256, 0.1, 100
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
net = net.to(device)
criterion = nn.CrossEntropyLoss()

def train_epoch(net, train_iter, loss, trainer):
    net.train()
    total_loss = 0
    for X, y in train_iter:
        X, y = X.to(device), y.to(device)  # 将数据移动到GPU上
        optimizer.zero_grad()
        y_hat = net(X)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_iter.dataset)

def evaluate_accuracy(net, data_iter):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            _, predicted = torch.max(y_hat, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

for epoch in range(num_epochs):
    train_loss = train_epoch(net, train_iter, loss, trainer)
    test_acc = evaluate_accuracy(net, test_iter)
    print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}')

Epoch 1, Loss: 0.0085, Test Accuracy: 0.5057
Epoch 2, Loss: 0.0062, Test Accuracy: 0.6055
Epoch 3, Loss: 0.0045, Test Accuracy: 0.6400
Epoch 4, Loss: 0.0037, Test Accuracy: 0.6690
Epoch 5, Loss: 0.0033, Test Accuracy: 0.6812
Epoch 6, Loss: 0.0031, Test Accuracy: 0.7120
Epoch 7, Loss: 0.0029, Test Accuracy: 0.7331
Epoch 8, Loss: 0.0027, Test Accuracy: 0.7471
Epoch 9, Loss: 0.0026, Test Accuracy: 0.7621
Epoch 10, Loss: 0.0025, Test Accuracy: 0.7700
Epoch 11, Loss: 0.0024, Test Accuracy: 0.7791
Epoch 12, Loss: 0.0024, Test Accuracy: 0.7863
Epoch 13, Loss: 0.0023, Test Accuracy: 0.7889
Epoch 14, Loss: 0.0022, Test Accuracy: 0.7953
Epoch 15, Loss: 0.0022, Test Accuracy: 0.8003
Epoch 16, Loss: 0.0022, Test Accuracy: 0.8020
Epoch 17, Loss: 0.0021, Test Accuracy: 0.8070
Epoch 18, Loss: 0.0021, Test Accuracy: 0.8101
Epoch 19, Loss: 0.0020, Test Accuracy: 0.8123
Epoch 20, Loss: 0.0020, Test Accuracy: 0.8124
Epoch 21, Loss: 0.0020, Test Accuracy: 0.8140
Epoch 22, Loss: 0.0020, Test Accuracy: 0.81

KeyboardInterrupt: 