In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch.optim import SGD

In [5]:
# 加载数据
def load_data_fashion_mnist(batch_size):
    trans = transforms.Compose([transforms.ToTensor()])
    mnist_train = torchvision.datasets.FashionMNIST('../data', train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST('../data', train=False, transform=trans, download=True)
    return (
        DataLoader(mnist_train, batch_size=batch_size, shuffle=True),
        DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
    )

In [7]:
train_iter, test_iter = load_data_fashion_mnist(256)

In [13]:
# 建立模型，损失函数，优化器
net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))

def init_net(module):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight.data, std=0.01)
        nn.init.zeros_(module.bias.data)
net.apply(init_net)

updater = SGD(net.parameters(), lr=0.01, )
loss_func = nn.CrossEntropyLoss(reduction='none')

In [14]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
    def reset(self):
        self.data = [0.0] * len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

In [25]:
y = torch.arange(4)
y.type(torch.float32)


torch.int64

In [33]:
# 训练
def train_epoch(train_iter, net, loss_func, updater):
    for X, y in train_iter:
        y_hat = net(X)
        loss = loss_func(y_hat, y)
        updater.zero_grad()
        loss.sum().backward()
        updater.step()
    
# 准确率
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

# 评估
def evaluate_accuracy(test_iter, net):
    metrics = Accumulator(2)
    with torch.no_grad():
        for X, y in test_iter:
            y_hat = net(X)
            acc = accuracy(y_hat, y)
            metrics.add(acc, y.numel())
    return metrics[0]/metrics[1]


In [34]:
# 迭代训练
def train(net, train_iter, test_iter, updater, loss_func, num_epochs):
    for epoch in range(num_epochs):
        train_epoch(train_iter, net, loss_func, updater)
        acc = evaluate_accuracy(test_iter, net)
        print(f'epoch: {epoch}, acc: {acc}')


In [35]:
train(net, train_iter, test_iter, updater, loss_func, 10)

epoch: 0, acc: 0.8081
epoch: 1, acc: 0.8368
epoch: 2, acc: 0.8298
epoch: 3, acc: 0.8092
epoch: 4, acc: 0.8178
epoch: 5, acc: 0.8046
epoch: 6, acc: 0.8249
epoch: 7, acc: 0.8218
epoch: 8, acc: 0.8391
epoch: 9, acc: 0.8291
