In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models, callbacks
from tensorflow.keras.datasets import mnist
import os

# load the mnist data

In [2]:
file_path = os.path.abspath('./mnist.npz')
(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data(path=file_path)
train_y, test_y = train_y[:1000], test_y[:1000]
train_x = train_x[:1000].reshape(-1, 28*28) / 255.0
test_x = test_x[:1000].reshape(-1, 28*28) / 255.0

# build the model and evaluate model

In [3]:
def build_model():
    model = keras.Sequential([
        layers.Dense(512, activation='relu', input_shape=(784, )),
        layers.Dropout(0.2),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', metrics=['accuracy'], loss='sparse_categorical_crossentropy')
    return model

def evaluate(target_model):
    _, acc = target_model.evaluate(test_x, test_y)
    print("Restore model, accuracy: {:5.2f}%".format(100*acc))

## 自动保存checkpoints
> tf.keras.callbacks.ModelCheckpoint 回调可以实现这一点

In [4]:
# 存储模型的文件名，语法与 str.format 一致
# period=10：每 10 epochs 保存一次
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True, period=10)
model = build_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_x, train_y, epochs=50, callbacks=[cp_callback], validation_data=(test_x, test_y), verbose=0)

W1223 23:54:32.870669 4610692544 callbacks.py:863] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.



Epoch 00010: saving model to training_2/cp-0010.ckpt

Epoch 00020: saving model to training_2/cp-0020.ckpt

Epoch 00030: saving model to training_2/cp-0030.ckpt

Epoch 00040: saving model to training_2/cp-0040.ckpt

Epoch 00050: saving model to training_2/cp-0050.ckpt


<tensorflow.python.keras.callbacks.History at 0xb3000cfd0>

## 加载权重

In [5]:
latest = tf.train.latest_checkpoint(checkpoint_dir)

# 最近训练模型的权重 'training_2/cp-0050.ckpt'
model = build_model()
model.load_weights(latest)
evaluate(model)



Restore model, accuracy: 87.20%


## 手动保存权重

In [6]:
# 手动保存权重
model.save_weights("./checkpoints/mannul_checkpoint")
model = build_model()
model.load_weights("./checkpoints/mannul_checkpoint")
evaluate(model)

W1223 23:54:48.363552 4610692544 util.py:144] Unresolved object in checkpoint: (root).optimizer.iter
W1223 23:54:48.365276 4610692544 util.py:144] Unresolved object in checkpoint: (root).optimizer.beta_1
W1223 23:54:48.366477 4610692544 util.py:144] Unresolved object in checkpoint: (root).optimizer.beta_2
W1223 23:54:48.372128 4610692544 util.py:144] Unresolved object in checkpoint: (root).optimizer.decay
W1223 23:54:48.375066 4610692544 util.py:144] Unresolved object in checkpoint: (root).optimizer.learning_rate




Restore model, accuracy: 87.20%


## 保存整个模型

In [7]:
model.save("my_model.h5")

## 使用HDF5中恢复完整的模型

In [8]:
new_model = models.load_model("my_model.h5")
evaluate(new_model)

W1223 23:54:53.261749 4610692544 hdf5_format.py:198] Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.




Restore model, accuracy: 87.20%


## saved model

In [9]:
import time
saved_model_path = "./saved_model/{}".format(int(time.time()))
tf.keras.experimental.export_saved_model(model, saved_model_path)

W1223 23:54:55.601330 4610692544 deprecation.py:323] From <ipython-input-9-56d34e489058>:3: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.
W1223 23:54:55.677302 4610692544 deprecation.py:506] From /Users/sunchao/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
W1223 23:54:56.168562 4610692544 deprecation.py:323] From /Users/sunchao/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_mode

## 恢复模型并进行预测

In [10]:
new_model = tf.keras.experimental.load_from_saved_model(saved_model_path)
model.predict(test_x).shape

W1223 23:54:59.141978 4610692544 deprecation.py:323] From <ipython-input-10-a3de6acb5550>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.


(1000, 10)

In [11]:
new_model.compile(optimizer=model.optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
evaluate(new_model)



Restore model, accuracy: 87.20%
