# softmax的简洁实现

In [33]:
import torch
import torchvision
from torch import nn
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l


In [34]:
def load_data_fashion_mnist(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../../mnist_data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../../mnist_data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_works()),
            data.DataLoader(mnist_test, batch_size, shuffle=True, num_workers=get_dataloader_works()))


In [35]:
def get_dataloader_works():
    # 使用4个进程读取数据
    return 4


In [36]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)


#### 初始化模型参数

In [37]:
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))


def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)


net.apply(init_weights);


In [38]:
"""
'none'表示直接返回N个样本的loss,是一个向量
'sum'指对N个样本的loss求和
'elementwise_mean'为默认情况,N个loss求平均
"""
loss = nn.CrossEntropyLoss(reduction='none')


In [39]:
trainer = torch.optim.SGD(net.parameters(), lr=0.1)


#### 训练

In [40]:
def train_epoch(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()
    metric = d2l.Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y) #此处l是一个向量，不能直接反向传播
        if isinstance(updater, torch.optim.Optimizer):
            # 使用Pytorch内置优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), d2l.accuracy(y_hat, y), y.numel())
        return metric[0] / metric[2], metric[1] / metric[2]


In [41]:
def train(net, train_iter, test_iter, loss, num_epochs, updater):
    # animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[
    #                         0.3, 0.9], legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss, updater)
        test_acc = d2l.evaluate_accuracy(net, test_iter)
        # animator.add(epoch+1, train_metrics+(test_acc,))
        c_train_loss, c_train_acc = train_metrics
        print('epoch:%d train_loss:%.4f train_acc:%.2f test_acc:%.2f' % (epoch+1,c_train_loss,c_train_acc,test_acc))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc


In [42]:
num_epochs=50
train(net,train_iter,test_iter,loss,num_epochs,trainer)

epoch:1 train_loss:2.3206 train_acc:0.11 test_acc:0.13
epoch:2 train_loss:2.1254 train_acc:0.13 test_acc:0.48
epoch:3 train_loss:1.9675 train_acc:0.46 test_acc:0.59
epoch:4 train_loss:1.8187 train_acc:0.58 test_acc:0.60
epoch:5 train_loss:1.7096 train_acc:0.60 test_acc:0.49
epoch:6 train_loss:1.6122 train_acc:0.55 test_acc:0.63
epoch:7 train_loss:1.5648 train_acc:0.63 test_acc:0.61
epoch:8 train_loss:1.5579 train_acc:0.56 test_acc:0.54
epoch:9 train_loss:1.4184 train_acc:0.59 test_acc:0.60
epoch:10 train_loss:1.3529 train_acc:0.62 test_acc:0.63
epoch:11 train_loss:1.2607 train_acc:0.69 test_acc:0.65
epoch:12 train_loss:1.2115 train_acc:0.72 test_acc:0.58
epoch:13 train_loss:1.3085 train_acc:0.56 test_acc:0.63
epoch:14 train_loss:1.2499 train_acc:0.63 test_acc:0.65
epoch:15 train_loss:1.1120 train_acc:0.73 test_acc:0.65
epoch:16 train_loss:1.2236 train_acc:0.64 test_acc:0.66
epoch:17 train_loss:1.1069 train_acc:0.70 test_acc:0.66
epoch:18 train_loss:1.1364 train_acc:0.66 test_acc:0.65
e

AssertionError: 0.7986176609992981