In [None]:
import tensorflow_datasets as tfds

import tensorflow as tf

In [None]:
mirrored_strategy = tf.distribute.MirroredStrategy()

In [None]:
def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

In [None]:

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

In [None]:
keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

In [None]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

In [None]:
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

In [None]:
model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

In [None]:
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

In [None]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

Distributed Input

In [1]:
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.18.0


In [2]:
# Simulate multiple CPUs with virtual devices
N_VIRTUAL_DEVICES = 2
physical_devices = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
    physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)])

In [3]:
print("Available devices:")
for i, device in enumerate(tf.config.list_logical_devices()):
  print("%d) %s" % (i, device))

Available devices:
0) LogicalDevice(name='/device:CPU:0', device_type='CPU')
1) LogicalDevice(name='/device:CPU:1', device_type='CPU')


In [None]:
global_batch_size = 16
# Create a tf.data.Dataset object.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)

@tf.function
def train_step(inputs):
  features, labels = inputs
  return labels - 0.3 * features

# Iterate over the dataset using the for..in construct.
for inputs in dataset:
  print(train_step(inputs))

In [None]:
global_batch_size = 16
mirrored_strategy = tf.distribute.MirroredStrategy()

dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)
# Distribute input using the `experimental_distribute_dataset`.
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
# 1 global batch of data fed to the model in 1 step.
print(next(iter(dist_dataset)))

In [None]:
# tf.data.Dataset.range(6).batch(4, drop_remainder=False)

# Without distribution:
# Batch 1: [0, 1, 2, 3]
# Batch 2: [4, 5]
# With distribution over 2 replicas. The last batch ([4, 5]) is split between 2 replicas.

# Batch 1:

# Replica 1:[0, 1]
# Replica 2:[2, 3]
# Batch 2:

# Replica 1: [4]
# Replica 2: [5]
# tf.data.Dataset.range(4).batch(4)

# Without distribution:
# Batch 1: [0, 1, 2, 3]
# With distribution over 5 replicas:
# Batch 1:
# Replica 1: [0]
# Replica 2: [1]
# Replica 3: [2]
# Replica 4: [3]
# Replica 5: []
# tf.data.Dataset.range(8).batch(4)

# Without distribution:
# Batch 1: [0, 1, 2, 3]
# Batch 2: [4, 5, 6, 7]
# With distribution over 3 replicas:
# Batch 1:
# Replica 1: [0, 1]
# Replica 2: [2, 3]
# Replica 3: []
# Batch 2:
# Replica 1: [4, 5]
# Replica 2: [6, 7]
# Replica 3: []


In [None]:
# Worker 0:
# Batch 1 = Replica 1: [0, 1]
# Batch 2 = Replica 1: [2, 3]
# Batch 3 = Replica 1: [4]
# Batch 4 = Replica 1: [5]
# Worker 1:
# Batch 1 = Replica 2: [6, 7]
# Batch 2 = Replica 2: [8, 9]
# Batch 3 = Replica 2: [10]
# Batch 4 = Replica 2: [11]
# DATA: This will autoshard the elements across all the workers. Each of the workers will read the entire dataset and only process the shard assigned to it. All other shards will be discarded. This is generally used if the number of input files is less than the number of workers and you want better sharding of data across all workers. The downside is that the entire dataset will be read on each worker. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2.

# Worker 0:
# Batch 1 = Replica 1: [0, 1]
# Batch 2 = Replica 1: [4, 5]
# Batch 3 = Replica 1: [8, 9]
# Worker 1:
# Batch 1 = Replica 2: [2, 3]
# Batch 2 = Replica 2: [6, 7]
# Batch 3 = Replica 2: [10, 11]
# OFF: If you turn off autosharding, each worker will process all the data. For example, let us distribute 1 files over 2 workers. File 1 contains [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]. Let the total number of replicas in sync be 2. Then each worker will see the following distribution:

# Worker 0:
# Batch 1 = Replica 1: [0, 1]
# Batch 2 = Replica 1: [2, 3]
# Batch 3 = Replica 1: [4, 5]
# Batch 4 = Replica 1: [6, 7]
# Batch 5 = Replica 1: [8, 9]
# Batch 6 = Replica 1: [10, 11]

# Worker 1:

# Batch 1 = Replica 2: [0, 1]

# Batch 2 = Replica 2: [2, 3]

# Batch 3 = Replica 2: [4, 5]

# Batch 4 = Replica 2: [6, 7]

# Batch 5 = Replica 2: [8, 9]

# Batch 6 = Replica 2: [10, 11]