In [1]:
# Constants describing where our training & validation data is located in S3
# S3 bucket containing our training & validation datasets
BUCKET_NAME = "databricks-public-datasets"
# AWS region of our S3 bucket
S3_REGION = "us-west-2"
# Prefix of training data filenames within our S3 bucket
TRAIN_PREFIX = "mnist/train"
# Prefix of validation data filenames within our S3 bucket
EVAL_PREFIX = "mnist/validation"

In [2]:
# Number of examples in our validation set, used to ensure we process the entire validation set during each model evaluation step
NUM_VALIDATION_EXAMPLES = 10000

In [3]:
# This directory persistently stores model checkpoints in Databricks' distributed filesystem (DBFS).
# Model checkpoints will be preserved across cluster restarts and after training finishes.
# If modifying this directory, note that our model training logic assumes that this directory is on
# DBFS (begins with "/dbfs/").
# 
# Since the checkpoint directory is distributed, its contents can be overwritten by other users or during successive
# model training runs; we recommend modifying the code below to use a unique distributed directory for each model
# training run (i.e., generate a unique directory name during each run)
#
# See https://docs.databricks.com/user-guide/dbfs-databricks-file-system.html for more info on DBFS.
CHECKPOINT_DIR = "/dbfs/tmp/mnist-checkpoint-dir"
# This directory persistently stores event files describing model performance on the training & validation datasets.
# Event files will be preserved across cluster restarts and after training finishes, and can be consumed by
# TensorBoard.
DBFS_EVENT_FILE_DIR = "/dbfs/tmp/mnist-event-file-dir"
# If set to true, the content of the above directories will be destroyed before a run.
REMOVE_OLD_DIRS = True

In [4]:
assert CHECKPOINT_DIR.startswith("/dbfs"), "Checkpoint directory must be on DBFS (start with '/dbfs/'), got non-DBFS checkpoint directory %s"%CHECKPOINT_DIR
assert DBFS_EVENT_FILE_DIR.startswith("/dbfs"), "Event file directory must be on DBFS (start with '/dbfs/'), got non-DBFS event file directory %s"%DBFS_EVENT_FILE_DIR

In [5]:
# Local directory in which to store training data (on the Spark workers) & validation data (on the driver)
LOCAL_DATA_DIR = "/tmp/mnist"
# Local directory in which to store training event files (on the chief Spark worker) & validation event files (on the driver)
LOCAL_EVENT_FILE_DIR = "/tmp/mnist-event-files"

In [7]:
# Number of Spark executors in the cluster; this field likely does not need to be modified
NUM_EXECUTORS = sc.defaultParallelism
# Number of parameter servers; we'll launch NUM_EXECUTORS - NUM_PS Tensorflow workers for training
# NUM_PS must be at least 1 and less than NUM_EXECUTORS
NUM_PS = 1 

NameError: name 'sc' is not defined

In [8]:
assert NUM_PS > 0, "Must use at least one parameter server for model training, got NUM_PS = %s"%NUM_PS
assert NUM_PS < NUM_EXECUTORS, "Number of parameter servers (%s) must be less than the number of Spark executors (%s)"%(NUM_PS, NUM_EXECUTORS)

NameError: name 'NUM_PS' is not defined

In [9]:
BATCH_SIZE = 32
# Learning rate
LEARNING_RATE = 5e-6
# Number of (global) training steps to perform
TRAIN_STEPS = 100000