In [1]:
import tensorflow as tf
import os
import time

  return f(*args, **kwds)


In [2]:
def get_loss_and_accuracy(images, labels):
  """ 
  Constructs the model training graph, which reads batches of data from data_dir and outputs the model's loss/zero-one accuracy on that batch.  
  Data is read only from files in data_dir with the specified filename prefix.
  
  Returns a tuple of tensors (loss, accuracy) corresponding to our model's loss/zero-one accuracy on a batch of data loaded from data_dir.
  """
  # Import the mnist module here so that it's available on the Spark executors within which this function will be run.
  from tensorflow.examples.tutorials.mnist import mnist  
  
  # Taken from https://github.com/tensorflow/tensorflow/blob/v1.4.0/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py#L131
  # Build a Graph that computes predictions from the inference model.
  logits = mnist.inference(images, hidden1_units=100, hidden2_units=100)

  # Add to the Graph the loss calculation, as well as an op for computing the zero-one accuracy of the model
  batch_size = tf.shape(images)[0]
  evaluation = mnist.evaluation(logits, labels)
  accuracy = tf.cast(evaluation, tf.float32) / tf.cast(batch_size, tf.float32)
  return (mnist.loss(logits, labels), accuracy)

In [3]:
def add_optimizers_and_launch_sess(loss, learning_rate, checkpoint_dir, task_index, num_workers, global_step, server):
  # Determine if current process is the chief worker
  is_chief = (task_index == 0)
  
  # Create Adam optimizer and wrap it in a SyncReplicasOptimizer, which coordinates updates across workers
  # For more information see https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/train/SyncReplicasOptimizer
  opt = tf.train.AdamOptimizer(LEARNING_RATE)        
  opt = tf.train.SyncReplicasOptimizer(
      opt,
      replicas_to_aggregate=num_workers,
      total_num_replicas=num_workers,
      name="mnist_sync_replicas")
  # Compute gradients with respect to the loss.
  grads = opt.compute_gradients(loss)
  apply_gradients_op = opt.apply_gradients(grads, global_step=global_step)
  with tf.control_dependencies([apply_gradients_op]):
    train_op = tf.identity(loss, name='train_op')

  init_op = tf.global_variables_initializer()
  # Get additional ops that must be run to initialize the SyncReplicasOptimizer
  chief_queue_runner = opt.get_chief_queue_runner()
  init_tokens_op = opt.get_init_tokens_op()

  # Create a Supervisor to manage model checkpointing
  # See https://www.tensorflow.org/api_docs/python/tf/train/Supervisor for more info
  sv = tf.train.Supervisor(
      is_chief=is_chief,
      logdir=CHECKPOINT_DIR,
      init_op=init_op,
      global_step=global_step,
      save_model_secs=30,
      recovery_wait_secs=1
  )

  # Create session config
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=False,
      device_filters=["/job:ps", "/job:worker/task:%d" % task_index])
  
  # The chief worker (task_index==0) will prepare the session,
  # while the remaining workers will wait for the session to be available.
  sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
  
  if is_chief:
    # Chief worker will start the chief queue runner and call the init op.
    sv.start_queue_runners(sess, [chief_queue_runner])        
    sess.run(init_tokens_op)  
    
  return (sess, train_op)