In [None]:
%%writefile train.py
import os
import io
import h5py
import tempfile
import tensorflow as tf
import horovod.tensorflow.keras as hvd

batch_size = 128
epochs = 10
num_classes = 10
learning_rate = 0.01

def get_model(num_classes=10):
    import tensorflow as tf
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    return model

def get_dataset(rank=0, size=1):
    import tensorflow as tf
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data('MNIST-data-%d' % rank)
    x_train = x_train[rank::size]
    y_train = y_train[rank::size]
    x_test = x_test[rank::size]
    y_test = y_test[rank::size]
    # Normalizing the RGB codes by dividing it to the max RGB value.
    x_train, x_test = x_train / 255.0, x_test / 255.0
    return (x_train, y_train), (x_test, y_test)

hvd.init()

(x_train, y_train), (x_test, y_test) = get_dataset(hvd.rank(), hvd.size())
model = get_model(num_classes)

# Horovod: add Horovod DistributedOptimizer
optimizer = hvd.DistributedOptimizer(tf.optimizers.Adam(lr=learning_rate * hvd.size()))

model.compile(optimizer=optimizer,
            loss='sparse_categorical_crossentropy',
            experimental_run_tf_function=False,
            metrics=['accuracy'])

callbacks = [
  # Horovod: broadcast initial variable states from rank 0 to all other processes.
  # This is necessary to ensure consistent initialization of all workers when
  # training is started with random weights or restored from a checkpoint.
  hvd.callbacks.BroadcastGlobalVariablesCallback(0),
  # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
  # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
  # the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
  hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
]

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
# Model checkpoint location.
ckpt_dir = tempfile.mkdtemp()
ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5')

if hvd.rank() == 0:
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='accuracy', mode='max', save_best_only=True))

model.fit(x_train, y_train,
        batch_size=batch_size,
        callbacks=callbacks,
        epochs=epochs,
        verbose=2,
        validation_data=(x_test, y_test))


model.evaluate(x_test,  y_test, verbose=2)

# Persisting saved model
def save_model(model, path):
    bio = io.BytesIO()
    with h5py.File(bio, "w") as f:
        model.save(f)
        
    with open(path, 'wb') as f:
        f.write(bio.getvalue())

save_model(model, "./trained_model")

In [None]:
# Due to https://github.com/horovod/horovod/issues/1176 we can't call horovod.run interactively and have to run
# horovod CLI from inside a container via 'k exec'. 
# Command: horovodrun --verbose --num-proc=4 python3 /home/spark/train.py

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import horovod.tensorflow.keras as hvd

# Restoring model from checkpoint and running some predictions
def restore_model(path):
    with h5py.File("./trained_model", 'r') as f:
        return hvd.load_model(f)    

model = restore_model("./trained_model")

image_index = 4443
plt.imshow(x_test[image_index].reshape(28, 28),cmap='binary')

pred = model.predict(x_test[image_index:image_index+1])
print(pred.argmax())