In [5]:
import torch
from torch import nn
from torchvision import transforms
import torchvision
from torch.utils import data

In [6]:
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root="data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4),
            data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4))


batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

In [7]:
# 初始化模型参数
num_inputs, num_outputs, num_hiddens = 784, 10, 256
W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [W1, b1, W2, b2]

In [8]:
# 激活函数
def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

In [13]:
# 模型
def net(X):
    X = X.reshape((-1, num_inputs))
    #@表示矩阵乘法
    H = relu(X @ W1 + b1)
    return (H @ W2 + b2)

In [14]:
# 损失函数
loss = nn.CrossEntropyLoss(reduction="none")

In [15]:
# 训练
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)

In [16]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    # 将y_hat进行类型转换
    cmp = y_hat.type(y.dtype) == y
    # 返回预测正确的个数
    return float(cmp.type(y.dtype).sum())


def train_epoch_ch3(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()
    metric = Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]


def evaluate_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]


def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        print(train_metrics, test_acc)


# 训练
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

(1.0363145255406698, 0.6404666666666666) 0.7155
(0.5988208348592122, 0.7891166666666667) 0.7998
(0.5191916248321533, 0.8179166666666666) 0.7547
(0.47980096282958984, 0.8316666666666667) 0.8256
(0.45324630273183186, 0.84005) 0.8202
(0.4328126366933187, 0.8472) 0.8305
(0.416875963528951, 0.8536333333333334) 0.8405
(0.4052356522878011, 0.85755) 0.818
(0.3930032709757487, 0.86135) 0.8405
(0.38217025489807127, 0.8649833333333333) 0.8522
