Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
no moving averages
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli committed Aug 24, 2017
1 parent eba0cc6 commit 31fd5e1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 43 deletions.
2 changes: 1 addition & 1 deletion niftynet/application/segmentation_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def connect_data_and_network(self,
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETORK_OUTPUT)
init_aggregator = \
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
init_aggregator()

def interpret_output(self, batch_output):
Expand Down
80 changes: 41 additions & 39 deletions niftynet/engine/application_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, division

import os
Expand All @@ -16,7 +17,7 @@
from niftynet.engine.application_variables import \
global_variables_initialize_or_restorer
from niftynet.io.misc_io import touch_folder, get_latest_subfolder
from niftynet.layer.bn import BN_COLLECTION_NAME
from niftynet.layer.bn import BN_COLLECTION

FILE_PREFIX = 'model.ckpt'

Expand Down Expand Up @@ -166,26 +167,20 @@ def _create_graph(self):
self.outputs_collector,
self.gradients_collector)
# global batch norm statistics from the last device
bn_ops = tf.get_collection(BN_COLLECTION_NAME, scope) \
bn_ops = tf.get_collection(BN_COLLECTION, scope) \
if self.is_training else None

# assemble all training operations
if self.is_training:
if self.is_training and self.gradients_collector:
updates_op = []
# model moving average operation
with tf.name_scope('MovingAverages'):
mva_op = ApplicationDriver._model_moving_averaging_op()
if not mva_op.type == "NoOp":
updates_op.extend(mva_op)
# batch normalisation moving averages operation
if bn_ops:
updates_op.extend(bn_ops)
# combine them with model parameter updating operation
with tf.name_scope('ApplyGradients'):
if self.gradients_collector is not None:
with graph.control_dependencies(updates_op):
self.app.set_network_update_op(
self.gradients_collector.gradients)
with graph.control_dependencies(updates_op):
self.app.set_network_update_op(
self.gradients_collector.gradients)

# initialisation operation
with tf.name_scope('Initialization'):
Expand Down Expand Up @@ -233,8 +228,13 @@ def _randomly_init_or_restore_variables(self, sess):
checkpoint = '{}-{}'.format(self.session_dir, self.initial_iter)
# restore session
tf.logging.info('Accessing {} ...'.format(checkpoint))
self.saver.restore(sess, checkpoint)
return
try:
self.saver.restore(sess, checkpoint)
except tf.errors.NotFoundError:
tf.logging.fatal(
'checkpoint {} not found or variables to restore do not '
'match the current application graph'.format(checkpoint))
raise

def _training_loop(self, sess, loop_status):
writer = tf.summary.FileWriter(self.summary_dir, sess.graph)
Expand All @@ -247,32 +247,29 @@ def _training_loop(self, sess, loop_status):
if self._coord.should_stop():
break

# prepare variables from the graph to run
# variables to the graph
vars_to_run = dict(train_op=train_op)
vars_to_run[CONSOLE] = \
self.outputs_collector.variables(collection=CONSOLE)
vars_to_run[NETORK_OUTPUT] = \
self.outputs_collector.variables(collection=NETORK_OUTPUT)
vars_to_run[CONSOLE], vars_to_run[NETORK_OUTPUT] = \
self.outputs_collector.variables(CONSOLE), \
self.outputs_collector.variables(NETORK_OUTPUT)
if iter_i % self.tensorboard_every_n == 0:
# adding tensorboard summary
vars_to_run[TF_SUMMARIES] = \
self.outputs_collector.variables(collection=TF_SUMMARIES)

# run all variables in one go
graph_output = sess.run(vars_to_run)

# process graph outputs
self.app.interpret_output(graph_output[NETORK_OUTPUT])
# if application specified summaries
console_str = self._console_vars_to_str(graph_output[CONSOLE])
summary = graph_output.get(TF_SUMMARIES, {})
if summary != {}:
if summary:
writer.add_summary(summary, iter_i)

# save current model
if iter_i % self.save_every_n == 0:
self._save_model(sess, iter_i)

# print variables of the updated network
console = graph_output.get(CONSOLE, {})
console_str = ', '.join(
'{}={}'.format(key, val) for (key, val) in console.items())
tf.logging.info('iter {}, {} ({:.3f}s)'.format(
iter_i, console_str, time.time() - local_time))

Expand All @@ -282,19 +279,22 @@ def _inference_loop(self, sess, loop_status):
local_time = time.time()
if self._coord.should_stop():
break

# build variables to run
vars_to_run = dict()
vars_to_run[NETORK_OUTPUT] = \
self.outputs_collector.variables(collection=NETORK_OUTPUT)
vars_to_run[CONSOLE] = \
self.outputs_collector.variables(collection=CONSOLE)
vars_to_run[NETORK_OUTPUT], vars_to_run[CONSOLE] = \
self.outputs_collector.variables(NETORK_OUTPUT), \
self.outputs_collector.variables(CONSOLE)

# evaluate the graph variables
graph_output = sess.run(vars_to_run)

# process the graph outputs
if not self.app.interpret_output(graph_output[NETORK_OUTPUT]):
tf.logging.info('processed all batches.')
loop_status['all_saved_flag'] = True
break
console = graph_output.get(CONSOLE, {})
console_str = ', '.join(
'{}={}'.format(key, val) for (key, val) in console.items())
console_str = self._console_vars_to_str(graph_output[CONSOLE])
tf.logging.info('{} ({:.3f}s)'.format(
console_str, time.time() - local_time))

Expand All @@ -306,6 +306,14 @@ def _save_model(self, session, iter_i):
global_step=iter_i)
tf.logging.info('iter {} saved: {}'.format(iter_i, self.session_dir))

@staticmethod
def _console_vars_to_str(console_dict):
if not console_dict:
return ''
console_str = ', '.join(
'{}={}'.format(key, val) for (key, val) in console_dict.items())
return console_str

def _device_string(self, id=0, is_worker=True):
devices = device_lib.list_local_devices()
has_local_gpu = any([x.device_type == 'GPU' for x in devices])
Expand All @@ -330,12 +338,6 @@ def _set_cuda_device(cuda_devices):
# using Tensorflow default choice
pass

@staticmethod
def _model_moving_averaging_op(decay=0.9):
variable_averages = tf.train.ExponentialMovingAverage(decay)
trainables = tf.trainable_variables()
return variable_averages.apply(var_list=trainables)

@staticmethod
def _create_app(app_type_string):
return ApplicationFactory.create(app_type_string)
Expand Down
6 changes: 3 additions & 3 deletions niftynet/layer/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from niftynet.layer.base_layer import TrainableLayer

BN_COLLECTION_NAME = tf.GraphKeys.UPDATE_OPS
BN_COLLECTION = tf.GraphKeys.UPDATE_OPS


class BNLayer(TrainableLayer):
Expand Down Expand Up @@ -72,8 +72,8 @@ def layer_op(self, inputs, is_training, use_local_stats=False):
moving_mean, mean, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.moving_decay).op
tf.add_to_collection(BN_COLLECTION_NAME, update_moving_mean)
tf.add_to_collection(BN_COLLECTION_NAME, update_moving_variance)
tf.add_to_collection(BN_COLLECTION, update_moving_mean)
tf.add_to_collection(BN_COLLECTION, update_moving_variance)

# call the normalisation function
if is_training or use_local_stats:
Expand Down

0 comments on commit 31fd5e1

Please sign in to comment.