# 基于 VGG-11 的 FashionMNIST 数据集的分类
## 1. 导入 FashionMNIST 数据集
首先我们下载 FashionMNIST， 这需要定义一个导入函数 load_FashionMNIST_dataset 和类别标签函数 get_FashionMNIST_label

In [None]:
# imports
import torch
from torch import nn
from torch import functional
from torch import optim
from torch.utils import data
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter("runs/FashionMNIST/VGG")

import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt


ROOT = './data'

def load_FashionMNIST_dataset(BatchSize, root=ROOT):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(0.5, 0.5)
    ])

    trainset = torchvision.datasets.FashionMNIST(root, train=True, transform=transform, download=True)
    trainloader = data.DataLoader(trainset,batch_size=BatchSize,shuffle=True, num_workers=2)

    testset = torchvision.datasets.FashionMNIST(root, train=False, transform=transform, download=True)
    testloader = data.DataLoader(testset,batch_size=BatchSize,shuffle=False,num_workers=2)

    return trainloader, testloader

labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']

In [None]:
def imshow(img):
    img = img/2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))

In [None]:
BatchSize = 128
trains,tests = load_FashionMNIST_dataset(BatchSize)

trainiter = iter(trains)
X, y = next(trainiter)

img_grid = torchvision.utils.make_grid(X)
print(img_grid.shape)
imshow(img_grid)
print([labels[idx] for idx in y])

# writer.add_images("Some Samples",X)
# writer.close()

ValueError: ignored

## 2. VGG
### 基本架构
VGG 的输入输出与全连接层与 AlexNet基本相同，不同的是 VGG 引入 5 个卷积块，实现了块状设计。其基本架构分为两部分：第一部分主要由卷积层和汇聚层组成的卷积块，第二部分由全连接层组成。


输入：图片（$3\times 224\times 224$）

模块一（卷积块1-2）：
- $3\times 3$ 卷积（64），填充为 1 (64 @ 224*224)
- ReLU 函数激活
- $2\times 2$ 最大汇聚，步幅为 2 （64 @ 112*112）
- $3\times 3$ 卷积（128），填充为 1 （128 @ 112*112）
- ReLU 函数激活
- $2\times 2$ 最大汇聚，步幅为 2 （128 @ 56*56）

模块二（卷积块3-5）：
- $3\times 3$ 卷积（256），填充为 1 （256 @ 56*56）
- ReLU 函数激活
- $3\times 3$ 卷积（256），填充为 1 （256 @ 56*56）
- ReLU 函数激活
- $2\times 2$ 最大汇聚，步幅为 2 （256 @ 28*28）
- $3\times 3$ 卷积（512），填充为 1 （512 @ 28*28）
- ReLU 函数激活
- $3\times 3$ 卷积（512），填充为 1 （512 @ 28*28）
- ReLU 函数激活
- $2\times 2$ 最大汇聚，步幅为 2 （512 @ 14*14）
- $3\times 3$ 卷积（512），填充为 1 （512 @ 14*14）
- ReLU 函数激活
- $3\times 3$ 卷积（512），填充为 1 （512 @ 14*14）
- ReLU 函数激活
- $2\times 2$ 最大汇聚，步幅为 2 （512 @ 7*7）
- Flatten 展平
- 全连接层（512\*7*7，4096）
- ReLU 函数激活
- Dropout(0.5)
- 全连接层（4096，4096）
- ReLU 函数激活
- Dropout(0.5)
- 全连接层（4096，1000）
- softmax 函数分类输出

输出：1000 个类别样本



In [None]:
def vgg_block(num_convs, in_chanels, out_chanels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_chanels, out_chanels,
        kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_chanels = out_chanels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

conv_archs = ((1,64), (1,128), (2,256), (2,512), (2,512))

def VGG11(conv_archs):
    conv_blks = []
    in_chanel = 1
    for (num_convs, out_chanel) in conv_archs:
        conv_blks.append(vgg_block(num_convs, in_chanel, out_chanel))
        in_chanel = out_chanel

    return nn.Sequential(
        *conv_blks, 
        nn.Flatten(), 
        nn.Linear(512*7*7,4096),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(4096,4096),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(4096,10))

vgg11 = VGG11(conv_archs)
# writer.add_graph(vgg11,X)
# writer.close()

In [None]:
def train(epochs,net, criterion, opt, train_set,device):
    def init_weights(m):
        if type(m) == nn.Conv2d or type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
            
    net.apply(init_weights)
    print("Training on!")

    net.to(device)

    for epoch in range(epochs):
        running_loss = 0.0
        net.train()
        for i, (X,y) in enumerate(train_set,0):
            opt.zero_grad()
            X,y = X.to(device), y.to(device)
            pred = net(X)
            loss = criterion(pred, y)
            running_loss += loss.item()
            loss.backward()
            opt.step()
        if i % 100 ==99:
            print(f"epoch {epoch}, i = {i}: loss={running_loss/100}")
            running_loss = 0
    print("Finish Training")
    return net.eval()

epochs = 2
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device("cpu")
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg11.parameters(), lr=0.1, momentum=0.9)

net = train(epochs,vgg11,criterion,optimizer,trains,device)

Training on!
Finish Training


In [None]:
def accuracy(net, datasets, device):
    total = 0
    correct = 0
    with torch.no_grad():
        for _,(X,y) in enumerate(datasets,0):
            X,y = X.to(device),y.to(device)
            outputs = net(X)
            _, preds = torch.max(outputs, 1)
            total += y.size(0)
            correct += (preds==y).sum().item()
    return correct/total

print("Accuracy of train sets:", accuracy(net, trains,device))
print("Accuracy of test sets:", accuracy(net, tests, device))

Accuracy of train sets: 0.1
Accuracy of test sets: 0.1


In [None]:
def pred_to_probs(net, X):
    output = net(X)
    _,preds_tensor = torch.max(output, 1)
    preds = np.squeeze(preds_tensor.numpy())
    return preds, [torch.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]

def plot_class_preds(net, X, y):
    preds, probs = pred_to_probs(net, X)
    fig = plt.figure(figsize=(12,48))
    for idx in np.arange(4):
        ax = fig.add_subplot(1,4, idx+1, xticks=[], yticks=[])
        img_grid = torchvision.utils.make_grid(X[idx])
        imshow(img_grid)
        ax.set_title("{0},{1:.1f}%\n(label: {2})".format(
            labels[preds[idx]],
            probs[idx]*100,
            labels[y[idx]]),
            color = ("green" if preds[idx] == y[idx].item() else "red"))
    return fig

testiter = iter(tests)
X, y = next(testiter)
net.cpu()
plot_class_preds(net,X, y)