# Saving and restoring a TensorFlow model

We start by loading the necessary libraries.

In [None]:
import tensorflow as tf

Then, we’ll build a MNIST model using the Keras Sequential API.

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize
x_train = x_train / 255
x_test = x_test/ 255

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(name="FLATTEN"))
model.add(tf.keras.layers.Dense(units=128 , activation="relu", name="D1"))
model.add(tf.keras.layers.Dense(units=64 , activation="relu", name="D2"))
model.add(tf.keras.layers.Dense(units=10, activation="softmax", name="OUTPUT"))
    
model.compile(optimizer="sgd", 
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"]
             )

model.fit(x=x_train, 
          y=y_train, 
          epochs=5,
          validation_data=(x_test, y_test)
         ) 

## Save and restore an entire model as SavedModel format

The SavedModel is the recommended format for save an entire model to disk.

In [None]:
model.save("SavedModel")

In [None]:
model2 = tf.keras.models.load_model("SavedModel")

## Save and restore an entire model as Keras H5 format

We can either passing a filename that ends in `.h5` or adding the `save_format="h5"` argument.

In [None]:
model.save("SavedModel.h5")

In [None]:
model.save("model_save", save_format="h5")

## Save and restore weights a TensorFlow Checkpoint

We can also use a `ModelCheckpoint` callback in order to save an entire model or just the weights into a checkpoint structure.
This callback is added to the callback argument in the `fit` method.

The model weights will be stored every epoch.

In [None]:
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="./checkpoint",
                                                         save_weights_only=True,
                                                         save_freq='epoch')

In [None]:

model.fit(x=x_train, 
          y=y_train, 
          epochs=5,
          validation_data=(x_test, y_test),
          callbacks=[checkpoint_callback]
         ) 

In [None]:
model.load_weights("./checkpoint")