##### Copyright 2019 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Keras 的分布式训练

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://tensorflow.google.cn/tutorials/distribute/keras"><img src="https://tensorflow.google.cn/images/tf_logo_32px.png" />在 tensorflow.google.cn 上查看</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/keras.ipynb"><img src="https://tensorflow.google.cn/images/colab_logo_32px.png" />在 Google Colab 运行</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/keras.ipynb"><img src="https://tensorflow.google.cn/images/GitHub-Mark-32px.png" />在 Github 上查看源代码</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/distribute/keras.ipynb"><img src="https://tensorflow.google.cn/images/download_logo_32px.png" />下载此 notebook</a>
  </td>
</table>

Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的
[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到
[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入
[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。

## 概述

`tf.distribute.Strategy` API 提供了一个抽象的 API ，用于跨多个处理单元（processing units）分布式训练。它的目的是允许用户使用现有模型和训练代码，只需要很少的修改，就可以启用分布式训练。

本教程使用 `tf.distribute.MirroredStrategy`，这是在一台计算机上的多 GPU（单机多卡）进行同时训练的图形内复制（in-graph replication）。事实上，它会将所有模型的变量复制到每个处理器上，然后，通过使用 [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) 去整合所有处理器的梯度（gradients），并将整合的结果应用于所有副本之中。

`MirroredStategy` 是 tensorflow 中可用的几种分发策略之一。 您可以在 [分发策略指南](../../guide/distribute_strategy.ipynb) 中阅读更多分发策略。


### Keras API

这个例子使用 `tf.keras` API 去构建和训练模型。 关于自定义训练模型，请参阅 [tf.distribute.Strategy with training loops](training_loops.ipynb) 教程。

## 导入依赖

In [2]:
# 导入 TensorFlow 和 TensorFlow 数据集

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

In [3]:
print(tf.__version__)

2.3.0


## 下载数据集

下载 MNIST 数据集并从 [TensorFlow Datasets](https://tensorflow.google.cn/datasets) 加载。 这会返回 `tf.data` 格式的数据集。

将 `with_info` 设置为 `True` 会包含整个数据集的元数据,其中这些数据集将保存在 `info` 中。 除此之外，该元数据对象包括训练和测试示例的数量。 


In [4]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

## 定义分配策略

创建一个 `MirroredStrategy` 对象。这将处理分配策略，并提供一个上下文管理器（`tf.distribute.MirroredStrategy.scope`）来构建你的模型。

In [5]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [6]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## 设置输入管道（pipeline）

在训练具有多个 GPU 的模型时，您可以通过增加批量大小（batch size）来有效地使用额外的计算能力。通常来说，使用适合 GPU 内存的最大批量大小（batch size），并相应地调整学习速率。

In [7]:
# 您还可以执行 info.splits.total_num_examples 来获取总数
# 数据集中的样例数量。

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

0-255 的像素值， [必须标准化到 0-1 范围](https://en.wikipedia.org/wiki/Feature_scaling)。在函数中定义标准化。

In [8]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

将此功能应用于训练和测试数据，随机打乱训练数据，并[批量训练](https://tensorflow.google.cn/api_docs/python/tf/data/Dataset#batch)。 请注意，我们还保留了训练数据的内存缓存以提高性能。


In [9]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## 生成模型

在 `strategy.scope` 的上下文中创建和编译 Keras 模型。

In [10]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

## 定义回调（callback）

这里使用的回调（callbacks）是：

*   *TensorBoard*: 此回调（callbacks）为 TensorBoard 写入日志，允许您可视化图形。
*   *Model Checkpoint*: 此回调（callbacks）在每个 epoch 后保存模型。
*   *Learning Rate Scheduler*: 使用此回调（callbacks），您可以安排学习率在每个 epoch/batch 之后更改。

为了便于说明，添加打印回调（callbacks）以在笔记本中显示*学习率*。

In [11]:
# 定义检查点（checkpoint）目录以存储检查点（checkpoints）

checkpoint_dir = './training_checkpoints'
# 检查点（checkpoint）文件的名称
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [12]:
# 衰减学习率的函数。
# 您可以定义所需的任何衰减函数。
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [13]:
# 在每个 epoch 结束时打印LR的回调（callbacks）。
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [14]:
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

## 训练和评估

在该部分，以普通的方式训练模型，在模型上调用 `fit` 并传入在教程开始时创建的数据集。 无论您是否分布式训练，此步骤都是相同的。

In [15]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


  1/938 [..............................] - ETA: 0s - loss: 2.3194 - accuracy: 0.0938

Instructions for updating:
use `tf.profiler.experimental.stop` instead.


Instructions for updating:
use `tf.profiler.experimental.stop` instead.






  6/938 [..............................] - ETA: 8s - loss: 2.1802 - accuracy: 0.3307

 18/938 [..............................] - ETA: 5s - loss: 1.7413 - accuracy: 0.5686

 29/938 [..............................] - ETA: 5s - loss: 1.4024 - accuracy: 0.6633

 40/938 [>.............................] - ETA: 4s - loss: 1.1742 - accuracy: 0.7121

 51/938 [>.............................] - ETA: 4s - loss: 1.0183 - accuracy: 0.7491

 62/938 [>.............................] - ETA: 4s - loss: 0.9055 - accuracy: 0.7737

 72/938 [=>............................] - ETA: 4s - loss: 0.8273 - accuracy: 0.7912

 83/938 [=>............................] - ETA: 4s - loss: 0.7671 - accuracy: 0.8055

 94/938 [==>...........................] - ETA: 4s - loss: 0.7184 - accuracy: 0.8165

105/938 [==>...........................] - ETA: 4s - loss: 0.6768 - accuracy: 0.8246

116/938 [==>...........................] - ETA: 4s - loss: 0.6391 - accuracy: 0.8322

127/938 [===>..........................] - ETA: 4s - loss: 0.6072 - accuracy: 0.8397

138/938 [===>..........................] - ETA: 3s - loss: 0.5834 - accuracy: 0.8441

149/938 [===>..........................] - ETA: 3s - loss: 0.5583 - accuracy: 0.8501

160/938 [====>.........................] - ETA: 3s - loss: 0.5374 - accuracy: 0.8558

171/938 [====>.........................] - ETA: 3s - loss: 0.5167 - accuracy: 0.8613

182/938 [====>.........................] - ETA: 3s - loss: 0.4999 - accuracy: 0.8653

193/938 [=====>........................] - ETA: 3s - loss: 0.4861 - accuracy: 0.8685

204/938 [=====>........................] - ETA: 3s - loss: 0.4738 - accuracy: 0.8711

215/938 [=====>........................] - ETA: 3s - loss: 0.4607 - accuracy: 0.8748


























































































































Learning rate for epoch 1 is 0.0010000000474974513


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Epoch 2/12
  1/938 [..............................] - ETA: 0s - loss: 0.1368 - accuracy: 0.9844

 20/938 [..............................] - ETA: 2s - loss: 0.1061 - accuracy: 0.9727

 39/938 [>.............................] - ETA: 2s - loss: 0.0834 - accuracy: 0.9780

 58/938 [>.............................] - ETA: 2s - loss: 0.0948 - accuracy: 0.9747

 77/938 [=>............................] - ETA: 2s - loss: 0.0919 - accuracy: 0.9740

 96/938 [==>...........................] - ETA: 2s - loss: 0.0885 - accuracy: 0.9754

115/938 [==>...........................] - ETA: 2s - loss: 0.0893 - accuracy: 0.9743

134/938 [===>..........................] - ETA: 2s - loss: 0.0898 - accuracy: 0.9739

153/938 [===>..........................] - ETA: 2s - loss: 0.0858 - accuracy: 0.9754

172/938 [====>.........................] - ETA: 2s - loss: 0.0828 - accuracy: 0.9761

190/938 [=====>........................] - ETA: 2s - loss: 0.0813 - accuracy: 0.9764

208/938 [=====>........................] - ETA: 1s - loss: 0.0805 - accuracy: 0.9763














































































Learning rate for epoch 2 is 0.0010000000474974513


Epoch 3/12
  1/938 [..............................] - ETA: 0s - loss: 0.0461 - accuracy: 0.9844

 19/938 [..............................] - ETA: 2s - loss: 0.0678 - accuracy: 0.9762

 38/938 [>.............................] - ETA: 2s - loss: 0.0609 - accuracy: 0.9790

 56/938 [>.............................] - ETA: 2s - loss: 0.0566 - accuracy: 0.9807

 75/938 [=>............................] - ETA: 2s - loss: 0.0563 - accuracy: 0.9817

 94/938 [==>...........................] - ETA: 2s - loss: 0.0557 - accuracy: 0.9822

113/938 [==>...........................] - ETA: 2s - loss: 0.0556 - accuracy: 0.9827

132/938 [===>..........................] - ETA: 2s - loss: 0.0570 - accuracy: 0.9828

151/938 [===>..........................] - ETA: 2s - loss: 0.0564 - accuracy: 0.9833

170/938 [====>.........................] - ETA: 2s - loss: 0.0546 - accuracy: 0.9836

188/938 [=====>........................] - ETA: 2s - loss: 0.0551 - accuracy: 0.9836

207/938 [=====>........................] - ETA: 2s - loss: 0.0542 - accuracy: 0.9837














































































Learning rate for epoch 3 is 0.0010000000474974513


Epoch 4/12


  1/938 [..............................] - ETA: 0s - loss: 0.0605 - accuracy: 0.9688

 20/938 [..............................] - ETA: 2s - loss: 0.0450 - accuracy: 0.9875

 39/938 [>.............................] - ETA: 2s - loss: 0.0376 - accuracy: 0.9900

 58/938 [>.............................] - ETA: 2s - loss: 0.0340 - accuracy: 0.9903

 77/938 [=>............................] - ETA: 2s - loss: 0.0348 - accuracy: 0.9894

 96/938 [==>...........................] - ETA: 2s - loss: 0.0338 - accuracy: 0.9901

115/938 [==>...........................] - ETA: 2s - loss: 0.0351 - accuracy: 0.9901

134/938 [===>..........................] - ETA: 2s - loss: 0.0362 - accuracy: 0.9897

153/938 [===>..........................] - ETA: 2s - loss: 0.0351 - accuracy: 0.9904

172/938 [====>.........................] - ETA: 2s - loss: 0.0350 - accuracy: 0.9904

191/938 [=====>........................] - ETA: 1s - loss: 0.0346 - accuracy: 0.9905

210/938 [=====>........................] - ETA: 1s - loss: 0.0362 - accuracy: 0.9902












































































Learning rate for epoch 4 is 9.999999747378752e-05


Epoch 5/12
  1/938 [..............................] - ETA: 0s - loss: 0.0149 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0315 - accuracy: 0.9891

 39/938 [>.............................] - ETA: 2s - loss: 0.0327 - accuracy: 0.9880

 58/938 [>.............................] - ETA: 2s - loss: 0.0344 - accuracy: 0.9884

 78/938 [=>............................] - ETA: 2s - loss: 0.0311 - accuracy: 0.9902

 98/938 [==>...........................] - ETA: 2s - loss: 0.0321 - accuracy: 0.9904

117/938 [==>...........................] - ETA: 2s - loss: 0.0324 - accuracy: 0.9907

136/938 [===>..........................] - ETA: 2s - loss: 0.0324 - accuracy: 0.9906

155/938 [===>..........................] - ETA: 2s - loss: 0.0316 - accuracy: 0.9906

174/938 [====>.........................] - ETA: 2s - loss: 0.0308 - accuracy: 0.9909

192/938 [=====>........................] - ETA: 2s - loss: 0.0317 - accuracy: 0.9907

211/938 [=====>........................] - ETA: 1s - loss: 0.0307 - accuracy: 0.9911














































































Learning rate for epoch 5 is 9.999999747378752e-05


Epoch 6/12


  1/938 [..............................] - ETA: 0s - loss: 0.0556 - accuracy: 0.9844

 20/938 [..............................] - ETA: 2s - loss: 0.0334 - accuracy: 0.9891

 39/938 [>.............................] - ETA: 2s - loss: 0.0302 - accuracy: 0.9912

 57/938 [>.............................] - ETA: 2s - loss: 0.0280 - accuracy: 0.9921

 76/938 [=>............................] - ETA: 2s - loss: 0.0286 - accuracy: 0.9912

 95/938 [==>...........................] - ETA: 2s - loss: 0.0279 - accuracy: 0.9911

114/938 [==>...........................] - ETA: 2s - loss: 0.0289 - accuracy: 0.9907

133/938 [===>..........................] - ETA: 2s - loss: 0.0291 - accuracy: 0.9907

152/938 [===>..........................] - ETA: 2s - loss: 0.0279 - accuracy: 0.9913

171/938 [====>.........................] - ETA: 2s - loss: 0.0277 - accuracy: 0.9920

190/938 [=====>........................] - ETA: 2s - loss: 0.0273 - accuracy: 0.9920

209/938 [=====>........................] - ETA: 1s - loss: 0.0285 - accuracy: 0.9916














































































Learning rate for epoch 6 is 9.999999747378752e-05


Epoch 7/12
  1/938 [..............................] - ETA: 0s - loss: 0.0066 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0387 - accuracy: 0.9906

 40/938 [>.............................] - ETA: 2s - loss: 0.0309 - accuracy: 0.9914

 59/938 [>.............................] - ETA: 2s - loss: 0.0306 - accuracy: 0.9915

 78/938 [=>............................] - ETA: 2s - loss: 0.0293 - accuracy: 0.9916

 97/938 [==>...........................] - ETA: 2s - loss: 0.0293 - accuracy: 0.9916

116/938 [==>...........................] - ETA: 2s - loss: 0.0269 - accuracy: 0.9926

135/938 [===>..........................] - ETA: 2s - loss: 0.0250 - accuracy: 0.9932

154/938 [===>..........................] - ETA: 2s - loss: 0.0248 - accuracy: 0.9932

173/938 [====>.........................] - ETA: 2s - loss: 0.0247 - accuracy: 0.9930

193/938 [=====>........................] - ETA: 1s - loss: 0.0243 - accuracy: 0.9932

212/938 [=====>........................] - ETA: 1s - loss: 0.0238 - accuracy: 0.9934














































































Learning rate for epoch 7 is 9.999999747378752e-05


Epoch 8/12
  1/938 [..............................] - ETA: 0s - loss: 0.0129 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0222 - accuracy: 0.9953

 39/938 [>.............................] - ETA: 2s - loss: 0.0214 - accuracy: 0.9944

 58/938 [>.............................] - ETA: 2s - loss: 0.0221 - accuracy: 0.9938

 77/938 [=>............................] - ETA: 2s - loss: 0.0215 - accuracy: 0.9935

 96/938 [==>...........................] - ETA: 2s - loss: 0.0217 - accuracy: 0.9937

115/938 [==>...........................] - ETA: 2s - loss: 0.0242 - accuracy: 0.9931

134/938 [===>..........................] - ETA: 2s - loss: 0.0229 - accuracy: 0.9935

153/938 [===>..........................] - ETA: 2s - loss: 0.0239 - accuracy: 0.9932

172/938 [====>.........................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9933

191/938 [=====>........................] - ETA: 2s - loss: 0.0224 - accuracy: 0.9936

210/938 [=====>........................] - ETA: 1s - loss: 0.0225 - accuracy: 0.9935














































































Learning rate for epoch 8 is 9.999999747378752e-06


Epoch 9/12


  1/938 [..............................] - ETA: 0s - loss: 0.0049 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0296 - accuracy: 0.9914

 39/938 [>.............................] - ETA: 2s - loss: 0.0313 - accuracy: 0.9900

 58/938 [>.............................] - ETA: 2s - loss: 0.0318 - accuracy: 0.9916

 77/938 [=>............................] - ETA: 2s - loss: 0.0292 - accuracy: 0.9923

 96/938 [==>...........................] - ETA: 2s - loss: 0.0274 - accuracy: 0.9927

115/938 [==>...........................] - ETA: 2s - loss: 0.0257 - accuracy: 0.9932

134/938 [===>..........................] - ETA: 2s - loss: 0.0256 - accuracy: 0.9929

153/938 [===>..........................] - ETA: 2s - loss: 0.0251 - accuracy: 0.9930

172/938 [====>.........................] - ETA: 2s - loss: 0.0242 - accuracy: 0.9933

191/938 [=====>........................] - ETA: 1s - loss: 0.0242 - accuracy: 0.9931

210/938 [=====>........................] - ETA: 1s - loss: 0.0241 - accuracy: 0.9932














































































Learning rate for epoch 9 is 9.999999747378752e-06


Epoch 10/12


  1/938 [..............................] - ETA: 0s - loss: 0.0110 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0206 - accuracy: 0.9953

 39/938 [>.............................] - ETA: 2s - loss: 0.0183 - accuracy: 0.9956

 58/938 [>.............................] - ETA: 2s - loss: 0.0210 - accuracy: 0.9952

 76/938 [=>............................] - ETA: 2s - loss: 0.0246 - accuracy: 0.9938

 95/938 [==>...........................] - ETA: 2s - loss: 0.0228 - accuracy: 0.9942

114/938 [==>...........................] - ETA: 2s - loss: 0.0220 - accuracy: 0.9941

133/938 [===>..........................] - ETA: 2s - loss: 0.0253 - accuracy: 0.9933

152/938 [===>..........................] - ETA: 2s - loss: 0.0251 - accuracy: 0.9931

171/938 [====>.........................] - ETA: 2s - loss: 0.0248 - accuracy: 0.9931

190/938 [=====>........................] - ETA: 2s - loss: 0.0251 - accuracy: 0.9928

209/938 [=====>........................] - ETA: 1s - loss: 0.0247 - accuracy: 0.9929














































































Learning rate for epoch 10 is 9.999999747378752e-06


Epoch 11/12


  1/938 [..............................] - ETA: 0s - loss: 0.0079 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0173 - accuracy: 0.9969

 39/938 [>.............................] - ETA: 2s - loss: 0.0238 - accuracy: 0.9924

 58/938 [>.............................] - ETA: 2s - loss: 0.0236 - accuracy: 0.9925

 77/938 [=>............................] - ETA: 2s - loss: 0.0221 - accuracy: 0.9929

 96/938 [==>...........................] - ETA: 2s - loss: 0.0217 - accuracy: 0.9933

115/938 [==>...........................] - ETA: 2s - loss: 0.0243 - accuracy: 0.9921

134/938 [===>..........................] - ETA: 2s - loss: 0.0249 - accuracy: 0.9928

153/938 [===>..........................] - ETA: 2s - loss: 0.0250 - accuracy: 0.9930

172/938 [====>.........................] - ETA: 2s - loss: 0.0241 - accuracy: 0.9933

191/938 [=====>........................] - ETA: 1s - loss: 0.0244 - accuracy: 0.9934

210/938 [=====>........................] - ETA: 1s - loss: 0.0239 - accuracy: 0.9935














































































Learning rate for epoch 11 is 9.999999747378752e-06


Epoch 12/12
  1/938 [..............................] - ETA: 0s - loss: 0.0433 - accuracy: 0.9844

 20/938 [..............................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9898

 39/938 [>.............................] - ETA: 2s - loss: 0.0236 - accuracy: 0.9928

 58/938 [>.............................] - ETA: 2s - loss: 0.0223 - accuracy: 0.9935

 77/938 [=>............................] - ETA: 2s - loss: 0.0234 - accuracy: 0.9929

 96/938 [==>...........................] - ETA: 2s - loss: 0.0217 - accuracy: 0.9935

115/938 [==>...........................] - ETA: 2s - loss: 0.0219 - accuracy: 0.9936

134/938 [===>..........................] - ETA: 2s - loss: 0.0214 - accuracy: 0.9937

153/938 [===>..........................] - ETA: 2s - loss: 0.0219 - accuracy: 0.9937

172/938 [====>.........................] - ETA: 2s - loss: 0.0224 - accuracy: 0.9936

191/938 [=====>........................] - ETA: 1s - loss: 0.0225 - accuracy: 0.9934

210/938 [=====>........................] - ETA: 1s - loss: 0.0218 - accuracy: 0.9936














































































Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7fe470118978>

如下所示，检查点（checkpoint）将被保存。

In [16]:
# 检查检查点（checkpoint）目录
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index		     ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index		     ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index		     ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


要查看模型的执行方式，请加载最新的检查点（checkpoint）并在测试数据上调用 `evaluate` 。

使用适当的数据集调用 `evaluate` 。

In [17]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.1031 - accuracy: 0.9688

 11/157 [=>............................] - ETA: 0s - loss: 0.0515 - accuracy: 0.9886

 20/157 [==>...........................] - ETA: 0s - loss: 0.0407 - accuracy: 0.9883

 29/157 [====>.........................] - ETA: 0s - loss: 0.0381 - accuracy: 0.9876





























Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446


要查看输出，您可以在终端下载并查看 TensorBoard 日志。

```
$ tensorboard --logdir=path/to/log-directory
```

In [18]:
!ls -sh ./logs

total 4.0K
4.0K train


## 导出到 SavedModel

将图形和变量导出为与平台无关的 SavedModel 格式。 保存模型后，可以在有或没有 scope 的情况下加载模型。

In [19]:
path = 'saved_model/'

In [20]:
model.save(path, save_format='tf')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


在无需 `strategy.scope` 加载模型。

In [21]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.1031 - accuracy: 0.9688

 14/157 [=>............................] - ETA: 0s - loss: 0.0481 - accuracy: 0.9888

 27/157 [====>.........................] - ETA: 0s - loss: 0.0364 - accuracy: 0.9878



















Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446


在含 `strategy.scope` 加载模型。

In [22]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.1031 - accuracy: 0.9688

 13/157 [=>............................] - ETA: 0s - loss: 0.0455 - accuracy: 0.9892

 24/157 [===>..........................] - ETA: 0s - loss: 0.0375 - accuracy: 0.9876

 35/157 [=====>........................] - ETA: 0s - loss: 0.0375 - accuracy: 0.9888























Eval loss: 0.03988004848361015, Eval Accuracy: 0.9861000180244446


### 示例和教程
以下是使用 keras fit/compile 分布式策略的一些示例：
1. 使用`tf.distribute.MirroredStrategy` 训练 [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py) 的示例。
2. 使用`tf.distribute.MirroredStrategy` 训练 [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) 的示例。

[分布式策略指南](../../guide/distribute_strategy.ipynb#examples_and_tutorials)中列出的更多示例 

## 下一步

* 阅读[分布式策略指南](../../guide/distribute_strategy.ipynb)。
* 阅读[自定义训练的分布式训练](training_loops.ipynb)教程。

注意：`tf.distribute.Strategy` 正在积极开发中，我们将在不久的将来添加更多示例和教程。欢迎您进行尝试。我们欢迎您通过[ GitHub 上的 issue ](https://github.com/tensorflow/tensorflow/issues/new) 提供反馈。