Skip to content

Commit

Permalink
Update resnet to run with tf r0.12 API. (#833)
Browse files Browse the repository at this point in the history
* Update resnet to run with tf r0.12 API.
1. tf.image.per_image_whitening -> tf.image.per_image_standardization
2. Use tf.summary to replace tf.image_summary, tf.scalar_summary, tf.merge_all_summaries.

* remove log
  • Loading branch information
selectwait authored and drpngx committed Jan 16, 2017
1 parent f88eef9 commit 22036b6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions resnet/cifar_input.py
Expand Up @@ -84,7 +84,7 @@ def build_input(dataset, data_path, batch_size, mode):
else:
image = tf.image.resize_image_with_crop_or_pad(
image, image_size, image_size)
image = tf.image.per_image_whitening(image)
image = tf.image.per_image_standardization(image)

example_queue = tf.FIFOQueue(
3 * batch_size,
Expand Down Expand Up @@ -112,5 +112,5 @@ def build_input(dataset, data_path, batch_size, mode):
assert labels.get_shape()[1] == num_classes

# Display the training images in the visualizer.
tf.image_summary('images', images)
tf.summary.image('images', images)
return images, labels
4 changes: 2 additions & 2 deletions resnet/resnet_main.py
Expand Up @@ -70,8 +70,8 @@ def train(hps):
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
output_dir=FLAGS.train_dir,
summary_op=[model.summaries,
tf.summary.scalar('Precision', precision)])
summary_op=tf.summary.merge([model.summaries,
tf.summary.scalar('Precision', precision)]))

logging_hook = tf.train.LoggingTensorHook(
tensors={'step': model.global_step,
Expand Down
6 changes: 3 additions & 3 deletions resnet/resnet_model.py
Expand Up @@ -59,7 +59,7 @@ def build_graph(self):
self._build_model()
if self.mode == 'train':
self._build_train_op()
self.summaries = tf.merge_all_summaries()
self.summaries = tf.summary.merge_all()

def _stride_arr(self, stride):
"""Map a stride scalar to the stride array for tf.nn.conv2d."""
Expand Down Expand Up @@ -122,12 +122,12 @@ def _build_model(self):
self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay()

tf.scalar_summary('cost', self.cost)
tf.summary.scalar('cost', self.cost)

def _build_train_op(self):
"""Build training specific ops for the graph."""
self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
tf.scalar_summary('learning rate', self.lrn_rate)
tf.summary.scalar('learning rate', self.lrn_rate)

trainable_variables = tf.trainable_variables()
grads = tf.gradients(self.cost, trainable_variables)
Expand Down

0 comments on commit 22036b6

Please sign in to comment.