##### Copyright 2019 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Custom training with tf.distribute.Strategy

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/custom_training"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/custom_training.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/custom_training.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/custom_training.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This tutorial demonstrates how to use `tf.distribute.Strategy`—a TensorFlow API that provides an abstraction for [distributing your training](../../guide/distributed_training.ipynb) across multiple processing units (GPUs, multiple machines, or TPUs)—with custom training loops. In this example, you will train a simple convolutional neural network on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) containing 70,000 images of size 28 x 28.

[Custom training loops](../customization/custom_training_walkthrough.ipynb) provide flexibility and a greater control on training. They also make it easier to debug the model and the training loop.

In [1]:
# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.9.1


## Download the Fashion MNIST dataset

In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

## Create a strategy to distribute the variables and the graph

How does `tf.distribute.MirroredStrategy` strategy work?

*   All the variables and the model graph are replicated across the replicas.
*   Input is evenly distributed across the replicas.
*   Each replica calculates the loss and gradients for the input it received.
*   The gradients are synced across all the replicas by summing them.
*   After the sync, the same update is made to the copies of the variables on each replica.

Note: You can put all the code below inside a single scope. This example divides it into several code cells for illustration purposes.


In [4]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[0:3],'GPU')

In [5]:
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()



2022-07-25 13:37:06.886146: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-07-25 13:37:08.457682: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9648 MB memory:  -> device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5
2022-07-25 13:37:08.458761: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 9648 MB memory:  -> device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:02:00.0, compute capability: 7.5
2022-07-25 13:37:08.459793: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2')


In [6]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 3


## Setup input pipeline

In [7]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

Create the datasets and distribute them:

In [6]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

2022-07-25 13:30:36.075295: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODU

## Create the model

Create a model using `tf.keras.Sequential`. You can also use the [Model Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) or the [functional API](https://www.tensorflow.org/guide/keras/functional) to do this.

In [7]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model

In [8]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

Normally, on a single machine with single GPU/CPU, loss is divided by the number of examples in the batch of input.

*So, how should the loss be calculated when using a `tf.distribute.Strategy`?*

* For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed
across the replicas (4 GPUs), each replica getting an input of size 16.

* The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (`BATCH_SIZE_PER_REPLICA` = 16), the loss should be divided by the `GLOBAL_BATCH_SIZE` (64).

*Why do this?*

* This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by **summing** them.

*How to do this in TensorFlow?*

* If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the `GLOBAL_BATCH_SIZE`: 
`scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)`
or you can use `tf.nn.compute_average_loss` which takes the per example loss,
optional sample weights, and `GLOBAL_BATCH_SIZE` as arguments and returns the scaled loss.

* If you are using regularization losses in your model then you need to scale
the loss value by the number of replicas. You can do this by using the `tf.nn.scale_regularization_loss` function.

* Using `tf.reduce_mean` is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.

* This reduction and scaling is done automatically in Keras `Model.compile` and `Model.fit`

* If using `tf.keras.losses` classes (as in the example below), the loss reduction needs to be explicitly specified to be one of `NONE` or `SUM`. `AUTO` and `SUM_OVER_BATCH_SIZE`  are disallowed when used with `tf.distribute.Strategy`. `AUTO` is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. `SUM_OVER_BATCH_SIZE` is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So, instead, you need to do the reduction yourself explicitly.
* If `labels` is multi-dimensional, then average the `per_example_loss` across the number of elements in each sample. For example, if the shape of `predictions` is `(batch_size, H, W, n_classes)` and `labels` is `(batch_size, H, W)`, you will need to update `per_example_loss` like: `per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)`

  Caution: **Verify the shape of your loss**. 
  Loss functions in `tf.losses`/`tf.keras.losses` typically
  return the average over the last dimension of the input. The loss
  classes wrap these functions. Passing `reduction=Reduction.NONE` when
  creating an instance of a loss class means "no **additional** reduction".
  For categorical losses with an example input shape of `[batch, W, H, n_classes]` the `n_classes`
  dimension is reduced. For pointwise losses like
  `losses.mean_squared_error` or `losses.binary_crossentropy` include a
  dummy axis so that `[batch, W, H, 1]` is reduced to `[batch, W, H]`. Without
  the dummy axis  `[batch, W, H]` will be incorrectly reduced to `[batch, W]`.


In [9]:
with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

## Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use `.result()` to get the accumulated statistics at any time.

In [10]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

## Training loop

In [11]:
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [12]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [13]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1


2022-07-25 13:31:48.334788: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8201
2022-07-25 13:31:49.109578: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8201
2022-07-25 13:31:50.048347: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8201
2022-07-25 13:31:51.427168: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8201


INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 0.6790518760681152, Accuracy: 75.53666687011719, Test Loss: 0.47022944688796997, Test Accuracy: 83.30000305175781
Epoch 2, Loss: 0.4078426659107208, Accuracy: 85.54166412353516, Test Loss: 0.3945533335208893, Test Accuracy: 86.16999816894531
Epoch 3, Loss: 0.35287538170814514, Accuracy: 87.32833099365234, Test Loss: 0.36147063970565796, Test Accuracy: 87.12999725341797
Epoch 4, Loss: 0.3205961585044861, Accuracy: 88.49666595458984, Test Loss: 0.35138148069381714, Test Accuracy: 87.52999877929688
Epoch 5, Loss: 0.3009766638278961, Accuracy: 89.25, Test Loss: 0.3221170902252197, Test Accuracy: 88.52000427246094
Epoch 6, Loss: 0.2848018705844879, Accuracy: 89.73500061035156, Test Loss: 0.33328384160995483, Test Accuracy: 88.02000427246094
Epoch 7, Loss: 0.27027127146720886, Accuracy: 90.20000457763672, Test Loss: 0.2944231629371643, Test Accuracy: 89.52999877929688
Epoch 8, Loss: 0.25513446

Things to note in the example above:

* Iterate over the `train_dist_dataset` and `test_dist_dataset` using  a `for x in ...` construct.
* The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.
* `tf.keras.Metrics` should be updated inside `train_step` and `test_step` that gets executed by `tf.distribute.Strategy.run`.
*`tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.


## Restore the latest checkpoint and test

A model checkpointed with a `tf.distribute.Strategy` can be restored with or without a strategy.

In [14]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [15]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [16]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))

Accuracy after restoring the saved model without strategy: 89.96000671386719


## Alternate ways of iterating over a dataset

### Using iterators

If you want to iterate over a given number of steps and not through the entire dataset, you can create an iterator using the `iter` call and explicitly call `next` on the iterator. You can choose to iterate over the dataset both inside and outside the `tf.function`. Here is a small snippet demonstrating iteration of the dataset outside the `tf.function` using an iterator.


In [17]:
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
  train_accuracy.reset_states()

Epoch 10, Loss: 0.23941445350646973, Accuracy: 90.46875
Epoch 10, Loss: 0.260286420583725, Accuracy: 90.1171875
Epoch 10, Loss: 0.225071981549263, Accuracy: 91.8359375
Epoch 10, Loss: 0.21715669333934784, Accuracy: 91.9140625
Epoch 10, Loss: 0.23235765099525452, Accuracy: 92.109375
Epoch 10, Loss: 0.2389928102493286, Accuracy: 92.03125
Epoch 10, Loss: 0.237949401140213, Accuracy: 92.109375
Epoch 10, Loss: 0.21342167258262634, Accuracy: 92.7734375
Epoch 10, Loss: 0.20895743370056152, Accuracy: 92.34375
Epoch 10, Loss: 0.20401696860790253, Accuracy: 92.3046875


### Iterating inside a `tf.function`

You can also iterate over the entire input `train_dist_dataset` inside a `tf.function` using the `for x in ...` construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a `@tf.function` decorator and iterating over `train_dist_dataset` inside the function.

In [18]:
@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

  train_accuracy.reset_states()



INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 0.21808364987373352, Accuracy: 92.03666687011719
Epoch 2, Loss: 0.21285371482372284, Accuracy: 92.2550048828125
Epoch 3, Loss: 0.1999495029449463, Accuracy: 92.65833282470703
Epoch 4, Loss: 0.1943306177854538, Accuracy: 92.76666259765625
Epoch 5, Loss: 0.18401119112968445, Accuracy: 93.29499816894531
Epoch 6, Loss: 0.1812204271554947, Accuracy: 93.31500244140625
Epoch 7, Loss: 0.17226870357990265, Accuracy: 93.68666076660156
Epoch 8, Loss: 0.16289854049682617, Accuracy: 94.09333801269531
Epoch 9, Loss: 0.15651777386665344, Accuracy: 94.2266616821289
Epoch 10, Loss: 0.1523996889591217, Accuracy: 94.36499786376953


### Tracking training loss across replicas

Note: As a general rule, you should use `tf.keras.Metrics` to track per-sample values and avoid values that have been aggregated within a replica.

Because of the loss scaling computation that is carried out, it's not recommended to use `tf.keras.metrics.Mean` to track the training loss across different replicas.

For example, if you run a training job with the following characteristics:

* Two replicas
* Two samples are processed on each replica
* Resulting loss values: [2,  3] and [4,  5] on each replica
* Global batch size = 4

With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: `(2 + 3) / 4 = 1.25` and `(4 + 5) / 4 = 2.25`. 

If you use `tf.keras.metrics.Mean` to track loss across the two replicas, the result is different. In this example, you end up with a `total` of 3.50 and `count` of 2, which results in `total`/`count` = 1.75  when `result()` is called on the metric. Loss calculated with `tf.keras.Metrics` is scaled by an additional factor that is equal to the number of replicas in sync.

### Guide and examples

Here are some examples for using distribution strategy with custom training loops:

1. [Distributed training guide](../../guide/distributed_training)
2. [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) example using `MirroredStrategy`.
1. [BERT](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_classifier.py) example trained using `MirroredStrategy` and `TPUStrategy`.
This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) example trained using `MirroredStrategy` that can be enabled using the `keras_use_ctl` flag.
3. [NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) example trained using `MirroredStrategy`.

You can find more examples listed under _Examples and tutorials_ in the [Distribution strategy guide](../../guide/distributed_training.ipynb).

## Next steps

*   Try out the new `tf.distribute.Strategy` API on your models.
*   Visit the [Better performance with `tf.function`](../../guide/function.ipynb) and [TensorFlow Profiler](../../guide/profiler.md) guides to learn more about tools to optimize the performance of your TensorFlow models.
*   Check out the [Distributed training in TensorFlow](../../guide/distributed_training.ipynb) guide, which provides an overview of the available distribution strategies.