# 3.7 softmax的pytorch实现

In [None]:
import torch
import torch.nn as nn
from torch.utils import data
from torchvision import datasets, transforms
import torch.optim

### load the data and param

In [None]:
BATCH_SIZE = 32
EPOCH = 10
INPUT = 28* 28
OUTPUT = 10
LR = 0.1
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.FashionMNIST(root = './data', train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(root = './data', train=False, download=False, transform=transform)

train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### create the net

In [None]:

class softmax_Net(nn.Module):
    def __init__(self, input, output):
        super().__init__()
        self.net = nn.Sequential(nn.Flatten(), nn.Linear(input, output))
        
    def forward(self, x):
        return self.net(x)

net = softmax_Net(INPUT, OUTPUT)


### train

In [None]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)

In [None]:
for epoch in range(EPOCH):
    net.train()
    total_loss, correct, total = 0, 0, 0
    for X, y in train_loader:
        y_pred = net(X)
        l = loss(y_pred, y)
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

        total_loss += l.item()
        correct += (y_pred.argmax(1) == y).sum().item()
        total += y.size(0)

    train_acc = correct / total
    print(f"Epoch [{epoch+1}/{EPOCH}]  Loss: {total_loss/len(train_loader):.4f}  Accuracy: {train_acc:.4f}")


In [None]:
net.eval()
correct, total = 0, 0
with torch.no_grad():
    for X, y in test_loader:
        y_pred = net(X)
        print(y.shape, y_pred.shape)
        correct += (y_pred.argmax(1) == y).sum().item()
        total += y.size(0)

print(f"Test Accuracy: {correct / total:.4f}")