In [None]:
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, parse_record_fn, num_epochs):
    # make the dataset prefetchable for parallellism
    dataset = dataset.prefetch(buffer_size=batch_size)
    
    # shuffle dataset
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)
    
    # repeat shuffled dataset for multi-epoch training
    dataset = dataset.repeat(num_epochs)

    # Parse the raw records into images and labels and batch them
    dataset = dataset.map(lambda x : parse_record_fn(x, is_training), num_parallel_calls=1)        
    dataset = dataset.batch(batch_size)
    
    # prefetch one batch at a time
    dataset.prefetch(1)

    return dataset

In [None]:
def learning_schedule(batch_size, batch_denom, num_images, boundary_epochs, decay_rates):
    initial_learning_rate = 0.1 * batch_size / batch_denom
    batches_per_epoch = num_images / batch_size

    # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    # a global step means running an optimization op on a batch
    def learning_rate_fn(global_step):
        global_step = tf.cast(global_step, tf.int32)
        return tf.train.piecewise_constant(global_step, boundaries, vals)

    return learning_rate_fn

In [None]:
def resnet_model_fn(features, labels, mode, model_class,
                    resnet_size, weight_decay, learning_rate_fn, momentum,
                    data_format, resnet_version, loss_scale,
                    loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):
    """Shared functionality for different resnet model_fns.
    Initializes the ResnetModel representing the model layers
    and uses that model to build the necessary EstimatorSpecs for
    the `mode` in question. For training, this means building losses,
    the optimizer, and the train op that get passed into the EstimatorSpec.
    For evaluation and prediction, the EstimatorSpec is returned without
    a train op, but with the necessary parameters for the given mode.
    Args:
    features: tensor representing input images
    labels: tensor representing class labels for all input images
    mode: current estimator mode; should be one of
      `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
    model_class: a class representing a TensorFlow model that has a __call__
      function. We assume here that this is a subclass of ResnetModel.
    resnet_size: A single integer for the size of the ResNet model.
    weight_decay: weight decay loss rate used to regularize learned variables.
    learning_rate_fn: function that returns the current learning rate given
      the current global_step
    momentum: momentum term used for optimization
    data_format: Input format ('channels_last', 'channels_first', or None).
      If set to None, the format is dependent on whether a GPU is available.
    resnet_version: Integer representing which version of the ResNet network to
      use. See README for details. Valid values: [1, 2]
    loss_scale: The factor to scale the loss for numerical stability. A detailed
      summary is present in the arg parser help text.
    loss_filter_fn: function that takes a string variable name and returns
      True if the var should be included in loss calculation, and False
      otherwise. If None, batch_normalization variables will be excluded
      from the loss.
    dtype: the TensorFlow dtype to use for calculations.
    Returns:
    EstimatorSpec parameterized according to the input params and the
    current mode.
    """

    # Generate a summary node for the images
    tf.summary.image('images', features, max_outputs=6)

    features = tf.cast(features, dtype)

    model = model_class(resnet_size, data_format, resnet_version=resnet_version, dtype=dtype)

    logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)

    # This acts as a no-op if the logits are already in fp32 (provided logits are
    # not a SparseTensor). If dtype is is low precision, logits must be cast to
    # fp32 for numerical stability.
    logits = tf.cast(logits, tf.float32)

    predictions = {
      'classes': tf.argmax(logits, axis=1),
      'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Return the predictions and the specification for serving a SavedModel
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'predict': tf.estimator.export.PredictOutput(predictions)
            })

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    # cross entropy part
    cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)

    # Create a tensor named cross_entropy for logging purposes.
    tf.identity(cross_entropy, name='cross_entropy')
    tf.summary.scalar('cross_entropy', cross_entropy)
    
    # L2 regularization part
    def exclude_batch_norm(name):
        return 'batch_normalization' not in name
    
    loss_filter_fn = loss_filter_fn or exclude_batch_norm

    # Add weight decay to the loss.
    l2_loss = weight_decay * tf.add_n(
      [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
       if loss_filter_fn(v.name)])
    
    tf.summary.scalar('l2_loss', l2_loss)
    loss = cross_entropy + l2_loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        learning_rate = learning_rate_fn(global_step)

        # Create a tensor named learning_rate for logging purposes
        tf.identity(learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=momentum)

        if loss_scale != 1:
            # When computing fp16 gradients, often intermediate tensor values are
            # so small, they underflow to 0. To avoid this, we multiply the loss by
            # loss_scale to make these tensor values loss_scale times bigger.
            scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)

            # Once the gradient computation is complete we can scale the gradients
            # back to the correct scale before passing them to the optimizer.
            unscaled_grad_vars = [(grad / loss_scale, var) for grad, var in scaled_grad_vars]
            minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
        else:
            minimize_op = optimizer.minimize(loss, global_step)
        
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.group(minimize_op, update_ops)
    else:
        train_op = None

       
    if not tf.contrib.distribute.has_distribution_strategy():
        accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes'])
    else:
        # Metrics are currently not compatible with distribution strategies during
        # training. This does not affect the overall performance of the model.
        accuracy = (tf.no_op(), tf.constant(0))

    metrics = {'accuracy': accuracy}

    # Create a tensor named train_accuracy for logging purposes
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, 
                                      train_op=train_op, eval_metric_ops=metrics)

In [None]:
def resnet_main(flags_obj, model_function, input_function, dataset_name, shape=None):
    """Shared main loop for ResNet Models.
    Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
    for details.
    model_function: the function that instantiates the Model and builds the
    ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
    dataset that the estimator can train on. This will be wrapped with
    all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
    used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
    This is only used if flags_obj.export_dir is passed.
    """

    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    # Create session config based on values of inter_op_parallelism_threads and
    # intra_op_parallelism_threads. Note that we default to having
    # allow_soft_placement = True, which is required for multi-GPU and not
    # harmful for other modes.
    session_config = tf.ConfigProto(
        inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
        intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
        allow_soft_placement=True)

    if flags_core.get_num_gpus(flags_obj) == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
    elif flags_core.get_num_gpus(flags_obj) == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
    else:
        distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=flags_core.get_num_gpus(flags_obj))

    run_config = tf.estimator.RunConfig(train_distribute=distribution, session_config=session_config)

    classifier = tf.estimator.Estimator(model_fn=model_function, model_dir=flags_obj.model_dir, 
                                        config=run_config,
                                        params={
                                            'resnet_size': int(flags_obj.resnet_size),
                                            'data_format': flags_obj.data_format,
                                            'batch_size': flags_obj.batch_size,
                                            'resnet_version': int(flags_obj.resnet_version),
                                            'loss_scale': flags_core.get_loss_scale(flags_obj),
                                            'dtype': flags_core.get_tf_dtype(flags_obj)
                                            })

    run_params = {
        'batch_size': flags_obj.batch_size,
        'dtype': flags_core.get_tf_dtype(flags_obj),
        'resnet_size': flags_obj.resnet_size,
        'resnet_version': flags_obj.resnet_version,
        'synthetic_data': flags_obj.use_synthetic_data,
        'train_epochs': flags_obj.train_epochs,
    }
    
    benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
    benchmark_logger.log_run_info('resnet', dataset_name, run_params)

    train_hooks = hooks_helper.get_train_hooks(flags_obj.hooks,batch_size=flags_obj.batch_size,
                                               benchmark_log_dir=flags_obj.benchmark_log_dir)

    def input_fn_train():
        return input_function(is_training=True, data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(flags_obj.batch_size, 
                                                               flags_core.get_num_gpus(flags_obj)),
                              num_epochs=flags_obj.epochs_between_evals)

    def input_fn_eval():
        return input_function(is_training=False, data_dir=flags_obj.data_dir,
                              batch_size=per_device_batch_size(flags_obj.batch_size, 
                                                               flags_core.get_num_gpus(flags_obj)),
                              num_epochs=1)

    total_training_cycle = (flags_obj.train_epochs // flags_obj.epochs_between_evals)
    
    for cycle_index in range(total_training_cycle):
        tf.logging.info('Starting a training cycle: %d/%d', cycle_index, total_training_cycle)

        classifier.train(input_fn=input_fn_train, hooks=train_hooks,
        max_steps=flags_obj.max_train_steps)

        tf.logging.info('Starting to evaluate.')

        # flags_obj.max_train_steps is generally associated with testing and
        # profiling. As a result it is frequently called with synthetic data, which
        # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
        # eval (which is generally unimportant in those circumstances) to terminate.
        # Note that eval will run for max_train_steps each loop, regardless of the
        # global_step count.
        eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                           steps=flags_obj.max_train_steps)

        benchmark_logger.log_evaluation_result(eval_results)

        if model_helpers.past_stop_threshold(flags_obj.stop_threshold, eval_results['accuracy']):
            break

    if flags_obj.export_dir is not None:
        # Exports a saved model for the given classifier.
        input_receiver_fn = export.build_tensor_serving_input_receiver_fn(shape, batch_size=flags_obj.batch_size)
        classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)


def define_resnet_flags(resnet_size_choices=None):
    """Add flags and validators for ResNet."""
    flags_core.define_base()
    flags_core.define_performance(num_parallel_calls=False)
    flags_core.define_image()
    flags_core.define_benchmark()
    flags.adopt_module_key_flags(flags_core)

    flags.DEFINE_enum(
    name='resnet_version', short_name='rv', default='2',
    enum_values=['1', '2'],
    help=flags_core.help_wrap(
    'Version of ResNet. (1 or 2) See README.md for details.'))


    choice_kwargs = dict(
    name='resnet_size', short_name='rs', default='50',
    help=flags_core.help_wrap('The size of the ResNet model to use.'))

    if resnet_size_choices is None:
    flags.DEFINE_string(**choice_kwargs)
    else:
    flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)