In [None]:
import tensorflow as tf
import horovod.tensorflow.keras as hvd

import matplotlib.pyplot as plt
%matplotlib inline

hvd.init()

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalizing the RGB codes by dividing it to the max RGB value.
x_train, x_test = x_train / 255.0, x_test / 255.0

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

# Horovod: add Horovod DistributedOptimizer. Now running in non-distributed mode
opt = hvd.DistributedOptimizer(tf.optimizers.Adam())

model.compile(optimizer=opt,
              loss=tf.losses.SparseCategoricalCrossentropy(),
              experimental_run_tf_function=False,
              metrics=['accuracy'])

In [None]:
callbacks = [
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),

    # Horovod: average metrics among workers at the end of every epoch.
    #
    # Note: This callback must be in the list before the ReduceLROnPlateau,
    # TensorBoard or other metrics-based callbacks.
    hvd.callbacks.MetricAverageCallback(),

    # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
    # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
    # the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
    hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
if hvd.rank() == 0:
    callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))

In [None]:
model.fit(x_train, y_train, callbacks=callbacks, epochs=5, verbose=2)
model.evaluate(x_test,  y_test, verbose=2)

In [None]:
image_index = 5557
plt.imshow(x_test[image_index].reshape(28, 28),cmap='binary')

pred = model.predict(x_test[image_index:image_index+1])
print(pred.argmax())