In [26]:
import torch
import torch.nn as nn
import torchvision
from torch.utils import data
from torchvision import transforms
import warnings
warnings.filterwarnings('ignore')

In [3]:
trans = transforms.ToTensor()
batch_size = 256
train_data = torchvision.datasets.FashionMNIST(root='../data', train=True, download=True, transform=trans)
test_data = torchvision.datasets.FashionMNIST(root='../data', train=False, download=True, transform=trans)
train_iter = data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_iter = data.DataLoader(test_data, shuffle=True, batch_size=batch_size)

In [4]:
net = torch.nn.Sequential(nn.Flatten(),
                         nn.Linear(784, 256),
                         nn.ReLU(),
                         nn.Linear(256, 10))
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
)

In [5]:
loss = nn.CrossEntropyLoss()

In [6]:
lr = 0.01 
optimizer = torch.optim.SGD(net.parameters(), lr = lr)

In [27]:
num_epochs = 10
for epoch in range(num_epochs):
    training_loss = 0
    correct = 0
    num_batches = 0
    num_samples = 0
    for X, y in train_iter:
        y_hat = net(X)
        loss_batch = loss(y_hat, y)  # 计算的是一个batch的平均损失
        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()
        
        training_loss += loss_batch  # 对每一个batch的平均损失求和
        
        cmp = torch.argmax(y_hat, axis=1) == y  # 记录该batch预测训练精度
        correct += cmp.sum()
        
        num_batches += 1  # 记录一共有多少个batch
        num_samples += len(y)  # 记录一共有多少个样本
        
    print(f'epoch {epoch + 1}: loss : {training_loss / num_batches}, accuracy : {correct / num_samples}')

epoch 1: loss : 0.6433107256889343, accuracy : 0.7769166827201843
epoch 2: loss : 0.6220905184745789, accuracy : 0.7856000065803528
epoch 3: loss : 0.6038233041763306, accuracy : 0.7930499911308289
epoch 4: loss : 0.5874991416931152, accuracy : 0.7997999787330627
epoch 5: loss : 0.5738343000411987, accuracy : 0.8050833344459534
epoch 6: loss : 0.5613183975219727, accuracy : 0.8095666766166687
epoch 7: loss : 0.5504812598228455, accuracy : 0.8126333355903625
epoch 8: loss : 0.5404790043830872, accuracy : 0.8162333369255066
epoch 9: loss : 0.5315883159637451, accuracy : 0.8194833397865295
epoch 10: loss : 0.5236487984657288, accuracy : 0.822350025177002
