In [81]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [82]:
# 创建训练和测试数据集
train_data = datasets.FashionMNIST(root='../data/',
                                    train=True,
                                    transform=ToTensor(),
                                    download=False)
test_data = datasets.FashionMNIST(root='../data/', 
                                  train=False, 
                                  transform=ToTensor(),
                                  download=False)

In [83]:
# 创建DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [84]:
# 定义一个简单的CNN网络
class simple_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.linear1 = nn.Linear(in_features=16 * 16, out_features=120)
        self.linear2 = nn.Linear(in_features=120, out_features=84)
        self.linear3 = nn.Linear(in_features=84, out_features=10)
    
    def forward(self, x):
        # input(N, 1, 28, 28)
        c1 = F.relu(self.conv1(x)) # (N, 6, 24, 24)
        s2 = F.max_pool2d(c1, (2, 2)) # (N, 6, 12, 12)
        c3 = F.relu(self.conv2(s2)) # (N, 16, 8, 8)
        s4 = F.max_pool2d(c3, 2) # (N, 16, 4, 4)
        s4_flattened = torch.flatten(s4, 1) # (N, 16 * 16)
        f5 = F.relu(self.linear1(s4_flattened)) # (N, 120)
        f6 = F.relu(self.linear2(f5)) # (N, 84)
        logits = self.linear3(f6) # (N, 10)
        return logits

model = simple_CNN()
print(model)
        

simple_CNN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (linear1): Linear(in_features=256, out_features=120, bias=True)
  (linear2): Linear(in_features=120, out_features=84, bias=True)
  (linear3): Linear(in_features=84, out_features=10, bias=True)
)


In [85]:
# 定义损失和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [86]:
# 训练网络
batch_size = 64
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    loss, accuracy = 0, 0

    for img_idx, (image, label) in enumerate(dataloader):
        pred = model(image)
        loss = loss_fn(pred, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if img_idx % 100 == 0:
            loss, current = loss.item(), img_idx * batch_size + len(image)
            accuracy += (pred.argmax(1) == label).type(torch.float).sum().item()
            print(f'loss = {loss:>5f} [{current:>5d}/{size:>5d}]')
    


In [87]:
def test_loop(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    loss, accuracy = 0, 0

    with torch.no_grad():
        for image, label in dataloader:
            pred = model(image)
            loss += loss_fn(pred, label).item()
            accuracy += (pred.argmax(1) == label).type(torch.float).sum().item()

    loss /= num_batches
    accuracy /= size
    print(f'Test Error: \n Accuracy:{(100*accuracy):>0.1f}%, Avg loss:{loss:>8f} \n')

In [88]:
for epoch in range(5):
    print(f'Epoch:{epoch + 1}')
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch:1
loss = 2.300874 [   64/60000]
loss = 2.305364 [ 6464/60000]
loss = 2.295405 [12864/60000]
loss = 2.295435 [19264/60000]
loss = 2.293531 [25664/60000]
loss = 2.275216 [32064/60000]
loss = 2.255158 [38464/60000]
loss = 2.190359 [44864/60000]
loss = 1.784206 [51264/60000]
loss = 1.243173 [57664/60000]
Test Error: 
 Accuracy:52.4%, Avg loss:1.154366 

Epoch:2
loss = 1.120431 [   64/60000]
loss = 1.039873 [ 6464/60000]
loss = 0.986027 [12864/60000]
loss = 1.048869 [19264/60000]
loss = 0.973168 [25664/60000]
loss = 0.784739 [32064/60000]
loss = 0.869435 [38464/60000]
loss = 0.807244 [44864/60000]
loss = 0.749585 [51264/60000]
loss = 0.856343 [57664/60000]
Test Error: 
 Accuracy:68.5%, Avg loss:0.817646 

Epoch:3
loss = 0.888514 [   64/60000]
loss = 0.693393 [ 6464/60000]
loss = 0.618818 [12864/60000]
loss = 0.922308 [19264/60000]
loss = 0.680940 [25664/60000]
loss = 0.702810 [32064/60000]
loss = 0.493246 [38464/60000]
loss = 0.791322 [44864/60000]
loss = 0.598319 [51264/60000]
loss =