## 4.1. 模型构造
### 4.1.1 继承Model类来构造模型

In [1]:
import keras
import keras.backend as K
import numpy as np

class MLP(keras.Model):
    def __init__(self, *args, **kwargs):
        super(MLP, self).__init__(*args, **kwargs)
        self._hidden = keras.layers.Dense(256, activation='relu')  # 隐藏层
        self._output = keras.layers.Dense(10)  # 输出层
        
    def call(self, inputs, mask=None):
        return self._output(self._hidden(inputs))
        

Using TensorFlow backend.


In [2]:
X = np.random.uniform(size=(2, 20))
net = MLP()
net.compile(optimizer='SGD')
net.predict(X)

array([[ 0.31569925,  0.06691933,  0.01224808, -0.14013164, -0.05919157,
         0.25708616, -0.13604207,  0.08747526,  0.04870658, -0.541568  ],
       [ 0.44202986, -0.00570045,  0.04320703, -0.15733904, -0.2288207 ,
         0.18503797, -0.2768906 ,  0.10237173,  0.0254905 , -0.5160129 ]],
      dtype=float32)

### 4.1.2 Sequential类继承自Model类

### 4.1.3 构造复杂的模型

In [3]:
class FancyMLP(keras.Model):
    def __init__(self, *args, **kwargs):
        super(FancyMLP, self).__init__(*args, **kwargs)
        self.rand_weight = K.variable(np.random.uniform(size=(20, 20)))
        self.dense = keras.layers.Dense(20, activation='relu')
        
    def call(self, inputs, mask=None):
        x = self.dense(inputs)
        # 使用创建的常数参数，以及backend的relu函数和dot函数
        x = K.relu(K.dot(x, self.rand_weight) + 1)
        # 复用全连接层。等价于两个全连接层共享参数
        x = self.dense(x)
        # 控制流，这里我们需要调用asscalar函数来返回标量进行比较
        x = K.switch(K.greater(K.l2_normalize(x), 1), x/2, x)
        x = K.switch(K.greater(0.8, K.l2_normalize(x)), x*10, x)
        return K.sum(x)
    

In [4]:
net = FancyMLP()
net.compile(optimizer='SGD')
#net.predict(X)

tracking <tf.Variable 'Variable:0' shape=(20, 20) dtype=float32, numpy=
array([[3.7273291e-01, 9.1178727e-01, 1.5650684e-01, 4.4507712e-01,
        1.3424584e-01, 3.9748928e-01, 6.7302930e-01, 1.8150915e-01,
        3.9226794e-01, 3.5764045e-01, 7.4916941e-01, 3.8381547e-01,
        1.8011788e-01, 4.9263665e-01, 8.0566001e-01, 2.5766879e-01,
        3.2081756e-01, 4.5934618e-01, 4.0920606e-01, 9.9059451e-01],
       [2.3059776e-01, 6.0264045e-01, 7.8132749e-02, 1.5034157e-01,
        7.8842467e-01, 1.8663944e-01, 2.7431723e-01, 9.0625089e-01,
        7.2927582e-01, 9.5787877e-01, 5.9977353e-01, 2.2651009e-01,
        8.4545559e-01, 8.0375797e-01, 9.4803047e-01, 6.2814581e-01,
        7.6995414e-01, 1.7266428e-01, 1.5269275e-01, 6.6218454e-01],
       [2.8548551e-01, 4.2093721e-01, 5.5952495e-01, 5.8712727e-01,
        7.4038637e-01, 7.3299044e-01, 6.1662382e-01, 7.0090431e-01,
        2.8839776e-01, 1.2219233e-01, 5.2977294e-01, 7.5558013e-01,
        4.9662212e-01, 8.5082495e-01, 2.82