In [8]:
# This notebok contains utility functions to be reused in other notebooks
#
# To import the functions use `run` magic: `%run -i 'commons.ipynb'`

def get_dataset(rank=0, size=1):
    import tensorflow as tf
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data('MNIST-data-%d' % rank)
    x_train = x_train[rank::size]
    y_train = y_train[rank::size]
    x_test = x_test[rank::size]
    y_test = y_test[rank::size]
    # Normalizing the RGB codes by dividing it to the max RGB value.
    x_train, x_test = x_train / 255.0, x_test / 255.0
    return (x_train, y_train), (x_test, y_test)

def get_model(num_classes=10):
    import tensorflow as tf
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    return model

def deserialize(model_bytes):
    import horovod.tensorflow.keras as hvd
    import h5py
    import io
    bio = io.BytesIO(model_bytes)
    with h5py.File(bio,'a') as f:
        return hvd.load_model(f)

In [9]:
# Utility functions for demo purposes

import matplotlib.pyplot as plt
%matplotlib inline

def plot_mnist_sample(x_train, y_train):
    plt.rc("image", cmap="binary")

    for i in range(10):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x_train[i].reshape(28, 28))
        plt.xticks(())
        plt.yticks(())
    plt.tight_layout()
    plt.show()

    print(y_train[:5])
    print(y_train[5:10])

def display_image(x_test, image_index):
    plt.imshow(x_test[image_index].reshape(28, 28),cmap='binary')

def predict_number(model, x_test, image_index):
    pred = model.predict(x_test[image_index:image_index+1])
    print(pred.argmax())