Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Added GAN parameter for logging generated images to Tensorboard #477

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions niftynet/application/gan_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,11 @@ def switch_sampler(for_training):
stddev=1.0,
dtype=tf.float32)
conditioning = data_dict['conditioning']
net_output = self.net(
fake_image, real_logits, fake_logits = self.net(
noise, images, conditioning, self.is_training)

loss_func = LossFunction(
loss_type=self.action_param.loss_type)
real_logits = net_output[1]
fake_logits = net_output[2]
lossG, lossD = loss_func(real_logits, fake_logits)
if self.net_param.decay > 0:
reg_losses = tf.get_collection(
Expand Down Expand Up @@ -220,6 +218,23 @@ def switch_sampler(for_training):
outputs_collector.add_to_collection(
var=lossG, name='lossG', average_over_devices=False,
collection=TF_SUMMARIES)
# images to display in tensorboard
if self.gan_param.tensorboard_n_fake_images > 0:
outputs_collector.add_to_collection(
var=fake_image[:self.gan_param.tensorboard_n_fake_images],
name='fake_image_sagittal',
collection=TF_SUMMARIES, summary_type='image3_sagittal_n')

outputs_collector.add_to_collection(
var=fake_image[:self.gan_param.tensorboard_n_fake_images],
name='fake_image_coronal',
collection=TF_SUMMARIES, summary_type='image3_coronal_n')

outputs_collector.add_to_collection(
var=fake_image[:self.gan_param.tensorboard_n_fake_images],
name='fake_image_axial',
collection=TF_SUMMARIES, summary_type='image3_axial_n')


with tf.name_scope('Optimiser'):
optimiser_class = OptimiserFactory.create(
Expand Down
9 changes: 6 additions & 3 deletions niftynet/engine/application_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from tensorflow.contrib.framework import list_variables

from niftynet.io.misc_io import \
image3_axial, image3_coronal, image3_sagittal, resolve_checkpoint
image3_axial, image3_coronal, image3_sagittal, \
image3_axial_n, image3_coronal_n, image3_sagittal_n, resolve_checkpoint
from niftynet.utilities import util_common as util
from niftynet.utilities.restore_initializer import restore_initializer

Expand All @@ -22,8 +23,10 @@
'image': tf.summary.image,
'image3_sagittal': image3_sagittal,
'image3_coronal': image3_coronal,
'image3_axial': image3_axial}

'image3_axial': image3_axial,
'image3_sagittal_n': image3_sagittal_n,
'image3_coronal_n': image3_coronal_n,
'image3_axial_n': image3_axial_n}

class GradientsCollector(object):
"""
Expand Down
49 changes: 49 additions & 0 deletions niftynet/io/misc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,55 @@ def image3_axial(name,
return image3(name, tensor, max_outputs, collections, [3], [1, 2])


def image3_sagittal_n(name,
tensor,
collections=(tf.GraphKeys.SUMMARIES, )):
"""
Create 2D image summary in the sagittal view. An image will be generated for
every element in axis 0 of the tensor.

:param name:
:param tensor:
:param max_outputs:
:param collections:
:return:
"""
return image3(name, tensor, tensor.shape.as_list()[0], collections, [1], [2, 3])


def image3_coronal_n(name,
tensor,
collections=(tf.GraphKeys.SUMMARIES, )):
"""
Create 2D image summary in the coronal view. An image will be generated for
every element in axis 0 of the tensor.

:param name:
:param tensor:
:param max_outputs:
:param collections:
:return:
"""
return image3(name, tensor, tensor.shape.as_list()[0], collections, [2], [1, 3])


def image3_axial_n(name,
tensor,
max_outputs=3,
collections=(tf.GraphKeys.SUMMARIES, )):
"""
Create 2D image summary in the axial view. An image will be generated for
every element in axis 0 of the tensor.

:param name:
:param tensor:
:param max_outputs:
:param collections:
:return:
"""
return image3(name, tensor, tensor.shape.as_list()[0], collections, [3], [1, 2])


def set_logger(file_name=None):
"""
Writing logs to a file if file_name,
Expand Down
6 changes: 6 additions & 0 deletions niftynet/utilities/user_parameters_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def __add_gan_args(parser):
type=int,
default=10)

parser.add_argument(
"--tensorboard_n_fake_images",
help="the number of fake images to log to Tensorboard in every update",
type=int,
default=0)

from niftynet.application.gan_application import SUPPORTED_INPUT
parser = add_input_name_args(parser, SUPPORTED_INPUT)
return parser
Expand Down