In [None]:
%run -i 'commons.ipynb'

In [None]:
def train_hvd(num_classes=10, learning_rate=0.001, batch_size=128, epochs=10):
    import os
    import tempfile
    import tensorflow as tf
    import horovod.tensorflow.keras as hvd

    hvd.init()
    
    (x_train, y_train), (x_test, y_test) = get_dataset(hvd.rank(), hvd.size())
    model = get_model(num_classes)

    # Horovod: add Horovod DistributedOptimizer. Now running in non-distributed mode
    optimizer = hvd.DistributedOptimizer(tf.optimizers.Adam(lr=learning_rate * hvd.size()))

    model.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                experimental_run_tf_function=False,
                metrics=['accuracy'])
    
    callbacks = [
      hvd.callbacks.BroadcastGlobalVariablesCallback(0),
      hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
    ]

    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
    # Model checkpoint location.
    ckpt_dir = tempfile.mkdtemp()
    ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5')
    
    if hvd.rank() == 0:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='accuracy', mode='max', save_best_only=True))

    history = model.fit(x_train, y_train,
            batch_size=batch_size,
            callbacks=callbacks,
            epochs=epochs,
            verbose=2,
            validation_data=(x_test, y_test))

    if hvd.rank() == 0:
        with open(ckpt_file, 'rb') as f:
            #returning a tuple of history and model bytes
            return history.history, f.read()

In [None]:
# Horovod: run training.
import horovod.run as hvd_run
model_bytes = hvd_run.run(train_hvd, np=4, use_gloo=True, verbose=2)[0][1]

model = deserialize(model_bytes)
model.evaluate(x_test,  y_test, verbose=2)

In [None]:
(x_train, y_train), (x_test, y_test)  = get_dataset()

image_index = 4443
display_image(x_test, image_index)

In [None]:
predict_number(model, x_test, image_index)