In [None]:
import tensorflow as tf
import horovod.tensorflow.keras as hvd
import horovod.spark
import pyspark
from pyspark.sql import SparkSession

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def train_function(model):
    import tensorflow as tf
    import horovod.tensorflow.keras as hvd
    import horovod.spark
    
    hvd.init()
    (mnist_images, mnist_labels), _ = \
    tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())

    dataset = tf.data.Dataset.from_tensor_slices(
        (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
                 tf.cast(mnist_labels, tf.int64))
    )
    dataset = dataset.repeat().shuffle(10000).batch(128)
    
    opt = tf.optimizers.Adam(0.001 * hvd.size())
    opt = hvd.DistributedOptimizer(opt)
    
    callbacks = [
        hvd.callbacks.BroadcastGlobalVariablesCallback(0),
        hvd.callbacks.MetricAverageCallback(),
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
    ]
    
    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
    if hvd.rank() == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint('./chckpnt-{epoch}.h5'))
    
    model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose)
    return model

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

# Horovod: add Horovod DistributedOptimizer. Now running in non-distributed mode
opt = hvd.DistributedOptimizer(tf.optimizers.Adam())

model.compile(optimizer=opt,
              loss=tf.losses.SparseCategoricalCrossentropy(),
              experimental_run_tf_function=False,
              metrics=['accuracy'])

In [None]:
import socket

localIpAddress = socket.gethostbyname(socket.gethostname())

conf = pyspark.SparkConf().setAll([
    ('spark.master', 'k8s://https://kubernetes.default.svc.cluster.local:443'),
    ("spark.driver.host", localIpAddress),
    ("spark.kubernetes.namespace", "spark"),
    ("spark.kubernetes.container.image", "akirillov/spark:spark-2.4.3-hadoop-2.9-k8s-horovod")])

spark = SparkSession.builder.appName("HorovodOnSpark").config(conf=conf).getOrCreate()

# Horovod: run training.
model = horovod.spark.run(train_function, args=(model,), num_proc=2, verbose=2)

# spark.stop()

In [None]:
spark.stop()