From c77181b4154520dfe1d7550288e30a172fc91438 Mon Sep 17 00:00:00 2001 From: Carole Sudre Date: Thu, 9 Aug 2018 11:10:46 +0100 Subject: [PATCH] Creation of a bias field augmentation layer. Usage presented in the segmentation_bfaug application in contrib. The bias field is modelled as a linear combination of polynomial functions, exponentiated and multiplicative to the image to augment. --- config/default_segmentation_bf.ini | 77 ++++ .../contrib/segmentation_bf_aug/__init__.py | 0 .../segmentation_application_bfaug.py | 356 ++++++++++++++++++ niftynet/layer/rand_bias_field.py | 168 +++++++++ niftynet/utilities/user_parameters_default.py | 15 + 5 files changed, 616 insertions(+) create mode 100755 config/default_segmentation_bf.ini create mode 100644 niftynet/contrib/segmentation_bf_aug/__init__.py create mode 100755 niftynet/contrib/segmentation_bf_aug/segmentation_application_bfaug.py create mode 100644 niftynet/layer/rand_bias_field.py diff --git a/config/default_segmentation_bf.ini b/config/default_segmentation_bf.ini new file mode 100755 index 00000000..b46c2f96 --- /dev/null +++ b/config/default_segmentation_bf.ini @@ -0,0 +1,77 @@ +############################ input configuration sections +[modality1] +csv_file= +path_to_search = ./example_volumes/monomodal_parcellation +filename_contains = T1 +filename_not_contains = +spatial_window_size = (20, 42, 42) +interp_order = 3 +pixdim=(1.0, 1.0, 1.0) +axcodes=(A, R, S) + +[label] +path_to_search = ./example_volumes/monomodal_parcellation +filename_contains = Label +filename_not_contains = +spatial_window_size = (20, 42, 42) +interp_order = 0 +pixdim=(1.0, 1.0, 1.0) +axcodes=(A, R, S) + +############################## system configuration sections +[SYSTEM] +cuda_devices = "" +num_threads = 2 +num_gpus = 1 +model_dir = ./models/model_monomodal_toy + +[NETWORK] +name = toynet +activation_function = prelu +batch_size = 1 +decay = 0.1 +reg_type = L2 + +# volume level preprocessing +volume_padding_size = 21 +# histogram normalisation +histogram_ref_file = ./example_volumes/monomodal_parcellation/standardisation_models.txt +norm_type = percentile +cutoff = (0.01, 0.99) +normalisation = False +whitening = False +normalise_foreground_only=True +foreground_type = otsu_plus +multimod_foreground_type = and + +queue_length = 20 + + +[TRAINING] +sample_per_volume = 32 +rotation_angle = +scaling_percentage = +bf_order = 3 +bias_field_range = (-0.5, 0.5) +random_flipping_axes= 1 +lr = 0.01 +loss_type = Dice +starting_iter = 0 +save_every_n = 100 +max_iter = 10 +max_checkpoints = 20 + +[INFERENCE] +border = (0, 0, 1) +#inference_iter = 10 +save_seg_dir = ./output/toy +output_interp_order = 0 +spatial_window_size = (0, 0, 3) + +############################ custom configuration sections +[SEGMENTATION] +image = modality1 +label = label +output_prob = False +num_classes = 160 +label_normalisation = True diff --git a/niftynet/contrib/segmentation_bf_aug/__init__.py b/niftynet/contrib/segmentation_bf_aug/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/niftynet/contrib/segmentation_bf_aug/segmentation_application_bfaug.py b/niftynet/contrib/segmentation_bf_aug/segmentation_application_bfaug.py new file mode 100755 index 00000000..bf2076fd --- /dev/null +++ b/niftynet/contrib/segmentation_bf_aug/segmentation_application_bfaug.py @@ -0,0 +1,356 @@ +import tensorflow as tf + +from niftynet.application.base_application import BaseApplication +from niftynet.engine.application_factory import \ + ApplicationNetFactory, InitializerFactory, OptimiserFactory +from niftynet.engine.application_variables import \ + CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES +from niftynet.engine.sampler_grid import GridSampler +from niftynet.engine.sampler_resize import ResizeSampler +from niftynet.engine.sampler_uniform import UniformSampler +from niftynet.engine.sampler_weighted import WeightedSampler +from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator +from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator +from niftynet.io.image_reader import ImageReader +from niftynet.layer.binary_masking import BinaryMaskingLayer +from niftynet.layer.discrete_label_normalisation import \ + DiscreteLabelNormalisationLayer +from niftynet.layer.histogram_normalisation import \ + HistogramNormalisationLayer +from niftynet.layer.loss_segmentation import LossFunction +from niftynet.layer.mean_variance_normalisation import \ + MeanVarNormalisationLayer +from niftynet.layer.pad import PadLayer +from niftynet.layer.post_processing import PostProcessingLayer +from niftynet.layer.rand_flip import RandomFlipLayer +from niftynet.layer.rand_rotation import RandomRotationLayer +from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer +from niftynet.layer.rand_bias_field import RandomBiasFieldLayer + +SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler']) + + +class SegmentationApplicationBFAug(BaseApplication): + REQUIRED_CONFIG_SECTION = "SEGMENTATION" + + def __init__(self, net_param, action_param, is_training): + super(SegmentationApplicationBFAug, self).__init__() + tf.logging.info('starting segmentation application') + self.is_training = is_training + + self.net_param = net_param + self.action_param = action_param + + self.data_param = None + self.segmentation_param = None + self.SUPPORTED_SAMPLING = { + 'uniform': (self.initialise_uniform_sampler, + self.initialise_grid_sampler, + self.initialise_grid_aggregator), + 'weighted': (self.initialise_weighted_sampler, + self.initialise_grid_sampler, + self.initialise_grid_aggregator), + 'resize': (self.initialise_resize_sampler, + self.initialise_resize_sampler, + self.initialise_resize_aggregator), + } + + def initialise_dataset_loader( + self, data_param=None, task_param=None, data_partitioner=None): + + self.data_param = data_param + self.segmentation_param = task_param + + # read each line of csv files into an instance of Subject + if self.is_training: + file_lists = [] + if self.action_param.validation_every_n > 0: + file_lists.append(data_partitioner.train_files) + file_lists.append(data_partitioner.validation_files) + else: + file_lists.append(data_partitioner.train_files) + + self.readers = [] + for file_list in file_lists: + reader = ImageReader(SUPPORTED_INPUT) + reader.initialise(data_param, task_param, file_list) + self.readers.append(reader) + + else: # in the inference process use image input only + inference_reader = ImageReader(['image']) + file_list = data_partitioner.inference_files + inference_reader.initialise(data_param, task_param, file_list) + self.readers = [inference_reader] + + foreground_masking_layer = None + if self.net_param.normalise_foreground_only: + foreground_masking_layer = BinaryMaskingLayer( + type_str=self.net_param.foreground_type, + multimod_fusion=self.net_param.multimod_foreground_type, + threshold=0.0) + + mean_var_normaliser = MeanVarNormalisationLayer( + image_name='image', binary_masking_func=foreground_masking_layer) + histogram_normaliser = None + if self.net_param.histogram_ref_file: + histogram_normaliser = HistogramNormalisationLayer( + image_name='image', + modalities=vars(task_param).get('image'), + model_filename=self.net_param.histogram_ref_file, + binary_masking_func=foreground_masking_layer, + norm_type=self.net_param.norm_type, + cutoff=self.net_param.cutoff, + name='hist_norm_layer') + + label_normaliser = None + if self.net_param.histogram_ref_file: + label_normaliser = DiscreteLabelNormalisationLayer( + image_name='label', + modalities=vars(task_param).get('label'), + model_filename=self.net_param.histogram_ref_file) + + normalisation_layers = [] + if self.net_param.normalisation: + normalisation_layers.append(histogram_normaliser) + if self.net_param.whitening: + normalisation_layers.append(mean_var_normaliser) + if task_param.label_normalisation and \ + (self.is_training or not task_param.output_prob): + normalisation_layers.append(label_normaliser) + + augmentation_layers = [] + if self.is_training: + if self.action_param.random_flipping_axes != -1: + augmentation_layers.append(RandomFlipLayer( + flip_axes=self.action_param.random_flipping_axes)) + if self.action_param.scaling_percentage: + augmentation_layers.append(RandomSpatialScalingLayer( + min_percentage=self.action_param.scaling_percentage[0], + max_percentage=self.action_param.scaling_percentage[1])) + if self.action_param.rotation_angle or \ + self.action_param.rotation_angle_x or \ + self.action_param.rotation_angle_y or \ + self.action_param.rotation_angle_z: + rotation_layer = RandomRotationLayer() + if self.action_param.rotation_angle: + rotation_layer.init_uniform_angle( + self.action_param.rotation_angle) + else: + rotation_layer.init_non_uniform_angle( + self.action_param.rotation_angle_x, + self.action_param.rotation_angle_y, + self.action_param.rotation_angle_z) + augmentation_layers.append(rotation_layer) + if self.action_param.bias_field_range: + bias_field_layer = RandomBiasFieldLayer() + bias_field_layer.init_order(self.action_param.bf_order) + bias_field_layer.init_uniform_coeff( + self.action_param.bias_field_range) + augmentation_layers.append(bias_field_layer) + + + volume_padding_layer = [] + if self.net_param.volume_padding_size: + volume_padding_layer.append(PadLayer( + image_name=SUPPORTED_INPUT, + border=self.net_param.volume_padding_size)) + + for reader in self.readers: + reader.add_preprocessing_layers( + volume_padding_layer + + normalisation_layers + + augmentation_layers) + + def initialise_uniform_sampler(self): + self.sampler = [[UniformSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + windows_per_image=self.action_param.sample_per_volume, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_weighted_sampler(self): + self.sampler = [[WeightedSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + windows_per_image=self.action_param.sample_per_volume, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_resize_sampler(self): + self.sampler = [[ResizeSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + shuffle_buffer=self.is_training, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_grid_sampler(self): + self.sampler = [[GridSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + spatial_window_size=self.action_param.spatial_window_size, + window_border=self.action_param.border, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_grid_aggregator(self): + self.output_decoder = GridSamplesAggregator( + image_reader=self.readers[0], + output_path=self.action_param.save_seg_dir, + window_border=self.action_param.border, + interp_order=self.action_param.output_interp_order) + + def initialise_resize_aggregator(self): + self.output_decoder = ResizeSamplesAggregator( + image_reader=self.readers[0], + output_path=self.action_param.save_seg_dir, + window_border=self.action_param.border, + interp_order=self.action_param.output_interp_order) + + def initialise_sampler(self): + if self.is_training: + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() + else: + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() + + def initialise_network(self): + w_regularizer = None + b_regularizer = None + reg_type = self.net_param.reg_type.lower() + decay = self.net_param.decay + if reg_type == 'l2' and decay > 0: + from tensorflow.contrib.layers.python.layers import regularizers + w_regularizer = regularizers.l2_regularizer(decay) + b_regularizer = regularizers.l2_regularizer(decay) + elif reg_type == 'l1' and decay > 0: + from tensorflow.contrib.layers.python.layers import regularizers + w_regularizer = regularizers.l1_regularizer(decay) + b_regularizer = regularizers.l1_regularizer(decay) + + self.net = ApplicationNetFactory.create(self.net_param.name)( + num_classes=self.segmentation_param.num_classes, + w_initializer=InitializerFactory.get_initializer( + name=self.net_param.weight_initializer), + b_initializer=InitializerFactory.get_initializer( + name=self.net_param.bias_initializer), + w_regularizer=w_regularizer, + b_regularizer=b_regularizer, + acti_func=self.net_param.activation_function) + + def connect_data_and_network(self, + outputs_collector=None, + gradients_collector=None): + # def data_net(for_training): + # with tf.name_scope('train' if for_training else 'validation'): + # sampler = self.get_sampler()[0][0 if for_training else -1] + # data_dict = sampler.pop_batch_op() + # image = tf.cast(data_dict['image'], tf.float32) + # return data_dict, self.net(image, is_training=for_training) + + def switch_sampler(for_training): + with tf.name_scope('train' if for_training else 'validation'): + sampler = self.get_sampler()[0][0 if for_training else -1] + return sampler.pop_batch_op() + + if self.is_training: + # if self.action_param.validation_every_n > 0: + # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), + # lambda: data_net(True), + # lambda: data_net(False)) + # else: + # data_dict, net_out = data_net(True) + if self.action_param.validation_every_n > 0: + data_dict = tf.cond(tf.logical_not(self.is_validation), + lambda: switch_sampler(for_training=True), + lambda: switch_sampler(for_training=False)) + else: + data_dict = switch_sampler(for_training=True) + image = tf.cast(data_dict['image'], tf.float32) + net_out = self.net(image, is_training=self.is_training) + + with tf.name_scope('Optimiser'): + optimiser_class = OptimiserFactory.create( + name=self.action_param.optimiser) + self.optimiser = optimiser_class.get_instance( + learning_rate=self.action_param.lr) + loss_func = LossFunction( + n_class=self.segmentation_param.num_classes, + loss_type=self.action_param.loss_type) + data_loss = loss_func( + prediction=net_out, + ground_truth=data_dict.get('label', None), + weight_map=data_dict.get('weight', None)) + reg_losses = tf.get_collection( + tf.GraphKeys.REGULARIZATION_LOSSES) + if self.net_param.decay > 0.0 and reg_losses: + reg_loss = tf.reduce_mean( + [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) + loss = data_loss + reg_loss + else: + loss = data_loss + grads = self.optimiser.compute_gradients(loss) + # collecting gradients variables + gradients_collector.add_to_collection([grads]) + # collecting output variables + outputs_collector.add_to_collection( + var=data_loss, name='loss', + average_over_devices=False, collection=CONSOLE) + outputs_collector.add_to_collection( + var=data_loss, name='loss', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + + # outputs_collector.add_to_collection( + # var=image*180.0, name='image', + # average_over_devices=False, summary_type='image3_sagittal', + # collection=TF_SUMMARIES) + + # outputs_collector.add_to_collection( + # var=image, name='image', + # average_over_devices=False, + # collection=NETWORK_OUTPUT) + + # outputs_collector.add_to_collection( + # var=tf.reduce_mean(image), name='mean_image', + # average_over_devices=False, summary_type='scalar', + # collection=CONSOLE) + else: + # converting logits into final output for + # classification probabilities or argmax classification labels + data_dict = switch_sampler(for_training=False) + image = tf.cast(data_dict['image'], tf.float32) + net_out = self.net(image, is_training=self.is_training) + + output_prob = self.segmentation_param.output_prob + num_classes = self.segmentation_param.num_classes + if output_prob and num_classes > 1: + post_process_layer = PostProcessingLayer( + 'SOFTMAX', num_classes=num_classes) + elif not output_prob and num_classes > 1: + post_process_layer = PostProcessingLayer( + 'ARGMAX', num_classes=num_classes) + else: + post_process_layer = PostProcessingLayer( + 'IDENTITY', num_classes=num_classes) + net_out = post_process_layer(net_out) + + outputs_collector.add_to_collection( + var=net_out, name='window', + average_over_devices=False, collection=NETWORK_OUTPUT) + outputs_collector.add_to_collection( + var=data_dict['image_location'], name='location', + average_over_devices=False, collection=NETWORK_OUTPUT) + init_aggregator = \ + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] + init_aggregator() + + def interpret_output(self, batch_output): + if not self.is_training: + return self.output_decoder.decode_batch( + batch_output['window'], batch_output['location']) + return True diff --git a/niftynet/layer/rand_bias_field.py b/niftynet/layer/rand_bias_field.py new file mode 100644 index 00000000..3fb51888 --- /dev/null +++ b/niftynet/layer/rand_bias_field.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import numpy as np +import scipy.ndimage +import nibabel as nib + +from niftynet.layer.base_layer import RandomisedLayer + + +class RandomBiasFieldLayer(RandomisedLayer): + """ + generate randomised bias field transformation for data augmentation + """ + + def __init__(self, name='random_bias_field'): + super(RandomBiasFieldLayer, self).__init__(name=name) + self._bf_coeffs = None + self.min_coeff = None + self.max_coeff = None + self.order = None + + def init_uniform_coeff(self, coeff_range=(-10.0, 10.0)): + assert coeff_range[0] < coeff_range[1] + self.min_coeff = float(coeff_range[0]) + self.max_coeff = float(coeff_range[1]) + + def init_order(self, order=3): + self.order = int(order) + + def randomise(self, spatial_rank=3): + self._generate_bias_field_coeffs(spatial_rank) + + def _generate_bias_field_coeffs(self, spatial_rank): + ''' + Sampling of the appropriate number of coefficients for the creation + of the bias field map + :param spatial_rank: spatial rank of the image to modify + :return: + ''' + rand_coeffs = [] + if spatial_rank == 3: + for order_x in range(0, self.order+1): + for order_y in range(0, self.order +1- order_x): + for order_z in range(0, self.order+1 -(order_x + order_y)): + rand_coeff_new = np.random.uniform(self.min_coeff, + self.max_coeff) + rand_coeffs.append(rand_coeff_new) + else: + for order_x in range(0, self.order+1): + for order_y in range(0, self.order +1- order_x): + rand_coeff_new = np.random.uniform(self.min_coeff, + self.max_coeff) + rand_coeffs.append(rand_coeff_new) + self._bf_coeffs = rand_coeffs + + def _generate_bias_field_map(self, shape): + ''' + Create the bias field map using a linear combination polynomial + functions and the coefficients previously sampled + :param shape: shape of the image in order to create the polynomial + functions + :return: bias field map to apply + ''' + spatial_rank = len(shape) + x_range = np.arange(-shape[0] / 2, shape[0] / 2) + y_range = np.arange(-shape[1] / 2, shape[1] / 2) + bf_map = np.zeros(shape) + i = 0 + if spatial_rank == 3: + z_range = np.arange(-shape[2] / 2, shape[2] / 2) + x_mesh, y_mesh, z_mesh = np.asarray(np.meshgrid(x_range, y_range, + z_range), dtype=float) + x_mesh /= float(np.max(x_mesh)) + y_mesh /= float(np.max(y_mesh)) + z_mesh /= float(np.max(z_mesh)) + for order_x in range(0, self.order+1): + for order_y in range(0, self.order + 1 - order_x): + for order_z in range(0, self.order + 1 - (order_x+order_y)): + rand_coeff = self._bf_coeffs[i] + + new_map = rand_coeff * np.power(x_mesh, order_x) * \ + np.power(y_mesh, order_y) * \ + np.power(z_mesh, order_z) + # print(np.asarray(np.where(np.abs(new_map) > + # 0)).shape, np.unique( + # new_map).shape) + bf_map += np.transpose(new_map, (1, 0, 2)) + i += 1 + if spatial_rank == 2: + x_mesh, y_mesh = np.asarray(np.meshgrid(x_range, y_range), + dtype=float) + x_mesh /= np.max(x_mesh) + y_mesh /= np.max(y_mesh) + for order_x in range(0, self.order+1): + for order_y in range(0, self.order+1 - order_x): + rand_coeff = self._bf_coeffs[i] + new_map = rand_coeff * np.power(x_mesh, order_x) * \ + np.power(y_mesh, order_y) + bf_map += np.transpose(new_map, (1, 0)) + i += 1 + return np.exp(bf_map) + + def _apply_transformation(self, image): + ''' + Create the bias field map based on the randomly sampled coefficients + and apply it ( + multiplicative) to the image + to augment + :param image image on which to apply the bias field augmentation: + :return: modified image + ''' + assert self._bf_coeffs is not None + bf_map = self._generate_bias_field_map(image.shape) + print(np.asarray(np.where(np.abs(bf_map)>0)).shape) + bf_image = image * bf_map + print(bf_image.shape, image.shape, np.max(bf_image), np.max(image), + np.max(bf_map), np.min(bf_map)) + bf_nii = nib.Nifti1Image(bf_map, np.diag([1,1,1,1])) + bf_image_nii = nib.Nifti1Image(bf_image, np.diag([1,1,1,1])) + image_nii = nib.Nifti1Image(image, np.diag([1,1,1,1])) + nib.save(bf_nii, './TestBF.nii.gz') + nib.save(bf_image_nii, './TestModif.nii.gz') + nib.save(image_nii, './InitModify.nii.gz') + return bf_image + + def layer_op(self, inputs, interp_orders, *args, **kwargs): + if inputs is None: + return inputs + for (field, image) in inputs.items(): + print(field) + if field == 'image': + for mod_i in range(image.shape[-1]): + if image.ndim == 4: + inputs[field][..., mod_i] = \ + self._apply_transformation( + image[..., mod_i]) + elif image.ndim == 5: + for t in range(image.shape[-2]): + inputs[field][..., t, mod_i] = \ + self._apply_transformation( + image[..., t, mod_i]) + else: + raise NotImplementedError("unknown input format") + return inputs + + # if inputs.spatial_rank == 3: + # if inputs.data.ndim == 4: + # for mod_i in range(inputs.data.shape[-1]): + # inputs.data[..., mod_i] = self._apply_transformation_3d( + # inputs.data[..., mod_i], inputs.interp_order) + # if inputs.data.ndim == 5: + # for t in range(inputs.data.shape[-1]): + # for mod_i in range(inputs.data.shape[-2]): + # inputs.data[..., mod_i, t] = \ + # self._apply_transformation_3d( + # inputs.data[..., mod_i, t], inputs.interp_order) + # if inputs.interp_order > 0: + # inputs.data = inputs.data.astype(np.float) + # elif inputs.interp_order == 0: + # inputs.data = inputs.data.astype(np.int64) + # else: + # raise ValueError('negative interpolation order') + # return inputs + # else: + # # TODO: rotation for spatial_rank is 2 + # # currently not supported 2/2.5D rand rotation + # return inputs diff --git a/niftynet/utilities/user_parameters_default.py b/niftynet/utilities/user_parameters_default.py index 7b07f685..e90b3db1 100755 --- a/niftynet/utilities/user_parameters_default.py +++ b/niftynet/utilities/user_parameters_default.py @@ -484,6 +484,21 @@ def add_training_args(parser): type=float_array, default=()) + parser.add_argument( + "--bias_field_range", + help="[Training only] The range of bias field coeffs in [min_coeff, " + "max_coeff]", + type=float_array, + default=()) + + parser.add_argument( + "--bf_order", + help="[Training only] maximal polynomial order to use for the " + "creation of the bias field augmentation", + metavar='', + type=int, + default=3) + parser.add_argument( "--random_flipping_axes", help="The axes which can be flipped to augment the data. Supply as "