## 5.7 VGG

### 5.7.1 VGG 块

In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from livelossplot.tf_keras import PlotLossesCallback
from  skimage import transform
import numpy as np

In [None]:
def vgg_block(x, num_conv, num_channels):
    for _ in range(num_conv):
        x = tf.keras.layers.Conv2D(
            num_channels, kernel_size=(3, 3), 
            padding='same', activation='relu')(x)
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.MaxPool2D((2, 2), 2)(x)
    return x

### 5.7.2 VGG 网络

In [None]:
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

In [None]:
def vgg(inputs, conv_arch):
    x = inputs
    # 卷积层部分
    for (num_convs, num_channels) in conv_arch:
        x = vgg_block(x, num_convs, num_channels)
    # 全连接层部分
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(4096, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(4096, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(10)(x)
    return x

inputs = tf.keras.Input((28, 28, 1))
y = tf.keras.layers.Lambda(lambda img: tf.image.resize(img, (224, 224)), input_shape=(28, 28, 1))(inputs)
y = vgg(y, conv_arch)
net = tf.keras.Model(inputs, y)
net.summary()

### 5.7.3 获取数据和训练模型

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

In [None]:
def metric_accuracy(y_true, y_pred):
    ytrue = K.flatten(y_true)
    ypred = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
    acc = K.equal(ytrue, ypred)
    return K.mean(acc)

net.compile(optimizer=keras.optimizers.Adam(), 
              loss=keras.losses.sparse_categorical_crossentropy,
              metrics=[metric_accuracy])

net.fit(x_train.reshape(x_train.shape[0], 28, 28, 1), y_train, epochs=5, batch_size=128,
          validation_data=(x_test.reshape(x_test.shape[0], 28, 28, 1), y_test),
          callbacks=[PlotLossesCallback()])