## Setup the Environment

The codes in env.py do not have to be changed for a different training process.

In [None]:
%run env.py
# %env # print env

## Import Libraries

Please import and initialize findspark first before import other pyspark libraries.

In [None]:
import findspark
findspark.init()
from pyspark.context import SparkContext
from pyspark.conf import SparkConf

from argparse import Namespace

from mnist_write import writeMNIST, readMNIST

## Set up Spark Config

Our machine is running on spark 2.3.
Spark env properties can be found [here](https://spark.apache.org/docs/2.3.0/configuration.html).
Yarn env properties can be found [here](https://spark.apache.org/docs/2.3.0/running-on-yarn.html#configuration).

Some notes:

**spark.submit.deployMode**: please change the value from `client` to `cluster` if you submit your work from different location, details can be found [here](https://stackoverflow.com/questions/41124428/spark-yarn-cluster-vs-client-how-to-choose-which-one-to-use?rq=1).

**spark.executor.instances**: this should match the amount of your worker nodes.

**spark.executor.memory**: check the remaining memory first at YARN ResourceManager UI, you may want to kill some dead process first.

**spark.executorEnv.CLASSPATH**: specifies the location of user-defined classes and packages, sometimes this is used for backward compatiblity.

In [None]:
conf = SparkConf()

conf.setAll([("spark.app.name", "mnist-yarn-write"), # your app name
             ("spark.master", "yarn"), # cluster mode, please leave this unchanged
             ("spark.yarn.queue", "default"), # value=default for CPU only, value=gpu if GPU enabled
             ("spark.yarn.maxAppAttempts", "1"), # please leave this unchanged
             ("spark.submit.deployMode", "client"), # check comments above
             ("spark.executor.instances", "2"), # check comments above
             ("spark.executor.memory", "4G"), # check comments above
             ("spark.dynamicAllocation.enabled", "false"), # please leave this unchanged
             ("spark.executorEnv.LD_LIBRARY_PATH", os.environ["LD_LIBRARY_PATH"]), # please leave this unchanged
             ("spark.executorEnv.HADOOP_HDFS_HOME", os.environ["HADOOP_HDFS_HOME"]), # please leave this unchanged
             ("spark.executorEnv.CLASSPATH", os.environ["CLASSPATH"])]) # please leave this unchanged

sc = SparkContext(conf=conf)
sc.addPyFile("mnist_write.py")

## Set up Arguments

In [None]:
args = Namespace(
  format="csv", # output format: ["csv", "csv2", "pickle", "tf", "tfr"]
  num_partitions=10, # number of output partitions
  output="/data/mnist/csv", # HDFS directory to save examples in parallelized format
  read=False, # read previously saved examples
  verify=True # verify saved examples after writing
)

## Main

In [None]:
if not args.read:
  # Note: these files are inside the mnist.zip file
  writeMNIST(sc, "data/train-images-idx3-ubyte.gz", "data/train-labels-idx1-ubyte.gz", args.output + "/train", args.format, args.num_partitions)
  writeMNIST(sc, "data/t10k-images-idx3-ubyte.gz", "data/t10k-labels-idx1-ubyte.gz", args.output + "/test", args.format, args.num_partitions)

if args.read or args.verify:
  readMNIST(sc, args.output + "/train", args.format)

sc.stop()