# The MNIST Application

This example demostrates integration of the Horovod MPI-based distributed deep learning framework and the Spark platform within the context of the MNIST application. 

In [1]:
import os
import time
from datetime import timedelta, datetime, tzinfo

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("spark-horovod-mnist").getOrCreate()

## Initialize the Spark RDD collection associated with MPI workers

In [2]:
partitions = 2

# Read the PMIx environmental variables
env = {}
with open('pmixsrv.env', 'r') as f:
    lines = f.read().splitlines() 
    for line in lines:
        words = line.split("=")
        env[words[0]] = words[1]
        
env["PATH"] = os.getenv("PATH")
env["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
env["http_proxy"] = os.getenv("http_proxy")
env["https_proxy"] = os.getenv("https_proxy")

arg = []
for id in range(partitions):
    arg.append(env)

rdd = spark.sparkContext.parallelize(arg, partitions)

## Download the MNIST dataset

In [3]:
def read_data_sets(pid, it):
    
    import tensorflow as tf
    
    learn = tf.contrib.learn
    learn.datasets.mnist.read_data_sets('MNIST-data-%d' % pid)
    
    yield pid

rdd.mapPartitionsWithIndex(read_data_sets).sum()

1

## Train the Horovod MPI-based distributed engine 

In [4]:
# The train method is defined after horovod'd example
# https://github.com/uber/horovod/blob/master/examples/tensorflow_mnist.py

def train(pid, parts):
        
    import tensorflow as tf
    import horovod.tensorflow as hvd
    import mnist_factory
    
    tf.logging.set_verbosity(tf.logging.INFO)
    
    # define the MPI enviromental variables     
    os.environ["PMIX_RANK"] = str(pid)
    for env in parts:
        for key in env:
            print(key, env[key])
            os.environ[key] = env[key]
       
    # initialize Horovod   
    hvd.init()
    
    # Extract the MNIST dataset
    learn = tf.contrib.learn
    mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank())
    
    # Build model...
    with tf.name_scope('input'):
        image = tf.placeholder(tf.float32, [None, 784], name='image')
        label = tf.placeholder(tf.float32, [None], name='label')
    predict, loss = mnist_factory.make_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)
    
    global_step = tf.contrib.framework.get_or_create_global_step()
    
    # Horovod: add Horovod Distributed Optimizer.
    opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())
    opt = hvd.DistributedOptimizer(opt)
    train_op = opt.minimize(loss, global_step=global_step)
    
    # Create hooks
    hooks = mnist_factory.make_hooks(hvd.size(), global_step, loss)
    
    # Horovod: save checkpoints only on worker 0 
    checkpoint_dir = './checkpoints' if hvd.rank() == 0 else None
    
    # Horovod: pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           hooks=hooks,
                                           config=config) as mon_sess:
        while not mon_sess.should_stop():
            # Run a training step synchronously.
            image_, label_ = mnist.train.next_batch(100)
            mon_sess.run(train_op, feed_dict={image: image_, label: label_})
    
    yield pid
 
rdd.mapPartitionsWithIndex(train).sum()

1