Author: Eugene Su

Email: su.eugene@gmail.com

https://sites.google.com/view/smartrobot/lab

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# TensorFlow Keras建立model，有Sequential API和Functional API兩種方式

## Sequential API

使用Sequential API建立模型，優點是簡單，缺點是限制就多，限制如下，

*   除了輸入層可以稍後設定，其它層必須逐層設置網路
*   只能支援單輸入和單輸出的模型



**方法一**: 無參數初始化tf.keras.Sequential()，之後呼叫add()依序增加各層網路

In [None]:
model = keras.models.Sequential()
model.add(layers.InputLayer(input_shape=(28, 28, 1)))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


**方法二**: 設定input_shape的方式取代tf.keras.layers.InputLayer

In [None]:
model = keras.models.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               100480    
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


**方法三**: 呼叫build()的方式設定輸入層

注意，此方法輸入層的維度必須包含batch size

In [None]:
model = keras.models.Sequential()
model.add(layers.Flatten())
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.build(input_shape=(None, 28, 28, 1))

model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 128)               100480    
_________________________________________________________________
dense_5 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


## Functional API
能建立更複雜、更靈活、更有創意的模型，使用方法相對麻煩

In [None]:
input_layer = keras.Input(shape=(28, 28, 1), name='input')
flatten_layer = layers.Flatten()(input_layer)
hidden_layer = layers.Dense(128, activation='relu')(flatten_layer)
ouput_layer = layers.Dense(10, activation='softmax')(hidden_layer)

model = keras.Model(inputs=input_layer, outputs=ouput_layer, name='MNIST')
model.summary()

Model: "MNIST"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 128)               100480    
_________________________________________________________________
dense_7 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
