In [83]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.tensorboard import SummaryWriter

# 载入训练集与测试集

In [84]:
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [85]:
len(train_data)

60000

# 将训练集划分成训练集及验证集

In [86]:
# train_data, val_data = random_split(train_data, [50000,10000])

# 将数据转成DataLoader格式，并设置batchsize和shuffle

In [87]:
batch_size = 64
learning_rate = 1e-3
epochs = 10

In [88]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
#val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# 定义网络结构

In [89]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU(),
            nn.Linear(512,10),
        )
        
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    
model = Net()

# 定义loss及优化器

In [90]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
writer = SummaryWriter('runs/fashion_mnist_experiment_1')

# 定义train

In [91]:
def train_loop(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # 计算loss
        pred = model(X)
        correct = (pred.argmax(1) == y).type(torch.float).sum().item()
        loss = loss_fn(pred, y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            correct /= len(X)
            print(f"loss: {loss:>7f}, accuracy: {(100*correct):>0.1f}% [{current:>5d}/{size:>5d}]")
            writer.add_scalars('Train',
                               {'loss': loss, 'accuracy': correct},
                               current + size * epoch)
            writer.flush()
            
    writer.flush()

# 定义test

In [92]:
def test_loop(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    num_batch = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_loss /= num_batch
    correct /= size
    current = epoch * size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    writer.add_scalars('Test',
                       {'loss': test_loss, 'accuracy': correct},
                       current)
    writer.flush()

# 开始训练

In [93]:
for t in range(epochs):
    print(f"Epoch {t+1}\n--------------------------")
    train_loop(train_loader, model, loss_fn, optimizer, t+1)
    test_loop(test_loader, model, loss_fn, optimizer, t+1)
print("Done!")

Epoch 1
--------------------------
loss: 2.307419, accuracy: 9.4% [    0/60000]
loss: 2.289490, accuracy: 9.4% [ 6400/60000]
loss: 2.281082, accuracy: 15.6% [12800/60000]
loss: 2.266302, accuracy: 25.0% [19200/60000]
loss: 2.250245, accuracy: 26.6% [25600/60000]
loss: 2.225015, accuracy: 48.4% [32000/60000]
loss: 2.217156, accuracy: 35.9% [38400/60000]
loss: 2.197673, accuracy: 50.0% [44800/60000]
loss: 2.203388, accuracy: 51.6% [51200/60000]
loss: 2.189970, accuracy: 53.1% [57600/60000]
Test Error: 
 Accuracy: 53.7%, Avg loss: 2.159365 

Epoch 2
--------------------------
loss: 2.182321, accuracy: 42.2% [    0/60000]
loss: 2.141686, accuracy: 59.4% [ 6400/60000]
loss: 2.112536, accuracy: 62.5% [12800/60000]
loss: 2.058203, accuracy: 65.6% [19200/60000]
loss: 2.056675, accuracy: 60.9% [25600/60000]
loss: 2.019469, accuracy: 60.9% [32000/60000]
loss: 1.978402, accuracy: 64.1% [38400/60000]
loss: 1.952573, accuracy: 57.8% [44800/60000]
loss: 1.924772, accuracy: 65.6% [51200/60000]
loss: 