## 5.9 含并行连结的网络（GoogLeNet）

### 5.9.1 Inception 块

In [14]:
import tensorflow as tf
from livelossplot.keras import PlotLossesCallback

In [15]:
class Inception(tf.keras.layers.Layer):
    # c1 - c4为每条线路里的层的输出通道数
    def __init__(self, c1, c2, c3, c4, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
        super(Inception, self).__init__(trainable, name, dtype, dynamic, **kwargs)
        # 线路1，单1 x 1卷积层
        self.p1_1 = tf.keras.layers.Conv2D(c1, (1, 1), activation='relu')
        # 线路2，1 x 1卷积层后接3 x 3卷积层
        self.p2_1 = tf.keras.layers.Conv2D(c2[0], (1, 1), activation='relu')
        self.p2_2 = tf.keras.layers.Conv2D(c2[1], (3, 3), padding='same', activation='relu')
        # 线路3，1 x 1卷积层后接5 x 5卷积层
        self.p3_1 = tf.keras.layers.Conv2D(c3[0], (1, 1), activation='relu')
        self.p3_2 = tf.keras.layers.Conv2D(c3[1], (5, 5), padding='same', activation='relu')
        # 线路4，3 x 3最大池化层后接1 x 1卷积层
        self.p4_1 = tf.keras.layers.MaxPool2D(3, strides=1, padding='same')
        self.p4_2 = tf.keras.layers.Conv2D(c4, (1, 1), activation='relu')
        
    def call(self, inputs, **kwargs):
        p1 = self.p1_1(inputs)
        p2 = self.p2_2(self.p2_1(inputs))
        p3 = self.p3_2(self.p3_1(inputs))
        p4 = self.p4_2(self.p4_1(inputs))
        return tf.keras.backend.concatenate([p1, p2, p3, p4], axis=-1)
    

### 5.9.2. GoogLeNet模型

In [16]:
inputs = tf.keras.layers.Input((28, 28, 1))
a = tf.keras.layers.Lambda(lambda img: tf.image.resize(img, (96, 96)))(inputs)

b1 = tf.keras.layers.Conv2D(64, (7, 7), strides=2, padding='same', activation='relu')(a)
b1 = tf.keras.layers.MaxPool2D(3, strides=2, padding='same')(b1)
        

In [17]:
b2 = tf.keras.layers.Conv2D(64, (1, 1), activation='relu')(b1)
b2 = tf.keras.layers.Conv2D(192, (3, 3), padding='same', activation='relu')(b2)
b2 = tf.keras.layers.MaxPool2D(3, strides=2, padding='same')(b2)

In [18]:
b3 = Inception(64, (96, 128), (16, 32), 32)(b2)
b3 = Inception(128, (128, 192), (32, 96), 64)(b3)
b3 = tf.keras.layers.MaxPool2D(3, strides=2, padding='same')(b3)



In [19]:
b4 = Inception(192, (96, 208), (16, 48), 64)(b3)
b4 = Inception(160, (112, 224), (24, 64), 64)(b4)
b4 = Inception(128, (128, 256), (24, 64), 64)(b4)
b4 = Inception(112, (144, 288), (32, 64), 64)(b4)
b4 = Inception(256, (160, 320), (32, 128), 128)(b4)
b4 = tf.keras.layers.MaxPool2D(3, strides=2, padding='same')(b4)



In [20]:
b5 = Inception(256, (160, 320), (32, 128), 128)(b4)
b5 = Inception(384, (192, 384), (48, 128), 128)(b5)
b5 = tf.keras.layers.GlobalAvgPool2D()(b5)

net = tf.keras.layers.Dense(10)(b5)
model = tf.keras.Model(inputs, net)



In [21]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 96, 96, 2)]       0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 48, 48, 64)        6336      
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 24, 24, 64)        0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 24, 24, 64)        4160      
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 24, 24, 192)       110784    
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 12, 12, 192)       0         
_________________________________________________________________
inception_3 (Inception)      (None, 12, 12, 256)       163696

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train / 255.
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), 
              loss=tf.keras.losses.sparse_categorical_crossentropy)
              #metrics=[d2l.metric_accuracy])

model.fit(x_train, y_train, epochs=5, batch_size=128, 
          callbacks=[PlotLossesCallback()])