<a href="https://colab.research.google.com/github/DongheeKang/MachineLearning/blob/master/Tensorflow2_save_and_load_distributed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to save and load a model using a distribution strategy

## Overview

how do I can save and load models in a SavedModel format with `tf.distribute.Strategy` during or after training. 

  * `tf.keras.Model.save`, `tf.keras.models.load_model`
  * `tf.saved_model.save`, `tf.saved_model.load`





In [2]:
import tensorflow_datasets as tfds

import tensorflow as tf


In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [4]:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m
Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f242a3de6d0>

## Save and load the model

Let's now explore the saving/loading APIs. 
There are two kinds of APIs available:

*   High-level (Keras): `Model.save` and `tf.keras.models.load_model`
*   Low-level: `tf.saved_model.save` and `tf.saved_model.load`


### The Keras API

In [5]:
keras_model_path = '/tmp/keras_save'
model.save(keras_model_path)

INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


Restore the model without `tf.distribute.Strategy`:

In [6]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f242ab50890>

After restoring the model, you can continue training on it, even without needing to call `Model.compile` again, since it was already compiled before saving. The model is saved in TensorFlow's standard `SavedModel` proto format.

Now, restore the model and train it using a `tf.distribute.Strategy`:

In [7]:
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Epoch 1/2
Epoch 2/2


As the `Model.fit` output shows, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. 

### The `tf.saved_model` API

In [9]:
model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


Loading can be done with `tf.saved_model.load`. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

In [10]:
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The `"serving_default"` key is the default key for the inference function with a saved Keras model. To do inference with this function: 

In [11]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.02800646e-01,  6.70271814e-02,  6.90521002e-02,
         5.75432032e-02, -4.39634696e-02,  8.99827387e-03,
         8.54578465e-02,  2.14327965e-02,  1.82602569e-04,
         1.30286179e-02],
       [-1.51366323e-01,  2.36753188e-02,  1.81642938e-02,
        -1.41145885e-01, -2.01356143e-01, -6.08870909e-02,
        -1.80455938e-01, -3.05397213e-02, -5.42015061e-02,
        -6.75992668e-02],
       [-1.83226258e-01,  2.72723567e-02, -3.34406793e-02,
        -3.96496840e-02,  9.36522335e-02, -3.71609330e-02,
         1.22122101e-01, -3.55634578e-02, -2.24769711e-01,
         8.45894516e-02],
       [-1.81401655e-01, -4.09738161e-02,  1.41574349e-02,
        -1.11232266e-01, -6.90684281e-03, -2.02334151e-02,
         2.01331303e-02, -4.43165489e-02, -6.57890961e-02,
         1.50010571e-01],
       [-2.44683117e-01,  6.53201491e-02, -1.53596955e-03,
        -7.09154271e-03,  3.79599631e-02, -2.30212793e-01,
        

load and do inference in a distributed manner:

In [12]:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.02800646e-01,  6.70271814e-02,  6.90521002e-02,
         5.75432032e-02, -4.39634696e-02,  8.99827387e-03,
         8.54578465e-02,  2.14327965e-02,  1.82602569e-04,
         1.30286179e-02],
       [-1.51366323e-01,  2.36753188e-02,  1.81642938e-02,
        -1.41145885e-01, -2.01356143e-01, -6.08870909e-02,
        -1.80455938e-01, -3.05397213e-02, -5.42015061e-02,
        -6.75992668e-02],
       [-1.83226258e-01,  2.72723567e-02, -3.34406793e-02,
        -3.96496840e-02,  9.36522335e-02, -3.71609330e-02,
         1.22122101e-01, -3.55634578e-02, -2.24769711e-01,
         8.45894516e-02],
       [-1.81401655e-01, -4.09738161e-02,  1.41574349e-02,
        -1.11232266e-01, -6.90684281e-03, -2.02334151e-02,
         2.01331303e-02, -4.43165489e-02, -6.57890961e-02,
         1.50010571e-01],
       [-2.44683117e-01,  6.53201491e-02, -1.53596955e-03,
        -7.09154271e-03,  3.79599631e-02, -2.30212793e-01,
        

Calling the restored function is just a forward pass on the saved model (`tf.keras.Model.predict`). 

How to continue training the loaded function? 

If I need to embed the loaded function into a bigger model? 

Answer is wrapping the model into a Keras layer
A common practice is to wrap this loaded object into a Keras layer to achieve this. 

Use tensorflow_hub!

In [13]:
import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


Epoch 1/2
Epoch 2/2


In the above example, Tensorflow Hub's `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load` into a Keras layer that is used to build another model. This is very useful for transfer learning. 

### Which API should I use?

Keras model -> `tf.keras.Model.save`  
not a Keras model -> `tf.saved_model.save`

Keras model -> `tf.keras.models.load_model`  
not a Keras model -> `tf.saved_model.load`



In [14]:
model = get_model()

# Saving the model using Keras `Model.save`
model.save(keras_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)

INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets






INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


### Saving/Loading from a local device

When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option `experimental_io_device` in `tf.saved_model.SaveOptions` and `tf.saved_model.LoadOptions` to set the I/O device to `localhost`. For example:

In [15]:
model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets






INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


### Caveats

One special case is when I create Keras models in certain ways, and then save them before training. For example:

In [16]:
class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(keras_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)





ValueError:  Model <__main__.SubclassedModel object at 0x7f242a208950> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.


A SavedModel saves the `tf.types.experimental.ConcreteFunction` objects generated when I trace a `tf.function`. If I get a `ValueError` it's because `Model.save` was not able to find or create a traced `ConcreteFunction`.

**Caution:** I shouldn't save a model without at least one `ConcreteFunction`, since the low-level API will otherwise generate a SavedModel with no `ConcreteFunction` signatures 

In [17]:
tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures









INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


_SignatureMap({})


Usually the model's forward pass—the `call` method—will be traced automatically when the model is called for the first time, often via the Keras `Model.fit` method. A `ConcreteFunction` can also be generated by the Keras [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) and [Functional](https://www.tensorflow.org/guide/keras/functional) APIs, if I set the input shape, for example, by making the first layer either a `tf.keras.layers.InputLayer` or another layer type, and passing it the `input_shape` keyword argument. 

To verify if model has any traced `ConcreteFunction`s, check if `Model.save_spec` is `None`:

In [18]:
print(my_model.save_spec() is None)

True


Let's use `tf.keras.Model.fit` to train the model, and notice that the `save_spec` gets defined and model saving will work:

In [19]:
BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(keras_model_path)

Epoch 1/2
Epoch 2/2
False
INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets
