## Set up the Environment

The codes in env.py do not have to be changed for a different training process.

In [1]:
%run env.py
# %env # print env

## Import Libraries

Please import and initialize findspark first before import other pyspark libraries.

In [2]:
import findspark
findspark.init()
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession

import numpy
import tensorflow as tf

from tensorflowonspark import TFCluster
from argparse import Namespace

from mnist_train_tf import map_fun

## Set up Spark Config

Our machine is running on spark 2.3.
Spark env properties can be found [here](https://spark.apache.org/docs/2.3.0/configuration.html).
Yarn env properties can be found [here](https://spark.apache.org/docs/2.3.0/running-on-yarn.html).

Some notes:

**spark.submit.deployMode**: please change the value from `client` to `cluster` if you submit your work from different location, details can be found [here](https://stackoverflow.com/questions/41124428/spark-yarn-cluster-vs-client-how-to-choose-which-one-to-use?rq=1).

**spark.executor.instances**: this should match the amount of your worker nodes.

**spark.executor.memory**: check the remaining memory first at YARN ResourceManager UI, you may want to kill some dead process first.

**spark.executorEnv.CLASSPATH**: specifies the location of user-defined classes and packages, sometimes this is used for backward compatiblity.

In [3]:
conf = SparkConf()

conf.setAll([("spark.app.name", "mnist-yarn-train"), # your app name
             ("spark.master", "yarn"), # cluster mode, please leave this unchanged
             ("spark.yarn.queue", "default"), # value=default for CPU only, value=gpu if GPU enabled
             ("spark.yarn.maxAppAttempts", "1"), # please leave this unchanged
             ("spark.submit.deployMode", "client"), # check comments above
             ("spark.executor.instances", "2"), # check comments above
             ("spark.executor.memory", "3G"), # check comments above
             ("spark.dynamicAllocation.enabled", "false"), # please leave this unchanged
             ("spark.executorEnv.LD_LIBRARY_PATH", os.environ["LD_LIBRARY_PATH"]), # please leave this unchanged
             ("spark.executorEnv.HADOOP_HDFS_HOME", os.environ["HADOOP_HDFS_HOME"]), # please leave this unchanged
             ("spark.executorEnv.CLASSPATH", os.environ["CLASSPATH"])]) # please leave this unchanged

sc = SparkContext(conf=conf)
sc.addPyFile("mnist_train_tf.py")

# print(sc._conf.getAll())

## Set up Arguments

In [4]:
args = Namespace(
  batch_size=100, # number of records per batch
  epochs=1, # number of epochs
  format="csv", # example format: ["csv", "tfr"]
  images="/data/mnist/csv/train/images", # HDFS path to MNIST images in parallelized format
  labels="/data/mnist/csv/train/labels", # HDFS path to MNIST labels in parallelized format
  model="/data/mnist/model", # HDFS path to save/load model during train/inference
  cluster_size=2, # number of nodes in the cluster
  num_ps=1, # number of parameter servers
  output="/data/mnist/predictions", # HDFS path to save test/inference output
  readers=1, # number of reader/enqueue threads
  steps=1000, # maximum number of steps
  tensorboard=True, # launch tensorboard process
  mode="train", # train|inference
  rdma=False # use rdma connection
)

## Main

In [5]:
if args.format == "tfr":
  images = sc.newAPIHadoopFile(args.images, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
                               keyClass="org.apache.hadoop.io.BytesWritable",
                               valueClass="org.apache.hadoop.io.NullWritable")
  def toNumpy(bytestr):
    example = tf.train.Example()
    example.ParseFromString(bytestr)
    features = example.features.feature
    image = numpy.array(features['image'].int64_list.value)
    label = numpy.array(features['label'].int64_list.value)
    return (image, label)
  dataRDD = images.map(lambda x: toNumpy(bytes(x[0])))

else:  # args.format == "csv":
  images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
  labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
  dataRDD = images.zip(labels)

cluster = TFCluster.run(sc, map_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK)

if args.mode == "train":
  cluster.train(dataRDD, args.epochs)

else:
  labelRDD = cluster.inference(dataRDD)
  labelRDD.saveAsTextFile(args.output)
cluster.shutdown()

sc.stop()

2019-03-15 21:39:34,950 INFO (MainThread-97634) Reserving TFSparkNodes w/ TensorBoard
2019-03-15 21:39:34,951 INFO (MainThread-97634) cluster_template: {'ps': range(0, 1), 'worker': range(1, 2)}
2019-03-15 21:39:34,954 INFO (MainThread-97634) listening for reservations at ('172.29.0.3', 46679)
2019-03-15 21:39:34,955 INFO (MainThread-97634) Starting TensorFlow on executors
2019-03-15 21:39:34,962 INFO (MainThread-97634) Waiting for TFSparkNodes to start
2019-03-15 21:39:34,964 INFO (MainThread-97634) waiting for 2 reservations
2019-03-15 21:39:35,967 INFO (MainThread-97634) waiting for 2 reservations
2019-03-15 21:39:36,969 INFO (MainThread-97634) waiting for 2 reservations
2019-03-15 21:39:37,971 INFO (MainThread-97634) all reservations completed
2019-03-15 21:39:37,972 INFO (MainThread-97634) All TFSparkNodes started
2019-03-15 21:39:37,972 INFO (MainThread-97634) {'executor_id': 1, 'host': '172.29.0.5', 'job_name': 'worker', 'task_index': 0, 'port': 38519, 'tb_pid': 88175, 'tb_port'