## 5.12 稠密连接网络（DenseNet）
### 5.12.1 稠密块

In [25]:
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from livelossplot.tf_keras import PlotLossesCallback

In [2]:
def conv_block(x, num_channels):
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Conv2D(num_channels, kernel_size=3, padding='same')(x)
    return x

In [16]:
class DenseBlock(keras.layers.Layer):
    def __init__(self, num_convs, num_channels, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super(DenseBlock, self).__init__(trainable, name, dtype, dynamic, **kwargs)
        self.num_conv = num_convs
        self.num_channels = num_channels
        
    def call(self, inputs, **kwargs):
        x = inputs
        for _ in range(self.num_conv):
            Y = conv_block(x, self.num_channels)
            x = K.concatenate([x, Y])
        return x
    

In [18]:
inp = keras.Input((224, 224, 1))
dense = DenseBlock(2, 10)(inp)
model = keras.Model(inp, dense)
model.output_shape

### 5.12.2 过渡层

In [19]:
def transition_block(x, num_channels):
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Conv2D(num_channels, kernel_size=1)(x)
    x = keras.layers.AvgPool2D(pool_size=2, strides=2)(x)
    return x

In [20]:
inp = keras.Input((224, 224, 1))
tran = transition_block(inp, 10)
model = keras.Model(inp, tran)
model.output_shape

### 5.12.3 DenseNet模型

In [30]:
inp = keras.Input((28, 28, 1))
x = keras.layers.Lambda(lambda img: tf.image.resize(img, (96, 96)), input_shape=(28, 28, 1))(inp)
x = keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)

In [31]:
num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    x = DenseBlock(num_convs, growth_rate)(x)
    # 上一个稠密块的输出通道数
    num_channels += num_convs * growth_rate
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        num_channels //= 2
        x = transition_block(x, num_channels)
        

In [32]:
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.GlobalAvgPool2D()(x)
x = keras.layers.Dense(10)(x)
x = keras.layers.Softmax()(x)

model = keras.Model(inp, x)

### 5.12.4 获取数据并训练模型

In [33]:
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train / 255.

In [None]:
def metric_accuracy(y_true, y_pred):
    ytrue = K.flatten(y_true)
    ypred = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
    acc = K.equal(ytrue, ypred)
    return K.mean(acc)

model.compile(optimizer=keras.optimizers.Adam(), 
              loss=keras.losses.sparse_categorical_crossentropy,
              metrics=[metric_accuracy])

model.fit(
    x_train.reshape((x_train.shape[0], 28, 28, 1)), 
    y_train, epochs=5, batch_size=128,
    #validation_data=(x_test.reshape(x_test.shape[0], 28, 28, 1), y_test),
    validation_split=0.3,
    callbacks=[PlotLossesCallback()])