# The MNIST Application

This example demostrates integration of the Horovod MPI-based distributed deep learning framework and the Spark platform within the context of the MNIST application. 

In [1]:
import os
import time
from datetime import timedelta, datetime, tzinfo

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("spark-horovod-mnist").getOrCreate()

## Initialize the Spark RDD collection associated with MPI workers

In [2]:
partitions = 4

# Read the PMIx environmental variables
env = {}
with open('pmixsrv.env', 'r') as f:
    lines = f.read().splitlines() 
    for line in lines:
        words = line.split("=")
        env[words[0]] = words[1]
        
env["PATH"] = os.getenv("PATH")
env["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH")
env["http_proxy"] = os.getenv("http_proxy")
env["https_proxy"] = os.getenv("https_proxy")

arg = []
for id in range(partitions):
    arg.append(env)

rdd = spark.sparkContext.parallelize(arg, partitions)

## Download the MNIST dataset

In [3]:
def read_data_sets(pid, it):
    
    import tensorflow as tf
    
    learn = tf.contrib.learn
    learn.datasets.mnist.read_data_sets('MNIST-data-%d' % pid)
    
    yield pid

rdd.mapPartitionsWithIndex(read_data_sets).collect()

[0, 1, 2, 3]

## Train the Horovod MPI-based distributed engine on the Spark workers

In [4]:
# The train method is defined after horovod'd example
# https://github.com/uber/horovod/blob/master/examples/tensorflow_mnist.py

def train(pid, parts):
        
    import tensorflow as tf
    import horovod.tensorflow as hvd
    import mnist_app
    
    log_string = mnist_app.get_log_string(1024)
    
    # define the MPI enviromental variables     
    os.environ["PMIX_RANK"] = str(pid)
    for env in parts:
        for key in env:
            os.environ[key] = env[key]
       
    # initialize Horovod   
    hvd.init()
    
    # Extract the MNIST dataset
    learn = tf.contrib.learn
    mnist = learn.datasets.mnist.read_data_sets('MNIST-data-%d' % hvd.rank())
    
    # Build model...
    with tf.name_scope('input'):
        image = tf.placeholder(tf.float32, [None, 784], name='image')
        label = tf.placeholder(tf.float32, [None], name='label')
    predict, loss = mnist_app.conv_model(image, label, tf.contrib.learn.ModeKeys.TRAIN)
    
    global_step = tf.train.get_or_create_global_step()
    
    # Horovod: add Horovod Distributed Optimizer.
    opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())
    opt = hvd.DistributedOptimizer(opt)
    train_op = opt.minimize(loss, global_step=global_step)
    
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing 
    # when done or an error occurs.
    
    # Horovod: save checkpoints only on worker 0 
    checkpoint_dir = './checkpoints' if hvd.rank() == 0 else None
    
    # Create hooks
    hooks = [
        hvd.BroadcastGlobalVariablesHook(0),
        tf.train.StopAtStepHook(last_step=101),
        tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
                                   every_n_iter=100),
    ]
    
    # Horovod: pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           hooks=hooks,
                                           config=config) as mon_sess:
        while not mon_sess.should_stop():
            # Run a training step synchronously.
            image_, label_ = mnist.train.next_batch(100)
            mon_sess.run(train_op, feed_dict={image: image_, label: label_})
            
    log_contents = log_string.getvalue()
    log_string.close()
    
    yield log_contents
 
log_contents = rdd.mapPartitionsWithIndex(train).collect()

In [5]:
print(log_contents[0])

2018-04-05 15:30:17,590 - tensorflow - INFO - Create CheckpointSaverHook.
2018-04-05 15:30:31,923 - tensorflow - INFO - Saving checkpoints for 1 into ./checkpoints/model.ckpt.
2018-04-05 15:30:32,175 - tensorflow - INFO - step = 1, loss = 2.3147516
2018-04-05 15:30:36,699 - tensorflow - INFO - global_step/sec: 20.9312
2018-04-05 15:30:36,701 - tensorflow - INFO - step = 101, loss = 0.46885204 (4.527 sec)
2018-04-05 15:30:36,703 - tensorflow - INFO - Saving checkpoints for 101 into ./checkpoints/model.ckpt.

