## 5.8 网络中的网络（NiN）

### 5.8.1 NiN 快

In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K
from livelossplot.tf_keras import PlotLossesCallback

In [2]:
def nin_block(num_channels, kernel_size, strides, padding):
    def layer(xn):
        xn = tf.keras.layers.Conv2D(
            num_channels, kernel_size, strides=strides, padding=padding, activation='relu')(xn)
        xn = tf.keras.layers.Conv2D(num_channels, (1, 1), activation='relu')(xn)
        xn = tf.keras.layers.Conv2D(num_channels, (1, 1), activation='relu')(xn)
        return xn
    return layer
        

### 5.8.2 NiN模型

In [3]:
inputs = tf.keras.layers.Input((28, 28, 1))
x = tf.keras.layers.Lambda(lambda img: tf.image.resize(img, (224, 224)), input_shape=(28, 28, 1))(inputs)
x = nin_block(96, (11, 11), strides=4, padding='valid')(x)
x = tf.keras.layers.MaxPool2D(3, strides=2)(x)
x = nin_block(256, (5, 5), strides=1, padding='same')(x)
x = tf.keras.layers.MaxPool2D(3, strides=2)(x)
x = nin_block(384, (3, 3), strides=1, padding='same')(x)
x = tf.keras.layers.MaxPool2D(3, strides=2)(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = nin_block(10, (3, 3), strides=1, padding='same')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Softmax()(x)

model = tf.keras.Model(inputs, x)
model.summary()

W1019 18:38:08.358957 139854265767744 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
lambda (Lambda)              (None, 224, 224, 1)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 54, 54, 96)        11712     
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 54, 54, 96)        9312      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 54, 54, 96)        9312      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 26, 26, 96)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 26, 26, 256)       614656

### 5.8.3 获取数据和训练模型

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train / 255.

In [None]:
def metric_accuracy(y_true, y_pred):
    ytrue = K.flatten(y_true)
    ypred = K.cast(K.argmax(y_pred, axis=-1), K.floatx())
    acc = K.equal(ytrue, ypred)
    return K.mean(acc)

model.compile(optimizer=tf.keras.optimizers.Adam(), 
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=[metric_accuracy])

model.fit(
    x_train.reshape((x_train.shape[0], 28, 28, 1)), 
    y_train, epochs=5, batch_size=128,
    #validation_data=(x_test.reshape(x_test.shape[0], 28, 28, 1), y_test),
    validation_split=0.3,
    callbacks=[PlotLossesCallback()])

Train on 42000 samples, validate on 18000 samples
