In [None]:
import torch
from torch import nn 
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Resize

# VGG块构造

In [None]:
def vgg_block(num_convs, in_channels, out_channels):
    # vgg块的卷积层kernel_size设置为3，padding设置为1，output = (input + 2 * 1 - 3) / 1 + 1 不变
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU())
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 这里会使宽高减半
    return nn.Sequential(*blk)

In [None]:
# 设置卷积层和全连接层的参数
conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))
# 经过5个vgg_block, 宽高会减半5次, 变成 224/32 = 7
fc_features = 512 * 7 * 7 # c * w * h
fc_hidden_units = 4096 # 任意


# VGG网络构造

In [None]:
def vgg(conv_arch, fc_features, fc_hidden_units=4096):
    net = nn.Sequential()
    # 卷积层部分
    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
        # 每经过一个vgg_block都会使宽高减半
        net.add_module("vgg_block_" + str(i+1), vgg_block(num_convs, in_channels, out_channels))

    # 全连接层部分
    net.add_module("fc", nn.Sequential(nn.Flatten(),
                                 nn.Linear(fc_features, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, fc_hidden_units),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(fc_hidden_units, 10)
                                ))
    return net


In [None]:
net = vgg(conv_arch, fc_features, fc_hidden_units)
X = torch.rand(1, 1, 224, 224)

# named_children获取一级子模块及其名字(named_modules会返回所有子模块,包括子模块的子模块)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)

# 获取数据和训练模型

In [None]:
# 为了节省训练时间，特地构造一个小通道模型
ratio = 8
small_conv_arch = [(1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), 
                   (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio)]
net = vgg(small_conv_arch, fc_features // ratio, fc_hidden_units // ratio)
print(net)


In [None]:
def load_fasion_mnist_data(batch_size, resize=None):
    trans = []
    if resize:
        trans.append(Resize(resize))
    trans.append(ToTensor())
    transform = Compose(trans)
    train_data = datasets.FashionMNIST(
        root="../fasionMNIST/data", 
        download=False, 
        train=True, 
        transform=transform
    )
    test_data = datasets.FashionMNIST(
        root="../fasionMNIST/data", 
        download=False, 
        train=True, 
        transform=transform
    )
    train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=8)
    test_dataloader = DataLoader(test_data, batch_size, shuffle=True, num_workers=8)
    return train_dataloader, test_dataloader

In [None]:
train_dataloader, test_dataloader = load_fasion_mnist_data(64, 224)

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:  
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:.6f}  [{current:5d}/{size:5d}]")

def test(dataloader, model, loss_fn):
    model.eval()
    loss, correct = 0, 0
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).float().sum().item()
        loss /= num_batches
        correct /= size
    print("Avg Loss: {:.6f}\tAccuracy: {:.2f}%".format(loss, correct*100))        

num_epochs = 5
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
for epoch in range(num_epochs):
    print("Epoch {}:\n----------".format(epoch+1)) 
    train(train_dataloader, net, loss_fn, optimizer)
    test(test_dataloader, net, loss_fn)