# Softmax的简单实现

In [13]:
import torch
from torch import nn
from d2l import torch as d2l
from torch.utils.data import DataLoader
import torchvision

In [14]:
batch_size = 256
train_data = torchvision.datasets.FashionMNIST(
    root='../data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.FashionMNIST(
    root='../data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_iter = DataLoader(train_data , batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_data , batch_size=batch_size, shuffle=False)

softmax的回归输出层本质就是一个全连接层，为了实现模型，其实只需要一个线性层即可，其输入为784，输出为10

In [None]:
net= nn.Sequential(nn.Flatten() , nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight , std = 0.01)

net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

In [17]:
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters() , lr = 0.1)
num_epochs = 10

In [23]:
# 使能GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)
loss = loss.to(device)


In [26]:
# 手动书写训练过程
for epoch in range(num_epochs):
    train_step = 0
    net.train()
    for data in train_iter:
        train_step += 1
        trainer.zero_grad()
        X, y = data
        X = X.to(device)
        y = y.to(device)
        y_hat = net(X)
        l = loss(y_hat, y).mean()
        l.backward()
        trainer.step()
        if train_step % 100 == 0:
            print(f'epoch {epoch + 1}, step {train_step}, loss {l.item():.4f}')
    # 测试模型
    net.eval()
    test_step = 0
    acc_num = 0
    all_num = 0
    with torch.no_grad():
        for data in test_iter:
            test_step += 1
            X, y = data
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            y_hat = y_hat.argmax(dim=1)
            cmp = y_hat.type(y.dtype) == y
            acc_num += cmp.sum().item()
            all_num += y.shape[0]
    print(f'epoch {epoch + 1}, test acc {acc_num / all_num:.4f}')
        

epoch 1, step 100, loss 0.4363
epoch 1, step 200, loss 0.4708
epoch 1, test acc 0.8346
epoch 2, step 100, loss 0.4964
epoch 2, step 200, loss 0.4582
epoch 2, test acc 0.8314
epoch 3, step 100, loss 0.4245
epoch 3, step 200, loss 0.4239
epoch 3, test acc 0.8355
epoch 4, step 100, loss 0.4638
epoch 4, step 200, loss 0.3814
epoch 4, test acc 0.8354
epoch 5, step 100, loss 0.4290
epoch 5, step 200, loss 0.3722
epoch 5, test acc 0.8398
epoch 6, step 100, loss 0.5360
epoch 6, step 200, loss 0.4015
epoch 6, test acc 0.8309
epoch 7, step 100, loss 0.3839
epoch 7, step 200, loss 0.5361
epoch 7, test acc 0.8329
epoch 8, step 100, loss 0.4683
epoch 8, step 200, loss 0.4952
epoch 8, test acc 0.8388
epoch 9, step 100, loss 0.4076
epoch 9, step 200, loss 0.4453
epoch 9, test acc 0.8209
epoch 10, step 100, loss 0.4220
epoch 10, step 200, loss 0.3938
epoch 10, test acc 0.8393
