In [127]:
import torch
from tool import load_data_fashion_mnist

# 准备数据

In [128]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
next(iter(train_iter))[1].shape

torch.Size([256])

In [129]:
def softmax(X):
    X_exp = torch.exp(X)
    X_sum = X_exp.sum(1, keepdim=True)
    return X_exp / X_sum


In [130]:
w = torch.normal(0, 1, size=(28 * 28, 10), requires_grad=True)
b = torch.zeros(10, requires_grad=True)

In [131]:
def net(X):
    return softmax(torch.matmul(X.reshape(-1, w.shape[0]), w) + b)

In [132]:
def cross_entropy(y_hat, y):
    return -torch.log(y_hat[range(len(y_hat)), y])

In [133]:
def SGD(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

In [134]:
def accuracy(y_hat, y):
    # 预测正确的个数
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

In [135]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
    
    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [136]:
def evaluate_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()
    metrics = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            y_pred = net(X)
            acc = accuracy(y_pred, y)
            metrics.add(acc, y.numel())
    return metrics[0]/metrics[1]

In [137]:
def train_epoch(net, train_iter, loss_func, optimizer):
    if isinstance(net, torch.nn.Module):
        net.train()
    metric = Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        loss = loss_func(y_hat, y)
        if isinstance(optimizer, torch.optim.Optimizer):
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
        else:
            loss.sum().backward()
            optimizer([w,b], lr=0.01, batch_size=X.shape[0])
        metric.add(loss.sum(), accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]

In [138]:
def train(num_epochs, net, train_iter, test_iter, loss_func, optimizer):

    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss_func, optimizer)
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'epoch {epoch}, train loss: {train_metrics[0]}, train acc: {train_metrics[1]}')

In [139]:
loss_func = cross_entropy
optimizer = SGD

In [141]:
train(10, net, train_iter, test_iter, loss_func, optimizer)

epoch 0, train loss: 3.04766295598348, train acc: 0.5101
epoch 1, train loss: 2.7442292922973635, train acc: 0.53975
epoch 2, train loss: 2.537665101114909, train acc: 0.5617833333333333
epoch 3, train loss: 2.385638902537028, train acc: 0.5788333333333333
epoch 4, train loss: 2.2678290891011557, train acc: 0.59205
epoch 5, train loss: 2.1728394027709963, train acc: 0.6043666666666667
epoch 6, train loss: 2.0933712020874022, train acc: 0.6131166666666666
epoch 7, train loss: 2.0259221177419025, train acc: 0.6217666666666667
epoch 8, train loss: 1.9674454442342122, train acc: 0.6285166666666666
epoch 9, train loss: 1.9158699361165366, train acc: 0.6352333333333333
