# The MNIST Application

This example demostrates integration of the Horovod MPI-based distributed deep learning framework and the Spark platform within the context of the MNIST application. 

In [1]:
import os
import time
from datetime import timedelta, datetime, tzinfo

## Initialize the Spark RDD collection associated with MPI workers

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("spark-horovod-mnist").getOrCreate()

partitions = 2
rdd = spark.sparkContext.parallelize(range(partitions), partitions)

## Download the MNIST dataset

In [3]:
def read_data_sets(pid, it):
    
    import tensorflow as tf
    
    learn = tf.contrib.learn
    learn.datasets.mnist.read_data_sets('MNIST-data-%d' % pid)
    
    yield pid

rdd.mapPartitionsWithIndex(read_data_sets).sum()

1

## Start the PMI server

In [4]:
if os.system("/opt/spark-mpi/bin/pmiserv -n " + str(partitions) + " hello &") != 0:
    print ("pmiserv: ERROR")

## Train the Horovod MPI-based distributed engine on the Spark workers

In [5]:
# The train method is defined after horovod'd example
# https://github.com/uber/horovod/blob/master/examples/tensorflow_mnist.py

def train(pid, it):
    
    import os   
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    
    import tensorflow as tf
    import horovod.tensorflow as hvd
    import mnist_factory
    
    # define the MPI enviromental variables     
    os.environ["PMI_PORT"] = os.uname()[1] + ":" + os.getenv("HYDRA_PROXY_PORT")
    os.environ["PMI_ID"] = str(pid)
       
    # initialize Horovod   
    hvd.init()
    
    # Extract the MNIST dataset
    learn = tf.contrib.learn
    mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank())
    
    # Build model...
    with tf.name_scope('input'):
        image = tf.placeholder(tf.float32, [None, 784], name='image')
        label = tf.placeholder(tf.float32, [None], name='label')
    predict, loss = mnist_factory.make_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)
    
    global_step = tf.contrib.framework.get_or_create_global_step()
    
    # Horovod: add Horovod Distributed Optimizer.
    opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())
    opt = hvd.DistributedOptimizer(opt)
    train_op = opt.minimize(loss, global_step=global_step)
    
    # Create hooks
    hooks = mnist_factory.make_hooks(hvd.size(), global_step, loss)
    
    # Horovod: save checkpoints only on worker 0 
    checkpoint_dir = './checkpoints' if hvd.rank() == 0 else None
    
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           hooks=hooks) as mon_sess:
        while not mon_sess.should_stop():
            # Run a training step synchronously.
            image_, label_ = mnist.train.next_batch(100)
            mon_sess.run(train_op, feed_dict={image: image_, label: label_})
    
    yield pid
 
rdd.mapPartitionsWithIndex(train).sum()

1

## Stop the PMI server

In [10]:
if os.system("pkill -9 \"hydra_pmi_proxy\" &") != 0:
    print ("pkill: ERROR")