In [None]:
!pip install tensorflow

In [8]:
import tensorflow as tf

In [None]:
tf.__version__

In [None]:
!pip install spark_tensorflow_distributor

In [10]:
NUM_WORKERS = 2

# Assume the driver node and worker nodes have the same instance type.

TOTAL_NUM_GPUS = len(tf.config.list_logical_devices('GPU')) * NUM_WORKERS

USE_GPU = TOTAL_NUM_GPUS > 0

In [11]:
from spark_tensorflow_distributor import MirroredStrategyRunner

# Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
    import tensorflow as tf
    import uuid

    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    def make_datasets():
        (mnist_images, mnist_labels), _ = \
            tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz')

        dataset = tf.data.Dataset.from_tensor_slices((
            tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
            tf.cast(mnist_labels, tf.int64))
        )
        dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
        return dataset

    def build_and_compile_cnn_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])
        model.compile(
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
            metrics=['accuracy'],
        )
        return model

    train_datasets = make_datasets()
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_datasets = train_datasets.with_options(options)
    multi_worker_model = build_and_compile_cnn_model()
    multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)

In [14]:
NUM_SLOTS = TOTAL_NUM_GPUS if USE_GPU else 4  # For CPU training, choose a reasonable NUM_SLOTS value

runner = MirroredStrategyRunner(num_slots=NUM_SLOTS,local_mode=False, use_gpu=USE_GPU)

runner.run(train)

Doing CPU training...
Will run with 4 Spark tasks.
Distributed training in progress...
View Spark executor stderr logs to inspect training...
Training with 4 slots is complete!
