In [1]:
# 1.保持序列模型和函数模型

In [2]:
# 构建一个简单的模型并训练

import tensorflow as tf
tf.keras.backend.clear_session()
from tensorflow import keras
from tensorflow.keras import layers

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)

model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')
model.summary()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1)

predictions = model.predict(x_test)


Model: "3_layer_mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 10)                650       
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________


W0710 06:11:54.982882 140640123414272 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Train on 60000 samples


In [None]:
# 1.1保持全模型
# 可以对整个模型进行保存，其保持的内容包括：

# 该模型的架构
# 模型的权重（在训练期间学到的）
# 模型的训练配置（你传递给编译的），如果有的话
# 优化器及其状态（如果有的话）（这使您可以从中断的地方重新启动训练）

In [3]:
import numpy as np
model.save('the_save_model.h5')



In [4]:
new_model = keras.models.load_model('the_save_model.h5')
new_prediction = new_model.predict(x_test)
print(new_prediction)
np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样


[[2.63737980e-04 3.25902676e-07 5.11975493e-04 ... 9.93999839e-01
  1.63902063e-04 3.22614127e-04]
 [6.39909646e-04 1.37444613e-05 9.95352268e-01 ... 1.41465273e-09
  1.21340316e-04 4.18984930e-10]
 [9.75631847e-05 9.88689721e-01 2.55207554e-03 ... 1.67452369e-03
  1.61452102e-03 1.53704765e-04]
 ...
 [1.05973538e-06 2.04239896e-08 8.90532419e-07 ... 1.58701678e-05
  4.01161204e-04 9.76371299e-03]
 [2.22036797e-05 8.02153318e-06 8.83524308e-07 ... 6.46006697e-08
  7.80109968e-03 1.87486614e-06]
 [4.95852646e-06 9.52486354e-11 1.20624591e-05 ... 1.91408264e-10
  7.36301775e-09 1.19961840e-09]]


In [5]:
# 1.2 保持为SavedModel文件
keras.experimental.export_saved_model(model, 'saved_model')
new_model = keras.experimental.load_from_saved_model('saved_model')
new_prediction = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样





W0710 06:14:57.024204 140640123414272 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0710 06:14:57.026279 140640123414272 export_utils.py:182] Export includes no default signature!
W0710 06:14:57.523731 140640123414272 export_utils.py:182] Export includes no default signature!


In [8]:
### 1.3仅保存网络结构
#仅保持网络结构，这样导出的模型并未包含训练好的参数


# ```python
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
new_prediction = reinitialized_model.predict(x_test)
assert abs(np.sum(predictions-new_prediction)) >0

In [9]:
# 也可以使用json保存网络结构
json_config = model.to_json()
reinitialized_model = keras.models.model_from_json(json_config)
new_prediction = reinitialized_model.predict(x_test)
assert abs(np.sum(predictions-new_prediction)) >0


In [10]:
# 1.4仅保存网络参数
weights = model.get_weights()
model.set_weights(weights)


In [11]:
# 可以把结构和参数保存结合起来
config = model.get_config()
weights = model.get_weights()
new_model = keras.Model.from_config(config) # config只能用keras.Model的这个api
new_model.set_weights(weights)
new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)


In [16]:
# 1.5完整的模型保持方法


In [17]:
json_config = model.to_json()
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)

model.save_weights('path_to_my_weights.h5')

with open('model_config.json') as json_file:
    json_config = json_file.read()
new_model = keras.models.model_from_json(json_config)
new_model.load_weights('path_to_my_weights.h5')

new_predictions = new_model.predict(x_test)
np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)


In [18]:
# 当然也可以一步到位
model.save('path_to_my_model.h5')
del model
model = keras.models.load_model('path_to_my_model.h5')


W0710 06:21:13.239060 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay
W0710 06:21:13.241223 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0710 06:21:13.243048 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer.momentum
W0710 06:21:13.244048 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer.rho
W0710 06:21:13.244863 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.kernel
W0710 06:21:13.245847 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-0.bias
W0710 06:21:13.246569 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'rms' for (root).layer_with_weights-1.kernel
W0710 06:21:13.247335 140640123414272 util.py:244] Unresolved object in checkpoint: (root).optimizer's 

In [19]:
# 1.6保存网络权重为SavedModel格式

In [20]:
model.save_weights('weight_tf_savedmodel')
model.save_weights('weight_tf_savedmodel_h5', save_format='h5')


In [21]:
# 1.7子类模型参数保存
# 子类模型的结构无法保存和序列化，只能保持参数

In [22]:
# 构建模型
class ThreeLayerMLP(keras.Model):
  
    def __init__(self, name=None):
        super(ThreeLayerMLP, self).__init__(name=name)
        self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')
        self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')
        self.pred_layer = layers.Dense(10, activation='softmax', name='predictions')

    def call(self, inputs):
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        return self.pred_layer(x)

def get_model():
    return ThreeLayerMLP(name='3_layer_mlp')

model = get_model()


In [23]:
# 训练模型
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1)


Train on 60000 samples


In [24]:
# 保持权重参数
model.save_weights('my_model_weights', save_format='tf')

# 输出结果，供后面对比

predictions = model.predict(x_test)
first_batch_loss = model.train_on_batch(x_train[:64], y_train[:64])

