In [1]:
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import tensorflow as tf
from tensorflow import keras

In [2]:
housing = fetch_california_housing()

X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)

X_valid = scaler.transform(X_valid)
X_test = scaler.transform(X_test)

# Basic Saving & Restoring

In [3]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(30, activation="relu"),
    keras.layers.Dense(1)
])

In [4]:
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(X_train, y_train, epochs=20,
                    validation_data=(X_valid, y_valid))
model.evaluate(X_test, y_test)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


0.34014034271240234

In [5]:
# Saving a model
model.save("my_model.h5")

In [6]:
# Restoring a model
model = keras.models.load_model("my_model.h5")

In [7]:
# Predict
X_new = X_test[:3]
y_pred = model.predict(X_new)
y_pred

array([[0.83272064],
       [1.847893  ],
       [1.5166662 ]], dtype=float32)

In [8]:
# Only save and load weights
model.save_weights("my_model_weights.ckpt")
model.load_weights("my_model_weights.ckpt")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1109e1c10>

### The difference between save and save_weights

- save(): 将整个模型保存下来(占用空间大)。以后直接载入模型文件即可开始使用,不用再定义网络结构和编译模型, 这种方法已经保存了模型的结构和权重, 以及损失函数和优化器

- save_weights(): 只保存模型的权重, 但并没有保存模型的图结构(占用空间小)。使用时需要首先定义一个和训练时结构相同的model，然后使用定义的model加载weights。

# Advanced Saving skills

The save() function save the model after training. But in real-world projects the training usually lasts several hours. So we should not only save the model at the end of the training, but also save checkpoints at regular intervals during training, to avoid losing things if the computer crashes. 

We can use "callbacks" tell fit() method to save checkpoints at different intervals.

## 1. At the end of each epoch

In [9]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(30, activation="relu"),
    keras.layers.Dense(1)
])

model.compile(loss="mean_squared_error", optimizer="sgd")

In [10]:
# the ModelCheckpoint callback saves checkpoints of your model at regular intervals during training
# by default at the end of each epoch
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_model.h5")
history = model.fit(X_train, y_train, epochs=10,
                    callbacks=[checkpoint_cb])

model = keras.models.load_model("my_model.h5")
model.evaluate(X_test, y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


0.36623772978782654

## 2. Save the best model so far (depends on the evaluation on the validation set)

In [11]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(30, activation="relu"),
    keras.layers.Dense(1)
])

model.compile(loss="mean_squared_error", optimizer="sgd")

In [12]:
# Only save the best model. In order to do this, we need the "Validation dataset"
checkpoint_best_cb = keras.callbacks.ModelCheckpoint("my_best_model.h5", save_best_only=True)
history = model.fit(X_train, y_train, validation_data=(X_valid, y_valid),
                    epochs=10, callbacks=[checkpoint_best_cb])

model = keras.models.load_model("my_best_model.h5")
model.evaluate(X_test, y_test)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


0.3459774851799011

## 3. Interrupt training when no progress anymore

In [13]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(30, activation="relu"),
    keras.layers.Dense(1)
])

model.compile(loss="mean_squared_error", optimizer="sgd")

In [14]:
# Interrupt training when it measures no progress on the validation set for a number of epochs (defined by the patience argument). 
# Can combine both callbacks(save_best_only and early_stop) to save checkpoints of your model (in case your computer crashes) 
# and interrupt training early when there is no more progress (to avoid wasting time and resources)

early_stop_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
history = model.fit(X_train, y_train, validation_data=(X_valid, y_valid), 
                    epochs=200, callbacks=[checkpoint_best_cb, early_stop_cb])
model.evaluate(X_test, y_test)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200


0.36273834109306335

## 4. Custom our own callbacks

In [21]:
class DIYCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs):
        print("\n***** Begin of Epoch {} *****".format(epoch+1))
    def on_epoch_end(self, epoch, logs):
        print("\n***** End of Epoch {} *****".format(epoch+1))

In [22]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(30, activation="relu"),
    keras.layers.Dense(1)
])

model.compile(loss="mean_squared_error", optimizer="sgd")

In [23]:
diy_cb = DIYCallback()
history = model.fit(X_train, y_train, epochs=5,
                    validation_data=(X_valid, y_valid),
                    callbacks=[diy_cb])


***** Begin of Epoch 1 *****
Epoch 1/5
***** End of Epoch 1 *****

***** Begin of Epoch 2 *****
Epoch 2/5
***** End of Epoch 2 *****

***** Begin of Epoch 3 *****
Epoch 3/5
***** End of Epoch 3 *****

***** Begin of Epoch 4 *****
Epoch 4/5
***** End of Epoch 4 *****

***** Begin of Epoch 5 *****
Epoch 5/5
***** End of Epoch 5 *****
