diff --git a/config/test/affine.yaml b/config/test/affine.yaml new file mode 100644 index 000000000..4077a9f05 --- /dev/null +++ b/config/test/affine.yaml @@ -0,0 +1,32 @@ +train: + method: "ddf" # ddf / dvf / conditional + backbone: + name: "global" + num_channel_initial: 1 + extract_levels: [0, 1, 2, 3, 4] + loss: + image: + name: "lncc" + weight: 0.1 + label: + weight: 1.0 + name: "dice" + scales: [0, 1, 2, 4, 8, 16, 32] + regularization: + weight: 0.5 + name: "bending" + preprocess: + batch_size: 2 + shuffle_buffer_num_batch: 1 + optimizer: + name: "adam" + adam: + learning_rate: 1.0e-5 + sgd: + learning_rate: 1.0e-4 + momentum: 0.9 + rms: + learning_rate: 1.0e-4 + momentum: 0.9 + epochs: 2 + save_period: 2 diff --git a/config/unpaired_labeled_ddf.yaml b/config/unpaired_labeled_ddf.yaml index 58ef6163d..f33f59b9b 100644 --- a/config/unpaired_labeled_ddf.yaml +++ b/config/unpaired_labeled_ddf.yaml @@ -10,7 +10,7 @@ dataset: train: # define neural network structure - method: "ddf" # options include "ddf", "dvf", "conditional" and "affine" + method: "ddf" # options include "ddf", "dvf", "conditional" backbone: name: "local" # options include "local", "unet" and "global" - use "global" when method=="affine" num_channel_initial: 1 # number of initial channel in local net, controls the size of the network diff --git a/deepreg/dataset/loader/grouped_loader.py b/deepreg/dataset/loader/grouped_loader.py index 215d151b8..771a76aca 100644 --- a/deepreg/dataset/loader/grouped_loader.py +++ b/deepreg/dataset/loader/grouped_loader.py @@ -5,6 +5,7 @@ Read https://deepreg.readthedocs.io/en/latest/api/loader.html#module-deepreg.dataset.loader.grouped_loader for more details. """ import random +from copy import deepcopy from typing import List from deepreg.dataset.loader.interface import ( @@ -264,7 +265,7 @@ def sample_index_generator(self): else: # sample indices are pre-calculated assert self.sample_indices is not None - sample_indices = self.sample_indices.copy() + sample_indices = deepcopy(self.sample_indices) rnd.shuffle(sample_indices) # shuffle in place for sample_index in sample_indices: group_index1, image_index1, group_index2, image_index2 = sample_index diff --git a/deepreg/model/__init__.py b/deepreg/model/__init__.py index a65e8cd89..6bfafaff0 100644 --- a/deepreg/model/__init__.py +++ b/deepreg/model/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa import deepreg.model.backbone import deepreg.model.loss +import deepreg.model.network diff --git a/deepreg/model/backbone/global_net.py b/deepreg/model/backbone/global_net.py index e7f3109ab..06c54729b 100644 --- a/deepreg/model/backbone/global_net.py +++ b/deepreg/model/backbone/global_net.py @@ -70,7 +70,9 @@ def __init__( units=12, bias_initializer=self.transform_initial ) - def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: + def call( + self, inputs: tf.Tensor, training=None, mask=None + ) -> (tf.Tensor, tf.Tensor): """ Build GlobalNet graph based on built layers. @@ -88,10 +90,10 @@ def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: ) # level E of encoding # predict affine parameters theta of shape = [batch, 4, 3] - self.theta = self._dense_layer(h_out) - self.theta = tf.reshape(self.theta, shape=(-1, 4, 3)) + theta = self._dense_layer(h_out) + theta = tf.reshape(theta, shape=(-1, 4, 3)) # warp the reference grid with affine parameters to output a ddf - grid_warped = layer_util.warp_grid(self.reference_grid, self.theta) - output = grid_warped - self.reference_grid - return output + grid_warped = layer_util.warp_grid(self.reference_grid, theta) + ddf = grid_warped - self.reference_grid + return ddf, theta diff --git a/deepreg/model/loss/image.py b/deepreg/model/loss/image.py index 8aebabd83..1daa00f2b 100644 --- a/deepreg/model/loss/image.py +++ b/deepreg/model/loss/image.py @@ -18,13 +18,13 @@ class SumSquaredDifference(tf.keras.losses.Loss): def __init__( self, - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "SumSquaredDifference", ): """ Init. - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ @@ -57,7 +57,7 @@ def __init__( self, num_bins: int = 23, sigma_ratio: float = 0.5, - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "GlobalMutualInformation", ): """ @@ -65,7 +65,7 @@ def __init__( :param num_bins: number of bins for intensity, the default value is empirical. :param sigma_ratio: a hyper param for gaussian function - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ @@ -248,7 +248,7 @@ def __init__( self, kernel_size: int = 9, kernel_type: str = "rectangular", - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "LocalNormalizedCrossCorrelation", ): """ @@ -256,7 +256,7 @@ def __init__( :param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'. :param kernel_type: str, rectangular, triangular or gaussian - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: name of the loss """ diff --git a/deepreg/model/loss/label.py b/deepreg/model/loss/label.py index 945438b1c..1e8756643 100644 --- a/deepreg/model/loss/label.py +++ b/deepreg/model/loss/label.py @@ -53,7 +53,7 @@ def __init__( self, scales: Optional[List] = None, kernel: str = "gaussian", - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "MultiScaleLoss", ): """ @@ -61,7 +61,7 @@ def __init__( :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -144,7 +144,7 @@ def __init__( neg_weight: float = 0.0, scales: Optional[List] = None, kernel: str = "gaussian", - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "DiceScore", ): """ @@ -154,7 +154,7 @@ def __init__( :param neg_weight: weight for negative class. :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -214,7 +214,7 @@ def __init__( neg_weight: float = 0.0, scales: Optional[List] = None, kernel: str = "gaussian", - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "CrossEntropy", ): """ @@ -224,7 +224,7 @@ def __init__( :param neg_weight: weight for negative class :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ @@ -276,7 +276,7 @@ def __init__( binary: bool = False, scales: Optional[List] = None, kernel: str = "gaussian", - reduction: str = tf.keras.losses.Reduction.AUTO, + reduction: str = tf.keras.losses.Reduction.SUM, name: str = "JaccardIndex", ): """ @@ -285,7 +285,7 @@ def __init__( :param binary: if True, project y_true, y_pred to 0 or 1. :param scales: list of scalars or None, if None, do not apply any scaling. :param kernel: gaussian or cauchy. - :param reduction: using AUTO reduction, + :param reduction: using SUM reduction over batch axis, calling the loss like `loss(y_true, y_pred)` will return a scalar tensor. :param name: str, name of the loss. """ diff --git a/deepreg/model/network.py b/deepreg/model/network.py new file mode 100644 index 000000000..76b9dbafa --- /dev/null +++ b/deepreg/model/network.py @@ -0,0 +1,507 @@ +import logging +from abc import abstractmethod +from copy import deepcopy +from typing import Dict, Optional + +import tensorflow as tf + +from deepreg.model import layer, layer_util +from deepreg.model.backbone import GlobalNet +from deepreg.registry import REGISTRY + + +def dict_without(d: dict, key) -> dict: + """ + Return a copy of the given dict without a certain key. + + :param d: dict to be copied. + :param key: key to be removed. + :return: the copy without a key + """ + copied = deepcopy(d) + copied.pop(key) + return copied + + +class RegistrationModel(tf.keras.Model): + """Interface for registration model.""" + + def __init__( + self, + moving_image_size: tuple, + fixed_image_size: tuple, + index_size: int, + labeled: bool, + batch_size: int, + config: dict, + num_devices: int = 1, + name: str = "RegistrationModel", + ): + """ + Init. + + :param moving_image_size: (m_dim1, m_dim2, m_dim3) + :param fixed_image_size: (f_dim1, f_dim2, f_dim3) + :param index_size: number of indices for identify each sample + :param labeled: if the data is labeled + :param batch_size: size of mini-batch + :param config: config for method, backbone, and loss. + :param num_devices: number of GPU used, + global_batch_size = batch_size*num_devices + :param name: name of the model + """ + super().__init__(name=name) + self.moving_image_size = moving_image_size + self.fixed_image_size = fixed_image_size + self.index_size = index_size + self.labeled = labeled + self.batch_size = batch_size + self.config = config + self.num_devices = num_devices + self.global_batch_size = num_devices * batch_size + + self._inputs = None # save inputs of self._model as dict + self._outputs = None # save outputs of self._model as dict + self._model = self.build_model() + self.build_loss() + + def get_config(self): + """Return the config dictionary for recreating this class.""" + return dict( + moving_image_size=self.moving_image_size, + fixed_image_size=self.fixed_image_size, + index_size=self.index_size, + labeled=self.labeled, + batch_size=self.batch_size, + config=self.config, + num_devices=self.num_devices, + name=self.name, + ) + + @abstractmethod + def build_model(self): + """Build the model to be saved as self._model.""" + + def build_inputs(self) -> Dict[str, tf.keras.layers.Input]: + """ + Build input tensors. + + :return: dict of inputs. + """ + # (batch, m_dim1, m_dim2, m_dim3, 1) + moving_image = tf.keras.Input( + shape=self.moving_image_size, + batch_size=self.batch_size, + name="moving_image", + ) + # (batch, f_dim1, f_dim2, f_dim3, 1) + fixed_image = tf.keras.Input( + shape=self.fixed_image_size, + batch_size=self.batch_size, + name="fixed_image", + ) + # (batch, index_size) + indices = tf.keras.Input( + shape=(self.index_size,), + batch_size=self.batch_size, + name="indices", + ) + + if not self.labeled: + return dict( + moving_image=moving_image, fixed_image=fixed_image, indices=indices + ) + + # (batch, m_dim1, m_dim2, m_dim3, 1) + moving_label = tf.keras.Input( + shape=self.moving_image_size, + batch_size=self.batch_size, + name="moving_label", + ) + # (batch, m_dim1, m_dim2, m_dim3, 1) + fixed_label = tf.keras.Input( + shape=self.fixed_image_size, + batch_size=self.batch_size, + name="fixed_label", + ) + return dict( + moving_image=moving_image, + fixed_image=fixed_image, + moving_label=moving_label, + fixed_label=fixed_label, + indices=indices, + ) + + def concat_images( + self, + moving_image: tf.Tensor, + fixed_image: tf.Tensor, + moving_label: Optional[tf.Tensor] = None, + ) -> tf.Tensor: + """ + Adjust image shape and concatenate them together. + + :param moving_image: registration source + :param fixed_image: registration target + :param moving_label: optional, only used for conditional model. + :return: + """ + images = [] + + # (batch, m_dim1, m_dim2, m_dim3, 1) + moving_image = tf.expand_dims(moving_image, axis=4) + moving_image = layer_util.resize3d( + image=moving_image, size=self.fixed_image_size + ) + images.append(moving_image) + + # (batch, m_dim1, m_dim2, m_dim3, 1) + fixed_image = tf.expand_dims(fixed_image, axis=4) + images.append(fixed_image) + + # (batch, m_dim1, m_dim2, m_dim3, 1) + if moving_label is not None: + moving_label = tf.expand_dims(moving_label, axis=4) + moving_label = layer_util.resize3d( + image=moving_label, size=self.fixed_image_size + ) + images.append(moving_label) + + # (batch, f_dim1, f_dim2, f_dim3, 2 or 3) + images = tf.concat(images, axis=4) + return images + + def _build_loss(self, name: str, inputs_dict: dict): + """ + Build and add one weighted loss together with the metrics. + + :param name: name of loss + :param inputs_dict: inputs for loss function + """ + if name not in self.config["loss"]: + # loss config is not defined + logging.warning( + f"The configuration for loss {name} is not defined." + f"Loss is not used." + ) + return + + loss_config = self.config["loss"][name] + + if "weight" not in loss_config: + # default loss weight 1 + logging.warning( + f"The weight for loss {name} is not defined." + f"Default weight = 1.0 is used." + ) + loss_config["weight"] = 1.0 + + # build loss + weight = loss_config["weight"] + + if weight == 0: + logging.warning(f"The weight for loss {name} is zero." f"Loss is not used.") + return + + loss_cls = REGISTRY.build_loss(config=dict_without(d=loss_config, key="weight")) + loss = loss_cls(**inputs_dict) / self.global_batch_size + weighted_loss = loss * weight + + # add loss + self._model.add_loss(weighted_loss) + + # add metric + self._model.add_metric( + loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean" + ) + self._model.add_metric( + weighted_loss, + name=f"loss/{name}_{loss_cls.name}_weighted", + aggregation="mean", + ) + + @abstractmethod + def build_loss(self): + """Build losses according to configs.""" + + def call( + self, inputs: Dict[str, tf.Tensor], training=None, mask=None + ) -> Dict[str, tf.Tensor]: + """ + Call the self._model. + + :param inputs: a dict of tensors. + :param training: training or not. + :param mask: maks for inputs. + :return: + """ + return self._model(inputs, training=training, mask=mask) # pragma: no cover + + @abstractmethod + def postprocess( + self, + inputs: Dict[str, tf.Tensor], + outputs: Dict[str, tf.Tensor], + ) -> (tf.Tensor, Dict): + """ + Return a dict used for saving inputs and outputs. + + :param inputs: dict of model inputs + :param outputs: dict of model outputs + :return: tuple, indices and a dict. + In the dict, each value is (tensor, normalize, on_label), where + - normalize = True if the tensor need to be normalized to [0, 1] + - on_label = True if the tensor depends on label + """ + + +@REGISTRY.register_model(name="ddf") +class DDFModel(RegistrationModel): + """ + A registration model predicts DDF. + + When using global net as backbone, + the model predicts an affine transformation parameters, + and a DDF is calculated based on that. + """ + + def build_model(self): + """Build the model to be saved as self._model.""" + # build inputs + self._inputs = self.build_inputs() + moving_image = self._inputs["moving_image"] + fixed_image = self._inputs["fixed_image"] + + # build ddf + backbone_inputs = self.concat_images(moving_image, fixed_image) + backbone = REGISTRY.build_backbone( + config=self.config["backbone"], + default_args=dict( + image_size=self.fixed_image_size, + out_channels=3, + out_kernel_initializer="zeros", + out_activation=None, + ), + ) + + if isinstance(backbone, GlobalNet): + # (f_dim1, f_dim2, f_dim3, 3), (4, 3) + ddf, theta = backbone(inputs=backbone_inputs) + self._outputs = dict(ddf=ddf, theta=theta) + else: + # (f_dim1, f_dim2, f_dim3, 3) + ddf = backbone(inputs=backbone_inputs) + self._outputs = dict(ddf=ddf) + + # build outputs + warping = layer.Warping(fixed_image_size=self.fixed_image_size) + # (f_dim1, f_dim2, f_dim3, 3) + pred_fixed_image = warping(inputs=[ddf, moving_image]) + self._outputs["pred_fixed_image"] = pred_fixed_image + + if not self.labeled: + return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) + + # (f_dim1, f_dim2, f_dim3, 3) + moving_label = self._inputs["moving_label"] + pred_fixed_label = warping(inputs=[ddf, moving_label]) + + self._outputs["pred_fixed_label"] = pred_fixed_label + return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) + + def build_loss(self): + """Build losses according to configs.""" + fixed_image = self._inputs["fixed_image"] + ddf = self._outputs["ddf"] + pred_fixed_image = self._outputs["pred_fixed_image"] + + # ddf + self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf)) + + # image + self._build_loss( + name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image) + ) + + # label + if self.labeled: + fixed_label = self._inputs["fixed_label"] + pred_fixed_label = self._outputs["pred_fixed_label"] + self._build_loss( + name="label", + inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label), + ) + + def postprocess( + self, + inputs: Dict[str, tf.Tensor], + outputs: Dict[str, tf.Tensor], + ) -> (tf.Tensor, Dict): + """ + Return a dict used for saving inputs and outputs. + + :param inputs: dict of model inputs + :param outputs: dict of model outputs + :return: tuple, indices and a dict. + In the dict, each value is (tensor, normalize, on_label), where + - normalize = True if the tensor need to be normalized to [0, 1] + - on_label = True if the tensor depends on label + """ + indices = inputs["indices"] + processed = dict( + moving_image=(inputs["moving_image"], True, False), + fixed_image=(inputs["fixed_image"], True, False), + ddf=(outputs["ddf"], True, False), + pred_fixed_image=(outputs["pred_fixed_image"], True, False), + ) + + # save theta for affine model + if "theta" in outputs: + processed["theta"] = (outputs["theta"], None, None) + + if not self.labeled: + return indices, processed + + processed = { + **dict( + moving_label=(inputs["moving_label"], False, True), + fixed_label=(inputs["fixed_label"], False, True), + pred_fixed_label=(outputs["pred_fixed_label"], False, True), + ), + **processed, + } + + return indices, processed + + +@REGISTRY.register_model(name="dvf") +class DVFModel(DDFModel): + """ + A registration model predicts DVF. + + DDF is calculated based on DVF. + """ + + def build_model(self): + """Build the model to be saved as self._model.""" + # build inputs + self._inputs = self.build_inputs() + moving_image = self._inputs["moving_image"] + fixed_image = self._inputs["fixed_image"] + + # build ddf + backbone_inputs = self.concat_images(moving_image, fixed_image) + backbone = REGISTRY.build_backbone( + config=self.config["backbone"], + default_args=dict( + image_size=self.fixed_image_size, + out_channels=3, + out_kernel_initializer="zeros", + out_activation=None, + ), + ) + dvf = backbone(inputs=backbone_inputs) + ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf) + + # build outputs + warping = layer.Warping(fixed_image_size=self.fixed_image_size) + # (f_dim1, f_dim2, f_dim3, 3) + pred_fixed_image = warping(inputs=[ddf, moving_image]) + + self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image) + + if not self.labeled: + return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) + + # (f_dim1, f_dim2, f_dim3, 3) + moving_label = self._inputs["moving_label"] + pred_fixed_label = warping(inputs=[ddf, moving_label]) + + self._outputs["pred_fixed_label"] = pred_fixed_label + return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) + + def postprocess( + self, + inputs: Dict[str, tf.Tensor], + outputs: Dict[str, tf.Tensor], + ) -> (tf.Tensor, Dict): + """ + Return a dict used for saving inputs and outputs. + + :param inputs: dict of model inputs + :param outputs: dict of model outputs + :return: tuple, indices and a dict. + In the dict, each value is (tensor, normalize, on_label), where + - normalize = True if the tensor need to be normalized to [0, 1] + - on_label = True if the tensor depends on label + """ + indices, processed = super().postprocess(inputs=inputs, outputs=outputs) + processed["dvf"] = (outputs["dvf"], True, False) + return indices, processed + + +@REGISTRY.register_model(name="conditional") +class ConditionalModel(RegistrationModel): + def build_model(self): + """Build the model to be saved as self._model.""" + assert self.labeled + + # build inputs + self._inputs = self.build_inputs() + moving_image = self._inputs["moving_image"] + fixed_image = self._inputs["fixed_image"] + moving_label = self._inputs["moving_label"] + + # build ddf + backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label) + backbone = REGISTRY.build_backbone( + config=self.config["backbone"], + default_args=dict( + image_size=self.fixed_image_size, + out_channels=1, + out_kernel_initializer="glorot_uniform", + out_activation="sigmoid", + ), + ) + # (batch, f_dim1, f_dim2, f_dim3) + pred_fixed_label = backbone(inputs=backbone_inputs) + pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4) + + self._outputs = dict(pred_fixed_label=pred_fixed_label) + return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) + + def build_loss(self): + """Build losses according to configs.""" + fixed_label = self._inputs["fixed_label"] + pred_fixed_label = self._outputs["pred_fixed_label"] + + self._build_loss( + name="label", + inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label), + ) + + def postprocess( + self, + inputs: Dict[str, tf.Tensor], + outputs: Dict[str, tf.Tensor], + ) -> (tf.Tensor, Dict): + """ + Return a dict used for saving inputs and outputs. + + :param inputs: dict of model inputs + :param outputs: dict of model outputs + :return: tuple, indices and a dict. + In the dict, each value is (tensor, normalize, on_label), where + - normalize = True if the tensor need to be normalized to [0, 1] + - on_label = True if the tensor depends on label + """ + indices = inputs["indices"] + processed = dict( + moving_image=(inputs["moving_image"], True, False), + fixed_image=(inputs["fixed_image"], True, False), + pred_fixed_label=(outputs["pred_fixed_label"], True, True), + moving_label=(inputs["moving_label"], False, True), + fixed_label=(inputs["fixed_label"], False, True), + ) + + return indices, processed diff --git a/deepreg/model/network/__init__.py b/deepreg/model/network/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/deepreg/model/network/affine.py b/deepreg/model/network/affine.py deleted file mode 100644 index 9b3592b4f..000000000 --- a/deepreg/model/network/affine.py +++ /dev/null @@ -1,164 +0,0 @@ -import tensorflow as tf - -from deepreg.model import layer, layer_util -from deepreg.model.network.util import ( - add_ddf_loss, - add_image_loss, - add_label_loss, - build_backbone, - build_inputs, -) -from deepreg.registry import Registry - - -def affine_forward( - backbone: tf.keras.Model, - moving_image: tf.Tensor, - fixed_image: tf.Tensor, - moving_label: (tf.Tensor, None), - moving_image_size: tuple, - fixed_image_size: tuple, -) -> (tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor): - """ - Perform the network forward pass. - - :param backbone: model architecture object, e.g. model.backbone.local_net - :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) - :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None - :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) - :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - - affine is the affine transformation matrix predicted by the network, - of shape (batch, 4, 3) - - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - - pred_fixed_image is the predicted (warped) moving image - of shape (batch, f_dim1, f_dim2, f_dim3) - - pred_fixed_label is the predicted (warped) moving label - of shape (batch, f_dim1, f_dim2, f_dim3) - - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) - """ - - # expand dims - # need to be squeezed later for warping - moving_image = tf.expand_dims( - moving_image, axis=4 - ) # (batch, m_dim1, m_dim2, m_dim3, 1) - fixed_image = tf.expand_dims( - fixed_image, axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - - # adjust moving image - moving_image = layer_util.resize3d( - image=moving_image, size=fixed_image_size - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - - # ddf, dvf - inputs = tf.concat( - [moving_image, fixed_image], axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 2) - ddf = backbone(inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) - affine = backbone.theta - - # prediction, (batch, f_dim1, f_dim2, f_dim3) - warping = layer.Warping(fixed_image_size=fixed_image_size) - grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) - pred_fixed_label = ( - warping(inputs=[ddf, moving_label]) if moving_label is not None else None - ) - return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed - - -def build_affine_model( - moving_image_size: tuple, - fixed_image_size: tuple, - index_size: int, - labeled: bool, - batch_size: int, - train_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Build a model which outputs the parameters for affine transformation. - - :param moving_image_size: (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: (f_dim1, f_dim2, f_dim3) - :param index_size: int, the number of indices for identifying a sample - :param labeled: bool, indicating if the data is labeled - :param batch_size: int, size of mini-batch - :param train_config: config for the model and loss - :param registry: registry to construct class objects - :return: the built tf.keras.Model - """ - - # inputs - (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - batch_size=batch_size, - labeled=labeled, - ) - - # backbone - backbone = build_backbone( - image_size=fixed_image_size, - out_channels=3, - config=train_config["backbone"], - method_name=train_config["method"], - registry=registry, - ) - - # forward - affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward( - backbone=backbone, - moving_image=moving_image, - fixed_image=fixed_image, - moving_label=moving_label, - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - ) - - # build model - inputs = { - "moving_image": moving_image, - "fixed_image": fixed_image, - "indices": indices, - } - outputs = {"ddf": ddf, "affine": affine} - model_name = train_config["method"].upper() + "RegistrationModel" - if moving_label is None: # unlabeled - model = tf.keras.Model( - inputs=inputs, outputs=outputs, name=model_name + "WithoutLabel" - ) - else: # labeled - inputs["moving_label"] = moving_label - inputs["fixed_label"] = fixed_label - outputs["pred_fixed_label"] = pred_fixed_label - model = tf.keras.Model( - inputs=inputs, outputs=outputs, name=model_name + "WithLabel" - ) - - # add loss and metric - loss_config = train_config["loss"] - model = add_ddf_loss( - model=model, ddf=ddf, loss_config=loss_config, registry=registry - ) - model = add_image_loss( - model=model, - fixed_image=fixed_image, - pred_fixed_image=pred_fixed_image, - loss_config=loss_config, - registry=registry, - ) - model = add_label_loss( - model=model, - grid_fixed=grid_fixed, - fixed_label=fixed_label, - pred_fixed_label=pred_fixed_label, - loss_config=loss_config, - registry=registry, - ) - - return model diff --git a/deepreg/model/network/build.py b/deepreg/model/network/build.py deleted file mode 100644 index fd9abfa6d..000000000 --- a/deepreg/model/network/build.py +++ /dev/null @@ -1,61 +0,0 @@ -import tensorflow as tf - -from deepreg.model.network.affine import build_affine_model -from deepreg.model.network.cond import build_conditional_model -from deepreg.model.network.ddf_dvf import build_ddf_dvf_model -from deepreg.registry import Registry - - -def build_model( - moving_image_size: tuple, - fixed_image_size: tuple, - index_size: int, - labeled: bool, - batch_size: int, - train_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Parsing algorithm types to model building functions. - - :param moving_image_size: [m_dim1, m_dim2, m_dim3] - :param fixed_image_size: [f_dim1, f_dim2, f_dim3] - :param index_size: dataset size - :param labeled: true if the label of moving/fixed images are provided - :param batch_size: mini-batch size - :param train_config: train configuration - :param registry: registry to construct class objects - :return: the built tf.keras.Model - """ - if train_config["method"] in ["ddf", "dvf"]: - return build_ddf_dvf_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - labeled=labeled, - batch_size=batch_size, - train_config=train_config, - registry=registry, - ) - elif train_config["method"] == "conditional": - return build_conditional_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - labeled=labeled, - batch_size=batch_size, - train_config=train_config, - registry=registry, - ) - elif train_config["method"] == "affine": - return build_affine_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - labeled=labeled, - batch_size=batch_size, - train_config=train_config, - registry=registry, - ) - else: - raise ValueError(f"Unknown method {train_config['method']}") diff --git a/deepreg/model/network/cond.py b/deepreg/model/network/cond.py deleted file mode 100644 index 5d53d72f7..000000000 --- a/deepreg/model/network/cond.py +++ /dev/null @@ -1,140 +0,0 @@ -import tensorflow as tf - -from deepreg.model import layer, layer_util -from deepreg.model.network.util import add_label_loss, build_backbone, build_inputs -from deepreg.registry import Registry - - -def conditional_forward( - backbone: tf.keras.Model, - moving_image: tf.Tensor, - fixed_image: tf.Tensor, - moving_label: (tf.Tensor, None), - moving_image_size: tuple, - fixed_image_size: tuple, -) -> [tf.Tensor, tf.Tensor]: - """ - Perform the network forward pass. - - :param backbone: model architecture object, e.g. model.backbone.local_net - :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) - :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None - :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) - :return: (pred_fixed_label, fixed_grid), where - - - pred_fixed_label is the predicted (warped) moving label - of shape (batch, f_dim1, f_dim2, f_dim3) - - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) - """ - - # expand dims - # need to be squeezed later for warping - moving_image = tf.expand_dims( - moving_image, axis=4 - ) # (batch, m_dim1, m_dim2, m_dim3, 1) - fixed_image = tf.expand_dims( - fixed_image, axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - moving_label = tf.expand_dims( - moving_label, axis=4 - ) # (batch, m_dim1, m_dim2, m_dim3, 1) - - # adjust moving image - if moving_image_size != fixed_image_size: - moving_image = layer_util.resize3d( - image=moving_image, size=fixed_image_size - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - moving_label = layer_util.resize3d( - image=moving_label, size=fixed_image_size - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - - # conditional - inputs = tf.concat( - [moving_image, fixed_image, moving_label], axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 3) - pred_fixed_label = backbone(inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 1) - pred_fixed_label = tf.squeeze( - pred_fixed_label, axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3) - - warping = layer.Warping(fixed_image_size=fixed_image_size) - grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) - - return pred_fixed_label, grid_fixed - - -def build_conditional_model( - moving_image_size: tuple, - fixed_image_size: tuple, - index_size: int, - labeled: bool, - batch_size: int, - train_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Build a model which outputs predicted fixed label. - - :param moving_image_size: (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: (f_dim1, f_dim2, f_dim3) - :param index_size: int, the number of indices for identifying a sample - :param labeled: bool, indicating if the data is labeled - :param batch_size: int, size of mini-batch - :param train_config: config for the model and loss - :param registry: registry to construct class objects - :return: the built tf.keras.Model - """ - # inputs - (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - batch_size=batch_size, - labeled=labeled, - ) - - # backbone - backbone = build_backbone( - image_size=fixed_image_size, - out_channels=1, - config=train_config["backbone"], - method_name=train_config["method"], - registry=registry, - ) - - # prediction - pred_fixed_label, grid_fixed = conditional_forward( - backbone=backbone, - moving_image=moving_image, - fixed_image=fixed_image, - moving_label=moving_label, - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - ) # (batch, f_dim1, f_dim2, f_dim3) - - # build model - inputs = { - "moving_image": moving_image, - "fixed_image": fixed_image, - "moving_label": moving_label, - "fixed_label": fixed_label, - "indices": indices, - } - outputs = {"pred_fixed_label": pred_fixed_label} - model = tf.keras.Model( - inputs=inputs, outputs=outputs, name="ConditionalRegistrationModel" - ) - - # loss and metric - model = add_label_loss( - model=model, - grid_fixed=grid_fixed, - fixed_label=fixed_label, - pred_fixed_label=pred_fixed_label, - loss_config=train_config["loss"], - registry=registry, - ) - - return model diff --git a/deepreg/model/network/ddf_dvf.py b/deepreg/model/network/ddf_dvf.py deleted file mode 100644 index 5bff816fa..000000000 --- a/deepreg/model/network/ddf_dvf.py +++ /dev/null @@ -1,176 +0,0 @@ -import tensorflow as tf - -from deepreg.model import layer, layer_util -from deepreg.model.network.util import ( - add_ddf_loss, - add_image_loss, - add_label_loss, - build_backbone, - build_inputs, -) -from deepreg.registry import Registry - - -def ddf_dvf_forward( - backbone: tf.keras.Model, - moving_image: tf.Tensor, - fixed_image: tf.Tensor, - moving_label: (tf.Tensor, None), - moving_image_size: tuple, - fixed_image_size: tuple, - output_dvf: bool, -) -> [(tf.Tensor, None), tf.Tensor, tf.Tensor, (tf.Tensor, None), tf.Tensor]: - """ - Perform the network forward pass. - :param backbone: model architecture object, e.g. model.backbone.local_net - :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3) - :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None - :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3) - :param output_dvf: bool, if true, model outputs dvf, if false, model outputs ddf - :return: (dvf, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where - - dvf is the dense velocity field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3) - - pred_fixed_image is the predicted (warped) moving image - of shape (batch, f_dim1, f_dim2, f_dim3) - - pred_fixed_label is the predicted (warped) moving label - of shape (batch, f_dim1, f_dim2, f_dim3) - - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3) - """ - - # expand dims - # need to be squeezed later for warping - moving_image = tf.expand_dims( - moving_image, axis=4 - ) # (batch, m_dim1, m_dim2, m_dim3, 1) - fixed_image = tf.expand_dims( - fixed_image, axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - - # adjust moving image - moving_image = layer_util.resize3d( - image=moving_image, size=fixed_image_size - ) # (batch, f_dim1, f_dim2, f_dim3, 1) - - # ddf, dvf - inputs = tf.concat( - [moving_image, fixed_image], axis=4 - ) # (batch, f_dim1, f_dim2, f_dim3, 2) - backbone_out = backbone(inputs=inputs) # (batch, f_dim1, f_dim2, f_dim3, 3) - if output_dvf: - dvf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) - ddf = layer.IntDVF(fixed_image_size=fixed_image_size)( - dvf - ) # (batch, f_dim1, f_dim2, f_dim3, 3) - else: - dvf = None - ddf = backbone_out # (batch, f_dim1, f_dim2, f_dim3, 3) - - # prediction, (batch, f_dim1, f_dim2, f_dim3) - warping = layer.Warping(fixed_image_size=fixed_image_size) - grid_fixed = tf.squeeze(warping.grid_ref, axis=0) # (f_dim1, f_dim2, f_dim3, 3) - pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)]) - pred_fixed_label = ( - warping(inputs=[ddf, moving_label]) if moving_label is not None else None - ) - return dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed - - -def build_ddf_dvf_model( - moving_image_size: tuple, - fixed_image_size: tuple, - index_size: int, - labeled: bool, - batch_size: int, - train_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Build a model which outputs DDF/DVF. - - :param moving_image_size: (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: (f_dim1, f_dim2, f_dim3) - :param index_size: int, the number of indices for identifying a sample - :param labeled: bool, indicating if the data is labeled - :param batch_size: int, size of mini-batch - :param train_config: config for the model and loss - :param registry: registry to construct class objects - :return: the built tf.keras.Model - """ - - # inputs - (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=index_size, - batch_size=batch_size, - labeled=labeled, - ) - - # backbone - backbone = build_backbone( - image_size=fixed_image_size, - out_channels=3, - config=train_config["backbone"], - method_name=train_config["method"], - registry=registry, - ) - - # forward - dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward( - backbone=backbone, - moving_image=moving_image, - fixed_image=fixed_image, - moving_label=moving_label, - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - output_dvf=train_config["method"] == "dvf", - ) - - # build model - inputs = { - "moving_image": moving_image, - "fixed_image": fixed_image, - "indices": indices, - } - outputs = {"ddf": ddf} - if dvf is not None: - outputs["dvf"] = dvf - - model_name = train_config["method"].upper() + "RegistrationModel" - - if moving_label is None: # unlabeled - model = tf.keras.Model( - inputs=inputs, outputs=outputs, name=model_name + "WithoutLabel" - ) - else: # labeled - inputs["moving_label"] = moving_label - inputs["fixed_label"] = fixed_label - outputs["pred_fixed_label"] = pred_fixed_label - model = tf.keras.Model( - inputs=inputs, outputs=outputs, name=model_name + "WithLabel" - ) - - # add loss and metric - loss_config = train_config["loss"] - model = add_ddf_loss( - model=model, ddf=ddf, loss_config=loss_config, registry=registry - ) - model = add_image_loss( - model=model, - fixed_image=fixed_image, - pred_fixed_image=pred_fixed_image, - loss_config=loss_config, - registry=registry, - ) - model = add_label_loss( - model=model, - grid_fixed=grid_fixed, - fixed_label=fixed_label, - pred_fixed_label=pred_fixed_label, - loss_config=loss_config, - registry=registry, - ) - - return model diff --git a/deepreg/model/network/util.py b/deepreg/model/network/util.py deleted file mode 100644 index b2cc24d1b..000000000 --- a/deepreg/model/network/util.py +++ /dev/null @@ -1,236 +0,0 @@ -# coding=utf-8 - -""" -Module to build backbone modules based on passed inputs. -""" - -import tensorflow as tf - -import deepreg.model.loss.label as label_loss -from deepreg.registry import Registry - - -def build_backbone( - image_size: tuple, - out_channels: int, - config: dict, - method_name: str, - registry: Registry, -) -> tf.keras.Model: - """ - Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in) - and returns a single output of shape (batch, dim1, dim2, dim3, ch_out). - - :param image_size: tuple, dims of image, (dim1, dim2, dim3) - :param out_channels: int, number of out channels, ch_out - :param method_name: str, one of ddf, dvf and conditional - :param config: dict, backbone configuration - :param registry: the registry object having all registered classes - :return: tf.keras.Model - """ - if not ( - (isinstance(image_size, tuple) or isinstance(image_size, list)) - and len(image_size) == 3 - ): - raise ValueError(f"image_size must be tuple of length 3, got {image_size}") - - if method_name in ["ddf", "dvf"]: - out_activation = None - # TODO try random init with smaller number - out_kernel_initializer = "zeros" # to ensure small ddf and dvf - elif method_name in ["conditional"]: - out_activation = "sigmoid" # output is probability - out_kernel_initializer = "glorot_uniform" - elif method_name in ["affine"]: - out_activation = None - out_kernel_initializer = "zeros" - else: - raise ValueError( - f"method name has to be one of ddf/dvf/conditional/affine, " - f"got {method_name}" - ) - - backbone = registry.build_backbone( - config=config, - default_args=dict( - image_size=image_size, - out_channels=out_channels, - out_kernel_initializer=out_kernel_initializer, - out_activation=out_activation, - ), - ) - return backbone - - -def build_inputs( - moving_image_size: tuple, - fixed_image_size: tuple, - index_size: int, - batch_size: int, - labeled: bool, -) -> [tf.keras.Input, tf.keras.Input, tf.keras.Input, tf.keras.Input, tf.keras.Input]: - """ - Configure a pair of moving and fixed images and a pair of - moving and fixed labels as model input - and returns model input tf.keras.Input - - TODO do we absolutely need the batch_size in Input? - - :param moving_image_size: tuple, dims of moving images, (m_dim1, m_dim2, m_dim3) - :param fixed_image_size: tuple, dims of fixed images, (f_dim1, f_dim2, f_dim3) - :param index_size: int, dataset size (number of images) - :param batch_size: int, mini-batch size - :param labeled: Boolean, true if we have label data - :return: 5 (if labeled=True) or 3 (if labeled=False) tf.keras.Input objects - """ - moving_image = tf.keras.Input( - shape=moving_image_size, batch_size=batch_size, name="moving_image" - ) # (batch, m_dim1, m_dim2, m_dim3) - fixed_image = tf.keras.Input( - shape=fixed_image_size, batch_size=batch_size, name="fixed_image" - ) # (batch, f_dim1, f_dim2, f_dim3) - moving_label = ( - tf.keras.Input( - shape=moving_image_size, batch_size=batch_size, name="moving_label" - ) - if labeled - else None - ) # (batch, m_dim1, m_dim2, m_dim3) - fixed_label = ( - tf.keras.Input( - shape=fixed_image_size, batch_size=batch_size, name="fixed_label" - ) - if labeled - else None - ) # (batch, m_dim1, m_dim2, m_dim3) - indices = tf.keras.Input( - shape=(index_size,), batch_size=batch_size, name="indices" - ) # (batch, 2) - return moving_image, fixed_image, moving_label, fixed_label, indices - - -def add_ddf_loss( - model: tf.keras.Model, - ddf: tf.Tensor, - loss_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Add regularization loss of ddf into model. - - :param model: tf.keras.Model - :param ddf: tensor of shape (batch, m_dim1, m_dim2, m_dim3, 3) - :param loss_config: config for loss - :param registry: the registry object having all registered classes - """ - if loss_config["regularization"]["weight"] <= 0: - # TODO will refactor the way building models - return model # pragma: no cover - config = loss_config["regularization"].copy() - weight = config.pop("weight", 1) - loss_reg = registry.build_loss(config=config)(inputs=ddf) - weighted_loss_reg = loss_reg * weight - model.add_loss(weighted_loss_reg) - model.add_metric(loss_reg, name="loss/regularization", aggregation="mean") - model.add_metric( - weighted_loss_reg, name="loss/weighted_regularization", aggregation="mean" - ) - return model - - -def add_image_loss( - model: tf.keras.Model, - fixed_image: tf.Tensor, - pred_fixed_image: tf.Tensor, - loss_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Add image dissimilarity loss of ddf into model. - - :param model: tf.keras.Model - :param fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param pred_fixed_image: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param loss_config: config for loss - :param registry: the registry object having all registered classes - """ - if loss_config["image"]["weight"] <= 0: - # TODO will refactor the way building models - return model # pragma: no cover - config = loss_config["image"].copy() - weight = config.pop("weight", 1) - - loss_image = registry.build_loss(config=config)( - y_true=fixed_image, - y_pred=pred_fixed_image, - ) - weighted_loss_image = loss_image * weight - model.add_loss(weighted_loss_image) - model.add_metric(loss_image, name="loss/image_dissimilarity", aggregation="mean") - model.add_metric( - weighted_loss_image, - name="loss/weighted_image_dissimilarity", - aggregation="mean", - ) - return model - - -def add_label_loss( - model: tf.keras.Model, - grid_fixed: tf.Tensor, - fixed_label: (tf.Tensor, None), - pred_fixed_label: (tf.Tensor, None), - loss_config: dict, - registry: Registry, -) -> tf.keras.Model: - """ - Add label dissimilarity loss of ddf into model. - - :param model: tf.keras.Model - :param grid_fixed: tensor of shape (f_dim1, f_dim2, f_dim3, 3) - :param fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param pred_fixed_label: tensor of shape (batch, f_dim1, f_dim2, f_dim3) - :param loss_config: config for loss - :param registry: the registry object having all registered classes - """ - if fixed_label is None: - # TODO will refactor the way building models - return model # pragma: no cover - if loss_config["label"]["weight"] <= 0: - # TODO will refactor the way building models - return model # pragma: no cover - config = loss_config["label"].copy() - weight = config.pop("weight", 1) - loss_label = registry.build_loss(config=config)( - y_true=fixed_label, - y_pred=pred_fixed_label, - ) - weighted_loss_label = loss_label * weight - model.add_loss(weighted_loss_label) - model.add_metric(loss_label, name="loss/label", aggregation="mean") - model.add_metric( - weighted_loss_label, - name="loss/weighted_label", - aggregation="mean", - ) - - # metrics - dice_binary = label_loss.DiceScore(binary=True)( - y_true=fixed_label, y_pred=pred_fixed_label - ) - dice_float = label_loss.DiceScore(binary=False)( - y_true=fixed_label, y_pred=pred_fixed_label - ) - tre = label_loss.compute_centroid_distance( - y_true=fixed_label, y_pred=pred_fixed_label, grid=grid_fixed - ) - foreground_label = label_loss.foreground_proportion(y=fixed_label) - foreground_pred = label_loss.foreground_proportion(y=pred_fixed_label) - model.add_metric(dice_binary, name="metric/dice_binary", aggregation="mean") - model.add_metric(dice_float, name="metric/dice_float", aggregation="mean") - model.add_metric(tre, name="metric/tre", aggregation="mean") - model.add_metric( - foreground_label, name="metric/foreground_label", aggregation="mean" - ) - model.add_metric(foreground_pred, name="metric/foreground_pred", aggregation="mean") - return model diff --git a/deepreg/predict.py b/deepreg/predict.py index 4fba371a5..dbe3270ff 100644 --- a/deepreg/predict.py +++ b/deepreg/predict.py @@ -17,7 +17,6 @@ import deepreg.model.optimizer as opt import deepreg.parser as config_parser from deepreg.callback import build_checkpoint_callback -from deepreg.model.network.build import build_model from deepreg.registry import REGISTRY, Registry from deepreg.util import ( build_dataset, @@ -79,126 +78,63 @@ def predict_on_dataset( sample_index_strs = [] metric_lists = [] - for _, inputs_dict in enumerate(dataset): - batch_size = inputs_dict[list(inputs_dict.keys())[0]].shape[0] - outputs_dict = model.predict(x=inputs_dict, batch_size=batch_size) - - # moving image/label - # (batch, m_dim1, m_dim2, m_dim3) - moving_image = inputs_dict["moving_image"] - moving_label = inputs_dict.get("moving_label", None) - # fixed image/labelimage_index - # (batch, f_dim1, f_dim2, f_dim3) - fixed_image = inputs_dict["fixed_image"] - fixed_label = inputs_dict.get("fixed_label", None) - - # indices to identify the pair - # (batch, num_indices) last indice is for label, -1 means unlabeled data - indices = inputs_dict.get("indices") - # ddf / dvf - # (batch, f_dim1, f_dim2, f_dim3, 3) - ddf = outputs_dict.get("ddf", None) - dvf = outputs_dict.get("dvf", None) - affine = outputs_dict.get("affine", None) # (batch, 4, 3) - - # prediction - # (batch, f_dim1, f_dim2, f_dim3) - pred_fixed_label = outputs_dict.get("pred_fixed_label", None) - pred_fixed_image = ( - layer_util.resample(vol=moving_image, loc=fixed_grid_ref + ddf) - if ddf is not None - else None - ) + for _, inputs in enumerate(dataset): + batch_size = inputs[list(inputs.keys())[0]].shape[0] + outputs = model.predict(x=inputs, batch_size=batch_size) + indices, processed = model.postprocess(inputs=inputs, outputs=outputs) + + # convert to np arrays + indices = indices.numpy() + processed = { + k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) + for k, v in processed.items() + } # save images of inputs and outputs - for sample_index in range(moving_image.shape[0]): - # save moving/fixed image under pair_dir - # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir - # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir + for sample_index in range(batch_size): + # save label independent tensors under pair_dir, otherwise under label_dir # init output path - indices_i = indices[sample_index, :].numpy().astype(int).tolist() + indices_i = indices[sample_index, :].astype(int).tolist() pair_dir, label_dir = build_pair_output_path( indices=indices_i, save_dir=save_dir ) - # save image/label - # if model is conditional, the pred_fixed_image depends on the input label - conditional = model_method == "conditional" - arr_save_dirs = [ - pair_dir, - pair_dir, - label_dir if conditional else pair_dir, - label_dir, - label_dir, - label_dir, - ] - arrs = [ - moving_image, - fixed_image, - pred_fixed_image, - moving_label, - fixed_label, - pred_fixed_label, - ] - names = [ - "moving_image", - "fixed_image", - "pred_fixed_image", # or warped moving image - "moving_label", - "fixed_label", - "pred_fixed_label", # or warped moving label - ] - for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names): - if arr is not None: - # for files under pair_dir, do not overwrite - save_array( - save_dir=arr_save_dir, - arr=arr[sample_index, :, :, :], - name=name, - normalize="image" in name, # label's value is already in [0, 1] - save_nifti=save_nifti, - save_png=save_png, - overwrite=arr_save_dir == label_dir, - ) - - # save ddf / dvf - arrs = [ddf, dvf] - names = ["ddf", "dvf"] - for arr, name in zip(arrs, names): - if arr is not None: - save_array( - save_dir=label_dir if conditional else pair_dir, - arr=arr[sample_index, :, :, :], - name=name, - normalize=True, - save_nifti=save_nifti, - save_png=save_png, + for name, (arr, normalize, on_label) in processed.items(): + if name == "theta": + np.savetxt( + fname=os.path.join(pair_dir, "affine.txt"), + X=arr[sample_index, :, :], + delimiter=",", ) - - # save affine - if affine is not None: - np.savetxt( - fname=os.path.join( - label_dir if conditional else pair_dir, "affine.txt" - ), - x=affine[sample_index, :, :].numpy(), - delimiter=",", + continue + + arr_save_dir = label_dir if on_label else pair_dir + save_array( + save_dir=arr_save_dir, + arr=arr[sample_index, :, :, :], + name=name, + normalize=normalize, # label's value is already in [0, 1] + save_nifti=save_nifti, + save_png=save_png, + overwrite=arr_save_dir == label_dir, ) # calculate metric sample_index_str = "_".join([str(x) for x in indices_i]) - if sample_index_str in sample_index_strs: + if sample_index_str in sample_index_strs: # pragma: no cover raise ValueError( "Sample is repeated, maybe the dataset has been repeated." ) sample_index_strs.append(sample_index_str) metric = calculate_metrics( - fixed_image=fixed_image, - fixed_label=fixed_label, - pred_fixed_image=pred_fixed_image, - pred_fixed_label=pred_fixed_label, + fixed_image=processed["fixed_image"][0], + fixed_label=processed["fixed_label"][0] if model.labeled else None, + pred_fixed_image=processed["pred_fixed_image"][0], + pred_fixed_label=processed["pred_fixed_label"][0] + if model.labeled + else None, fixed_grid_ref=fixed_grid_ref, sample_index=sample_index, ) @@ -306,14 +242,16 @@ def predict( optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) # model - model = build_model( - moving_image_size=data_loader.moving_image_shape, - fixed_image_size=data_loader.fixed_image_shape, - index_size=data_loader.num_indices, - labeled=config["dataset"]["labeled"], - batch_size=preprocess_config["batch_size"], - train_config=config["train"], - registry=registry, + model = registry.build_model( + config=dict( + name=config["train"]["method"], + moving_image_size=data_loader.moving_image_shape, + fixed_image_size=data_loader.fixed_image_shape, + index_size=data_loader.num_indices, + labeled=config["dataset"]["labeled"], + batch_size=config["train"]["preprocess"]["batch_size"], + config=config["train"], + ) ) # metrics diff --git a/deepreg/registry.py b/deepreg/registry.py index 34372fcc6..9516059aa 100644 --- a/deepreg/registry.py +++ b/deepreg/registry.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Callable BACKBONE_CLASS = "backbone_class" @@ -99,7 +100,7 @@ def build_from_config( raise ValueError(f"config must be a dict, but got {type(config)}") if "name" not in config: raise ValueError(f"`config` must contain the key `name`, but got {config}") - args = config.copy() + args = deepcopy(config) # insert key, value pairs if key is not in args if default_args is not None: @@ -123,9 +124,17 @@ def build_from_config( def copy(self): copied = Registry() - copied._dict = self._dict.copy() + copied._dict = deepcopy(self._dict) return copied + def register_model(self, name: str, cls: Callable = None, force: bool = False): + return self.register(category=MODEL_CLASS, name=name, cls=cls, force=force) + + def build_model(self, config: dict, default_args=None): + return self.build_from_config( + category=MODEL_CLASS, config=config, default_args=default_args + ) + def register_backbone(self, name: str, cls: Callable = None, force: bool = False): return self.register(category=BACKBONE_CLASS, name=name, cls=cls, force=force) diff --git a/deepreg/train.py b/deepreg/train.py index 31678078c..d30e33355 100644 --- a/deepreg/train.py +++ b/deepreg/train.py @@ -12,7 +12,6 @@ import deepreg.model.optimizer as opt import deepreg.parser as config_parser from deepreg.callback import build_checkpoint_callback -from deepreg.model.network.build import build_model from deepreg.registry import REGISTRY, Registry from deepreg.util import build_dataset, build_log_dir @@ -115,21 +114,25 @@ def train( # use strategy to support multiple GPUs # the network is mirrored in each GPU so that we can use larger batch size - # https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_tfkerasmodelfit + # https://www.tensorflow.org/guide/distributed_training # only model, optimizer and metrics need to be defined inside the strategy - if len(tf.config.list_physical_devices("GPU")) > 1: + num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) + if num_devices > 1: strategy = tf.distribute.MirroredStrategy() # pragma: no cover else: strategy = tf.distribute.get_strategy() with strategy.scope(): - model = build_model( - moving_image_size=data_loader_train.moving_image_shape, - fixed_image_size=data_loader_train.fixed_image_shape, - index_size=data_loader_train.num_indices, - labeled=config["dataset"]["labeled"], - batch_size=config["train"]["preprocess"]["batch_size"], - train_config=config["train"], - registry=registry, + model = registry.build_model( + config=dict( + name=config["train"]["method"], + moving_image_size=data_loader_train.moving_image_shape, + fixed_image_size=data_loader_train.fixed_image_shape, + index_size=data_loader_train.num_indices, + labeled=config["dataset"]["labeled"], + batch_size=config["train"]["preprocess"]["batch_size"], + config=config["train"], + num_devices=num_devices, + ) ) optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) @@ -150,7 +153,8 @@ def train( callbacks = [tensorboard_callback, ckpt_callback] # train - # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like + # it's necessary to define the steps_per_epoch + # and validation_steps to prevent errors like # BaseCollectiveExecutor::StartAbort Out of range: End of sequence model.fit( x=dataset_train, diff --git a/deepreg/vis.py b/deepreg/vis.py index f129524e6..3e86e0175 100755 --- a/deepreg/vis.py +++ b/deepreg/vis.py @@ -67,7 +67,7 @@ def gif_slices(img_paths, save_path="", interval=50): def tile_slices(img_paths, save_path="", fname=None, slice_inds=None, col_titles=None): """ - Generate a tiled plot of muliple images for multiple slices in the image. + Generate a tiled plot of multiple images for multiple slices in the image. Rows are different slices, columns are different images. :param img_paths: list or comma separated string of image paths @@ -98,13 +98,10 @@ def tile_slices(img_paths, save_path="", fname=None, slice_inds=None, col_titles plt.figure(figsize=(num_imgs * 2, num_inds * 2)) - imgs = [] - for img_path in img_paths: - img = load_nifti_file(img_path) - imgs.append(img) + imgs = [load_nifti_file(p) for p in img_paths] - for img, col_num in zip(imgs, range(num_imgs)): - for index, row_num in zip(slice_inds, range(num_inds)): + for col_num, img in enumerate(imgs): + for row_num, index in enumerate(slice_inds): plt.subplot(num_inds, num_imgs, subplot_mat[row_num, col_num]) plt.imshow(img[:, :, index]) plt.axis("off") diff --git a/demos/grouped_mask_prostate_longitudinal/demo_data.py b/demos/grouped_mask_prostate_longitudinal/demo_data.py index 072ec987f..4833129e9 100644 --- a/demos/grouped_mask_prostate_longitudinal/demo_data.py +++ b/demos/grouped_mask_prostate_longitudinal/demo_data.py @@ -111,7 +111,7 @@ os.mkdir(MODEL_PATH) ZIP_PATH = "grouped_mask_prostate_longitudinal_1" -ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/grouped_mask_prostate_longitudinal_1.zip" +ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/grouped_mask_prostate_longitudinal/20210110.zip" zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip") get_file(os.path.abspath(zip_file), ORIGIN) diff --git a/demos/grouped_mask_prostate_longitudinal/demo_predict.py b/demos/grouped_mask_prostate_longitudinal/demo_predict.py index 416fb8351..2eed3106e 100755 --- a/demos/grouped_mask_prostate_longitudinal/demo_predict.py +++ b/demos/grouped_mask_prostate_longitudinal/demo_predict.py @@ -49,7 +49,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/grouped_mask_prostate_longitudinal/demo_train.py b/demos/grouped_mask_prostate_longitudinal/demo_train.py index b31d51555..c39d530b1 100755 --- a/demos/grouped_mask_prostate_longitudinal/demo_train.py +++ b/demos/grouped_mask_prostate_longitudinal/demo_train.py @@ -46,7 +46,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/grouped_mr_heart/demo_data.py b/demos/grouped_mr_heart/demo_data.py index 1ff3ad97f..891369e72 100644 --- a/demos/grouped_mr_heart/demo_data.py +++ b/demos/grouped_mr_heart/demo_data.py @@ -113,14 +113,12 @@ shutil.rmtree(MODEL_PATH) os.mkdir(MODEL_PATH) -num_zipfiles = 21 -zip_filepath = os.path.abspath(os.path.join(MODEL_PATH, "grouped_mr_heart_1.zip")) -zip_file_parts = [ - zip_filepath + ".%03d" % (idx + 1) for idx in range(num_zipfiles) -] # https://github.com/DeepRegNet/deepreg-model-zoo/blob/master/grouped_mr_heart_1/grouped_mr_heart_1.zip.021 -for idx, zip_file in enumerate(zip_file_parts, start=1): +num_zipfiles = 11 +zip_filepath = os.path.abspath(os.path.join(MODEL_PATH, "checkpoint.zip")) +zip_file_parts = [zip_filepath + ".%02d" % idx for idx in range(num_zipfiles)] +for idx, zip_file in enumerate(zip_file_parts): ORIGIN = ( - "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/grouped_mr_heart_1/grouped_mr_heart_1.zip.%03d" + "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/grouped_mr_heart/20210110/part.%02d" % idx ) get_file(zip_file, ORIGIN) diff --git a/demos/grouped_mr_heart/demo_predict.py b/demos/grouped_mr_heart/demo_predict.py index c3d6e1edd..f72c9a79f 100755 --- a/demos/grouped_mr_heart/demo_predict.py +++ b/demos/grouped_mr_heart/demo_predict.py @@ -48,7 +48,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/grouped_mr_heart/demo_train.py b/demos/grouped_mr_heart/demo_train.py index f6063585d..3b92374c4 100755 --- a/demos/grouped_mr_heart/demo_train.py +++ b/demos/grouped_mr_heart/demo_train.py @@ -46,7 +46,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/paired_ct_lung/demo_data.py b/demos/paired_ct_lung/demo_data.py index 8f8fcfa7c..7e7621677 100644 --- a/demos/paired_ct_lung/demo_data.py +++ b/demos/paired_ct_lung/demo_data.py @@ -231,7 +231,7 @@ def move_test_cases_into_correct_path(test_cases, path_to_train, path_to_test): ######## DOWNLOAD MODEL CKPT FROM MODEL ZOO ######## -url = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/paired_ct_lung_1.zip" +url = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/paired_ct_lung/20210110.zip" fname = "pretrained.zip" os.chdir(os.path.join(main_path, project_dir)) diff --git a/demos/paired_ct_lung/demo_predict.py b/demos/paired_ct_lung/demo_predict.py index 89b6702e7..fb4bb1f46 100755 --- a/demos/paired_ct_lung/demo_predict.py +++ b/demos/paired_ct_lung/demo_predict.py @@ -49,7 +49,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/paired_ct_lung/demo_train.py b/demos/paired_ct_lung/demo_train.py index 2bf7feb63..c13bb7883 100755 --- a/demos/paired_ct_lung/demo_train.py +++ b/demos/paired_ct_lung/demo_train.py @@ -46,7 +46,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/paired_mrus_brain/README.md b/demos/paired_mrus_brain/README.md index 2d0b4b3bf..f311c560a 100644 --- a/demos/paired_mrus_brain/README.md +++ b/demos/paired_mrus_brain/README.md @@ -3,6 +3,9 @@ > **Note**: Please read the > [DeepReg Demo Disclaimer](introduction.html#demo-disclaimer). +> **Warning**: +> [This demo ought to be improved in the future.](https://github.com/DeepRegNet/DeepReg/issues/620). + [Source Code](https://github.com/DeepRegNet/DeepReg/tree/main/demos/paired_mrus_brain) ## Author diff --git a/demos/paired_mrus_brain/demo_predict.py b/demos/paired_mrus_brain/demo_predict.py index 5c0f542f0..6ef115527 100644 --- a/demos/paired_mrus_brain/demo_predict.py +++ b/demos/paired_mrus_brain/demo_predict.py @@ -49,7 +49,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/paired_mrus_brain/demo_train.py b/demos/paired_mrus_brain/demo_train.py index 369ba52e6..d377ee511 100755 --- a/demos/paired_mrus_brain/demo_train.py +++ b/demos/paired_mrus_brain/demo_train.py @@ -47,7 +47,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/paired_mrus_prostate/README.md b/demos/paired_mrus_prostate/README.md index 5b7ba782c..2023ea6de 100644 --- a/demos/paired_mrus_prostate/README.md +++ b/demos/paired_mrus_prostate/README.md @@ -3,6 +3,9 @@ > **Note**: Please read the > [DeepReg Demo Disclaimer](introduction.html#demo-disclaimer). +> **Warning**: +> [This demo ought to be improved in the future.](https://github.com/DeepRegNet/DeepReg/issues/621). + [Source Code](https://github.com/DeepRegNet/DeepReg/tree/main/demos/paired_mrus_brain) This demo uses DeepReg to re-implement the algorithms described in diff --git a/demos/paired_mrus_prostate/demo_data.py b/demos/paired_mrus_prostate/demo_data.py index c294b0f7a..4dbd6eb4d 100644 --- a/demos/paired_mrus_prostate/demo_data.py +++ b/demos/paired_mrus_prostate/demo_data.py @@ -68,8 +68,8 @@ shutil.rmtree(MODEL_PATH) os.mkdir(MODEL_PATH) -ZIP_PATH = "paired_mrus_prostate_1" -ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/paired_mrus_prostate_1.zip" +ZIP_PATH = "checkpoint" +ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/paired_mrus_prostate/20210110.zip" zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip") get_file(os.path.abspath(zip_file), ORIGIN) diff --git a/demos/paired_mrus_prostate/demo_predict.py b/demos/paired_mrus_prostate/demo_predict.py index 39a4b67a0..d76c8f522 100755 --- a/demos/paired_mrus_prostate/demo_predict.py +++ b/demos/paired_mrus_prostate/demo_predict.py @@ -48,7 +48,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/paired_mrus_prostate/demo_train.py b/demos/paired_mrus_prostate/demo_train.py index c3d930144..4a0221dd6 100755 --- a/demos/paired_mrus_prostate/demo_train.py +++ b/demos/paired_mrus_prostate/demo_train.py @@ -48,7 +48,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/unpaired_ct_abdomen/README.md b/demos/unpaired_ct_abdomen/README.md index 491735313..66bb40705 100644 --- a/demos/unpaired_ct_abdomen/README.md +++ b/demos/unpaired_ct_abdomen/README.md @@ -3,6 +3,9 @@ > **Note**: Please read the > [DeepReg Demo Disclaimer](introduction.html#demo-disclaimer). +> **Warning**: +> [This demo ought to be improved in the future.](https://github.com/DeepRegNet/DeepReg/issues/552). + [Source Code](https://github.com/DeepRegNet/DeepReg/tree/main/demos/unpaired_ct_abdomen) ## Author diff --git a/demos/unpaired_ct_abdomen/demo_predict.py b/demos/unpaired_ct_abdomen/demo_predict.py index 6b61ebf53..e46896960 100755 --- a/demos/unpaired_ct_abdomen/demo_predict.py +++ b/demos/unpaired_ct_abdomen/demo_predict.py @@ -59,7 +59,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/unpaired_ct_abdomen/demo_train.py b/demos/unpaired_ct_abdomen/demo_train.py index 28e4e93dd..345ffa8f4 100755 --- a/demos/unpaired_ct_abdomen/demo_train.py +++ b/demos/unpaired_ct_abdomen/demo_train.py @@ -59,7 +59,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/unpaired_ct_lung/demo_data.py b/demos/unpaired_ct_lung/demo_data.py index 4e5b473bf..e99b30ccd 100644 --- a/demos/unpaired_ct_lung/demo_data.py +++ b/demos/unpaired_ct_lung/demo_data.py @@ -259,9 +259,7 @@ def move_test_cases_into_correct_path(test_cases, path_to_train, path_to_test): ######## DOWNLOAD MODEL CKPT FROM MODEL ZOO ######## -url = ( - "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/unpaired_ct_lung_1.zip" -) +url = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/unpaired_ct_lung/20210110.zip" fname = "pretrained.zip" diff --git a/demos/unpaired_ct_lung/demo_predict.py b/demos/unpaired_ct_lung/demo_predict.py index 5013cd224..0bdbda5b3 100755 --- a/demos/unpaired_ct_lung/demo_predict.py +++ b/demos/unpaired_ct_lung/demo_predict.py @@ -33,7 +33,7 @@ "The prediction can also be launched using the following command.\n" "deepreg_predict --gpu '' " f"--config_path demos/{name}/{name}.yaml " - f"--ckpt_path demos/{name}/dataset/pretrained/unpaired_ct_lung_1/ckpt-4000 " + f"--ckpt_path demos/{name}/dataset/pretrained/ckpt-5000 " f"--log_root demos/{name} " "--log_dir logs_predict " "--save_png --mode test\n" @@ -43,14 +43,14 @@ log_root = f"demos/{name}" log_dir = "logs_predict/" + datetime.now().strftime("%Y%m%d-%H%M%S") -ckpt_path = f"{log_root}/dataset/pretrained/unpaired_ct_lung_1/ckpt-4000" +ckpt_path = f"{log_root}/dataset/pretrained/ckpt-5000" config_path = [f"{log_root}/{name}.yaml"] if args.test: config_path.append("config/test/demo_unpaired_grouped.yaml") predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/unpaired_ct_lung/demo_train.py b/demos/unpaired_ct_lung/demo_train.py index 799f6c5d8..a849433c3 100755 --- a/demos/unpaired_ct_lung/demo_train.py +++ b/demos/unpaired_ct_lung/demo_train.py @@ -46,7 +46,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/unpaired_mr_brain/README.md b/demos/unpaired_mr_brain/README.md index b906c3938..4108f181c 100644 --- a/demos/unpaired_mr_brain/README.md +++ b/demos/unpaired_mr_brain/README.md @@ -3,6 +3,9 @@ > **Note**: Please read the > [DeepReg Demo Disclaimer](introduction.html#demo-disclaimer). +> **Warning**: +> [This demo ought to be improved in the future.](https://github.com/DeepRegNet/DeepReg/issues/620). + [Source Code](https://github.com/DeepRegNet/DeepReg/tree/main/demos/unpaired_mr_brain) ## Author diff --git a/demos/unpaired_mr_brain/demo_predict.py b/demos/unpaired_mr_brain/demo_predict.py index d91f892ab..9c65e2156 100644 --- a/demos/unpaired_mr_brain/demo_predict.py +++ b/demos/unpaired_mr_brain/demo_predict.py @@ -50,7 +50,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/unpaired_mr_brain/demo_train.py b/demos/unpaired_mr_brain/demo_train.py index f1122b7e5..56edb87b9 100644 --- a/demos/unpaired_mr_brain/demo_train.py +++ b/demos/unpaired_mr_brain/demo_train.py @@ -47,7 +47,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/demos/unpaired_us_prostate_cv/demo_data.py b/demos/unpaired_us_prostate_cv/demo_data.py index da57f7d85..2a1ccad58 100644 --- a/demos/unpaired_us_prostate_cv/demo_data.py +++ b/demos/unpaired_us_prostate_cv/demo_data.py @@ -33,7 +33,7 @@ os.mkdir(MODEL_PATH) ZIP_PATH = "unpaired_us_prostate_cv_1" -ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/unpaired_us_prostate_cv_1.zip" +ORIGIN = "https://github.com/DeepRegNet/deepreg-model-zoo/raw/master/demo/unpaired_us_prostate_cv/20210110.zip" zip_file = os.path.join(MODEL_PATH, ZIP_PATH + ".zip") get_file(os.path.abspath(zip_file), ORIGIN) diff --git a/demos/unpaired_us_prostate_cv/demo_predict.py b/demos/unpaired_us_prostate_cv/demo_predict.py index 6d2473e89..8281d7cb6 100755 --- a/demos/unpaired_us_prostate_cv/demo_predict.py +++ b/demos/unpaired_us_prostate_cv/demo_predict.py @@ -49,7 +49,7 @@ predict( gpu="0", - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path=ckpt_path, mode="test", batch_size=1, diff --git a/demos/unpaired_us_prostate_cv/demo_train.py b/demos/unpaired_us_prostate_cv/demo_train.py index 77f5737f2..c4ff8e26b 100755 --- a/demos/unpaired_us_prostate_cv/demo_train.py +++ b/demos/unpaired_us_prostate_cv/demo_train.py @@ -48,7 +48,7 @@ train( gpu="0", config_path=config_path, - gpu_allow_growth=False, + gpu_allow_growth=True, ckpt_path="", log_root=log_root, log_dir=log_dir, diff --git a/docs/CODE_OF_CONDUCT.md b/docs/CODE_OF_CONDUCT.md index 98d50f399..72be879ad 100644 --- a/docs/CODE_OF_CONDUCT.md +++ b/docs/CODE_OF_CONDUCT.md @@ -64,7 +64,8 @@ an explanation as to why the situation has not been resolved. All community leaders are obligated to respect the privacy and security of the reporter of any incident. -## Conflicts of Interest If a complaint concerns a member of the community leaders and +## Conflicts of Interest If a complaint concerns a member of the community leaders and + should the reporter not feel comfortable sharing the report with such member, reports can be made privately to any of the members of the community leaders by contacting them privately. Contact methods are listed at diff --git a/docs/Intro_to_Medical_Image_Registration.ipynb b/docs/Intro_to_Medical_Image_Registration.ipynb index d1ffe363a..cee489f2d 100644 --- a/docs/Intro_to_Medical_Image_Registration.ipynb +++ b/docs/Intro_to_Medical_Image_Registration.ipynb @@ -98,7 +98,7 @@ "# Make a directory \"MICCAI_2020_reg_tutorial\"\n", "if not os.path.exists(\"./MICCAI_2020_reg_tutorial\"):\n", " os.makedirs(\"./MICCAI_2020_reg_tutorial\")\n", - "# Move into the dir\n", + "# Move into the dir\n", "os.chdir(\"./MICCAI_2020_reg_tutorial\")\n", "print(os.getcwd())" ], @@ -123,12 +123,12 @@ "colab": {} }, "source": [ - "# Clone the DeepReg repository which contains the code\n", + "# Clone the DeepReg repository which contains the code\n", "! git clone https://github.com/DeepRegNet/DeepReg\n", "%cd ./DeepReg/\n", "# Switch to a fixed version\n", "! git checkout tags/miccai2020-challenge\n", - "# pip install into the notebook env\n", + "# pip install into the notebook env\n", "! pip install -e . --no-cache-dir\n", "print(os.getcwd())" ], @@ -143,7 +143,7 @@ "colab": {} }, "source": [ - "# We import some utility modules.\n", + "# We import some utility modules.\n", "import nibabel\n", "import tensorflow as tf \n", "import deepreg.model.layer as layer\n", @@ -156,7 +156,7 @@ "import numpy as np\n", "from tensorflow.keras.utils import get_file\n", "\n", - "# We set the plot size to some parameters.\n", + "# We set the plot size to some parameters.\n", "plt.rcParams[\"figure.figsize\"] = (100,100)\n", "print(os.getcwd())\n", "if not os.getcwd() == \"/content/MICCAI_2020_reg_tutorial/DeepReg\":\n", @@ -419,7 +419,7 @@ "colab": {} }, "source": [ - "# We define a function to visualise the results of the overlap for label based loss\n", + "# We define a function to visualise the results of the overlap for label based loss\n", "from skimage.color import label2rgb\n", "def pred_label_comparison(pred, mask, shape_pred, thresh=0.5):\n", " \"\"\"\n", @@ -444,27 +444,27 @@ " colour coded to show areas of intersections between\n", " masks and predictions.\n", " \"\"\"\n", - " # Create outrue_posut np.array to store images\n", + " # Create outrue_posut np.array to store images\n", " label = np.zeros((shape_pred[0], shape_pred[1], shape_pred[2], 3))\n", "\n", - " # Thresholding pred\n", + " # Thresholding pred\n", " pred_thresh = pred > thresh\n", "\n", - " # Creating inverse to the masks and predictions\n", + " # Creating inverse to the masks and predictions\n", " mask_not = np.logical_not(mask)\n", " pred_not = np.logical_not(pred_thresh)\n", "\n", - " # Finding intersections\n", + " # Finding intersections\n", " true_pos_array = np.logical_and(pred_thresh, mask)\n", " false_pos_array = np.logical_and(pred_thresh, mask_not)\n", " false_neg_array = np.logical_and(pred_not, mask)\n", "\n", - " # Labelling via color\n", + " # Labelling via color\n", " false_pos_labels = 2*false_pos_array # green\n", - " false_neg_labels = 3*false_neg_array # red\n", + " false_neg_labels = 3*false_neg_array # red\n", " label_array = true_pos_array + false_pos_labels + false_neg_labels\n", "\n", - " # Compare all preds to masks\n", + " # Compare all preds to masks\n", " for i in range(shape_pred[0]):\n", " label[i, :, :, :] = label2rgb(\n", " label_array[i, :, :],\n", @@ -486,7 +486,7 @@ "colab": {} }, "source": [ - "# We set the plot size to some parameters.\n", + "# We set the plot size to some parameters.\n", "plt.rcParams[\"figure.figsize\"] = (10,10)\n", "\n", "# Opening the file\n", @@ -498,7 +498,7 @@ "\n", "# Getting the 0th slice in the tensor\n", "fixed_image_0 = fixed_image[0, ..., 0]\n", - "# Getting the 0th slice foreground label, at index 1 in the label tensor.\n", + "# Getting the 0th slice foreground label, at index 1 in the label tensor.\n", "fixed_label_0 = fixed_labels[0, ..., 0, 1]" ], "execution_count": null, @@ -529,17 +529,17 @@ "colab": {} }, "source": [ - "# Simulate a warped label\n", - "# The following function generates a random transform.\n", + "# Simulate a warped label\n", + "# The following function generates a random transform.\n", "transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.02, seed=4)\n", "\n", - "# We create a reference grid of image size\n", + "# We create a reference grid of image size\n", "grid_ref = layer_util.get_reference_grid(grid_size=fixed_labels.shape[1:4])\n", "\n", "# We warp our reference grid by our random transform\n", "grid_random = layer_util.warp_grid(grid_ref, transform_random)\n", "# We resample the fixed image with the random transform to create a distorted\n", - "# image, which we will use as our moving image.\n", + "# image, which we will use as our moving image.\n", "moving_label = layer_util.resample(vol=fixed_labels, loc=grid_random)[0, ..., 0, 1]\n", "moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)[0, ..., 0]\n", "\n", @@ -567,7 +567,7 @@ "colab": {} }, "source": [ - "# We set the plot size to some parameters.\n", + "# We set the plot size to some parameters.\n", "plt.rcParams[\"figure.figsize\"] = (10,10)\n", "\n", "comparison = pred_label_comparison(np.expand_dims(moving_label, axis=0), np.expand_dims(fixed_label_0, axis=0), [1, 128, 128], thresh=0.1)\n", @@ -600,7 +600,7 @@ }, "source": [ "from deepreg.model.loss.label import dice_score\n", - "# Calculating dice - we need [batch, dim1, dim2, dim3], so we expand the labels axes'\n", + "# Calculating dice - we need [batch, dim1, dim2, dim3], so we expand the labels axes'\n", "batch_moving_label = np.expand_dims(np.expand_dims(moving_label, axis=0), axis=-1)\n", "batch_fixed_label = np.expand_dims(np.expand_dims(fixed_label_0, axis=0), axis=-1)\n", "\n", @@ -670,11 +670,11 @@ "colab": {} }, "source": [ - "# Illustrate intensity based loss\n", + "# Illustrate intensity based loss\n", "from deepreg.model.loss.image import ssd\n", "\n", - "# The ssd function requires [batch, dim1, dim2, dim3, ch] sized tensors -\n", - "# expand the image dims as ours are [batch, dim1, dim2, dim3]\n", + "# The ssd function requires [batch, dim1, dim2, dim3, ch] sized tensors -\n", + "# expand the image dims as ours are [batch, dim1, dim2, dim3]\n", "moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)\n", "\n", "fig, axs = plt.subplots(1, 2)\n", @@ -762,7 +762,7 @@ "colab": {} }, "source": [ - "# We set the plot size to some parameters.\n", + "# We set the plot size to some parameters.\n", "plt.rcParams[\"figure.figsize\"] = (100,100)" ], "execution_count": null, @@ -776,7 +776,7 @@ "colab": {} }, "source": [ - "# We define some utility functions first\n", + "# We define some utility functions first\n", "## optimisation\n", "@tf.function\n", "def train_step_CT(grid, weights, optimizer, mov, fix):\n", @@ -793,16 +793,16 @@ " :param fix: fixed image, tensor shape[1, f_dim1, f_dim2, f_dim3]\n", " :return loss: image dissimilarity to minimise\n", " \"\"\"\n", - " # We initialise an instance of gradient tape to track operations\n", + " # We initialise an instance of gradient tape to track operations\n", " with tf.GradientTape() as tape:\n", " pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights))\n", - " # Calculate the loss function between the fixed image\n", - " # and the moving image\n", + " # Calculate the loss function between the fixed image\n", + " # and the moving image\n", " loss = image_loss.dissimilarity_fn(\n", " y_true=fix, y_pred=pred, name=image_loss_name\n", " )\n", " gradients = tape.gradient(loss, [weights])\n", - " # Applying the gradients\n", + " # Applying the gradients\n", " optimizer.apply_gradients(zip(gradients, [weights]))\n", " return loss\n", "\n", @@ -816,7 +816,7 @@ " \"\"\"\n", " # Display\n", " plt.figure()\n", - " # Generate a nIdx images in 3s\n", + " # Generate a nIdx images in 3s\n", " for idx in range(nIdx):\n", " axs = plt.subplot(nIdx, 3, 3 * idx + 1)\n", " axs.imshow(moving_image[0, ..., idx_slices[idx]], cmap=\"gray\")\n", @@ -862,7 +862,7 @@ "colab": {} }, "source": [ - "# We re-use the data from the head and neck CT we used to illustrate the losses, so we don't have to redownload it.\n", + "# We re-use the data from the head and neck CT we used to illustrate the losses, so we don't have to redownload it.\n", "\n", "## registration parameters\n", "image_loss_name = \"ssd\"\n", @@ -881,23 +881,23 @@ "\n", "# generate a radomly-affine-transformed moving image using DeepReg utils\n", "fixed_image_size = fixed_image.shape\n", - "# The following function generates a random transform.\n", + "# The following function generates a random transform.\n", "transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2)\n", "\n", - "# We create a reference grid of image size\n", + "# We create a reference grid of image size\n", "grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4])\n", "\n", "# We warp our reference grid by our random transform\n", "grid_random = layer_util.warp_grid(grid_ref, transform_random)\n", "# We resample the fixed image with the random transform to create a distorted\n", - "# image, which we will use as our moving image.\n", + "# image, which we will use as our moving image.\n", "moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)\n", "\n", "# warp the labels to get ground-truth using the same random affine transform\n", - "# for validation\n", + "# for validation\n", "fixed_labels = tf.cast(tf.expand_dims(fid[\"label\"], axis=0), dtype=tf.float32)\n", - "# We have multiple labels, so we apply the transform to all the labels by\n", - "# stacking them\n", + "# We have multiple labels, so we apply the transform to all the labels by\n", + "# stacking them\n", "moving_labels = tf.stack(\n", " [\n", " layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random)\n", @@ -915,12 +915,12 @@ " trainable=True,\n", ")\n", "\n", - "# We perform an optimisation by backpropagating the loss through to our \n", - "# trainable weight layer.\n", + "# We perform an optimisation by backpropagating the loss through to our \n", + "# trainable weight layer.\n", "optimiser = tf.optimizers.Adam(learning_rate)\n", "\n", "\n", - "# Perform an optimisation for total_iter number of steps.\n", + "# Perform an optimisation for total_iter number of steps.\n", "for step in range(total_iter):\n", " loss_opt = train_step_CT(grid_ref, var_affine, optimiser, moving_image, fixed_image)\n", " if (step % 50) == 0: # print info\n", @@ -977,7 +977,7 @@ "colab": {} }, "source": [ - "# Check how the labels have been registered\n", + "# Check how the labels have been registered\n", "warped_moving_labels = layer_util.resample(vol=moving_labels, loc=grid_opt)\n", "\n", "# display\n", @@ -1046,17 +1046,17 @@ " \"\"\"\n", " with tf.GradientTape() as tape:\n", " pred = warper(inputs=[weights, mov])\n", - " # Calculating the image loss between the ground truth and prediction\n", + " # Calculating the image loss between the ground truth and prediction\n", " loss_image = image_loss.dissimilarity_fn(\n", " y_true=fix, y_pred=pred, name=image_loss_name\n", " )\n", " # We calculate the deformation loss\n", " loss_deform = deform_loss.local_displacement_energy(weights, deform_loss_name)\n", - " # Total loss is weighted\n", + " # Total loss is weighted\n", " loss = loss_image + weight_deform_loss * loss_deform\n", - " # We calculate the gradients by backpropagating the loss to the trainable layer\n", + " # We calculate the gradients by backpropagating the loss to the trainable layer\n", " gradients = tape.gradient(loss, [weights])\n", - " # Using our tf optimizer, we apply the gradients\n", + " # Using our tf optimizer, we apply the gradients\n", " optimizer.apply_gradients(zip(gradients, [weights]))\n", " return loss, loss_image, loss_deform" ], @@ -1071,7 +1071,7 @@ "colab": {} }, "source": [ - "## We download the data for this example.\n", + "## We download the data for this example.\n", "MAIN_PATH = os.getcwd()\n", "\n", "DATA_PATH = \"dataset\"\n", @@ -1105,10 +1105,10 @@ "source": [ "## We define some registration parameters - play around with these!\n", "image_loss_name = \"lncc\" # local normalised cross correlation loss between images\n", - "deform_loss_name = \"bending\" # Loss to measure the bending energy of the ddf\n", - "weight_deform_loss = 1 # we weight the deformation loss\n", + "deform_loss_name = \"bending\" # Loss to measure the bending energy of the ddf\n", + "weight_deform_loss = 1 # we weight the deformation loss\n", "learning_rate = 0.1\n", - "total_iter = int(3001) # This will train for longer" + "total_iter = int(3001) # This will train for longer" ], "execution_count": null, "outputs": [] @@ -1121,7 +1121,7 @@ "colab": {} }, "source": [ - "# We get our two subject images from our datasets\n", + "# We get our two subject images from our datasets\n", "moving_image = tf.cast(tf.expand_dims(fid[\"image0\"], axis=0), dtype=tf.float32)\n", "fixed_image = tf.cast(tf.expand_dims(fid[\"image1\"], axis=0), dtype=tf.float32)" ], @@ -1136,24 +1136,24 @@ "colab": {} }, "source": [ - "# We initialise our layers\n", + "# We initialise our layers\n", "fixed_image_size = fixed_image.shape\n", "initialiser = tf.random_normal_initializer(mean=0, stddev=1e-3)\n", "\n", - "# Creating our DDF tensor that can be trained\n", + "# Creating our DDF tensor that can be trained\n", "# The DDF will be of shape [IM_SIZE_1, IM_SIZE_2, 3],\n", - "# representing the displacement field at each pixel and xyz dimension.\n", + "# representing the displacement field at each pixel and xyz dimension.\n", "var_ddf = tf.Variable(initialiser(fixed_image_size + [3]), name=\"ddf\", trainable=True)\n", "\n", - "# We create a warping layer and initialise an optimizer\n", + "# We create a warping layer and initialise an optimizer\n", "warping = layer.Warping(fixed_image_size=fixed_image_size[1:4])\n", "optimiser = tf.optimizers.Adam(learning_rate)\n", "\n", "\n", "## Optimising the layer\n", - "## With GPU this takes about 5 minutes.\n", + "## With GPU this takes about 5 minutes.\n", "for step in range(total_iter):\n", - " # Call the gradient tape function\n", + " # Call the gradient tape function\n", " loss_opt, loss_image_opt, loss_deform_opt = train_step(\n", " warping, var_ddf, optimiser, moving_image, fixed_image\n", " )\n", @@ -1168,7 +1168,7 @@ " deform_loss_name,\n", " loss_deform_opt,\n", " )\n", - " # Visualising loss during training\n", + " # Visualising loss during training\n", " # plt.figure()\n", " # fig, axs = plt.subplots(1, 3)\n", " # warped_moving_image = warping(inputs=[var_ddf, moving_image])\n", @@ -1206,7 +1206,7 @@ }, "source": [ "## We can observe the effects of the warping on the moving label using\n", - "# the optimised affine transformation\n", + "# the optimised affine transformation\n", "moving_label = tf.cast(tf.expand_dims(fid[\"label0\"], axis=0), dtype=tf.float32)\n", "fixed_label = tf.cast(tf.expand_dims(fid[\"label1\"], axis=0), dtype=tf.float32)\n", "\n", @@ -1394,7 +1394,7 @@ "\n", "gpu = \"0\"\n", "gpu_allow_growth = False\n", - "# This will take a couple of minutes\n", + "# This will take a couple of minutes\n", "predict(\n", " gpu=gpu,\n", " gpu_allow_growth=gpu_allow_growth,\n", diff --git a/docs/source/assets/grouped_mask_prostate_longitudinal.png b/docs/source/assets/grouped_mask_prostate_longitudinal.png index cc7b651eb..210bcfda4 100644 Binary files a/docs/source/assets/grouped_mask_prostate_longitudinal.png and b/docs/source/assets/grouped_mask_prostate_longitudinal.png differ diff --git a/docs/source/assets/grouped_mr_heart.png b/docs/source/assets/grouped_mr_heart.png index 75907a857..8df73b8fb 100644 Binary files a/docs/source/assets/grouped_mr_heart.png and b/docs/source/assets/grouped_mr_heart.png differ diff --git a/docs/source/assets/paired_ct_lung.png b/docs/source/assets/paired_ct_lung.png index 040a8cea4..b9f011e56 100644 Binary files a/docs/source/assets/paired_ct_lung.png and b/docs/source/assets/paired_ct_lung.png differ diff --git a/docs/source/assets/paired_mrus_prostate.png b/docs/source/assets/paired_mrus_prostate.png index fe730a079..31bfa7821 100644 Binary files a/docs/source/assets/paired_mrus_prostate.png and b/docs/source/assets/paired_mrus_prostate.png differ diff --git a/docs/source/assets/unpaired_ct_lung.png b/docs/source/assets/unpaired_ct_lung.png index e91462c2c..e1328556a 100644 Binary files a/docs/source/assets/unpaired_ct_lung.png and b/docs/source/assets/unpaired_ct_lung.png differ diff --git a/docs/source/assets/unpaired_us_prostate_cv.png b/docs/source/assets/unpaired_us_prostate_cv.png index 3fbe1a92b..9f7706475 100644 Binary files a/docs/source/assets/unpaired_us_prostate_cv.png and b/docs/source/assets/unpaired_us_prostate_cv.png differ diff --git a/docs/source/docs/configuration.md b/docs/source/docs/configuration.md index 78d1d2128..c82820856 100644 --- a/docs/source/docs/configuration.md +++ b/docs/source/docs/configuration.md @@ -1,9 +1,9 @@ # Configuration File -Besides the arguments provided to the command line tools, detailed training and -prediction configuration is specified in a `yaml` file. The configuration file contains +In addition to the arguments provided to the command line tools, detailed training and +prediction configuration is specified in a `YAML` file. The configuration file contains two sections, `dataset` and `train`. Within `dataset` one specifies the data file -formas, sizes, as well as the data loader to use. The `train` section specifies +formats, sizes, as well as the data loader to use. The `train` section specifies parameters related to the neural network. ## Dataset section @@ -11,7 +11,7 @@ parameters related to the neural network. The `dataset` section specifies the path to the data to be used during training, the data loader to use as well as the specific arguments to configure the data loader. -### Dir key - Required +### Dir key - Required The paths to the training, validation and testing data are specified under a `dir` dictionary key like this: @@ -19,9 +19,9 @@ dictionary key like this: ```yaml dataset: dir: - train: "data/test/h5/paired/train" # folder contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data ``` Multiple dataset directories can be specified, such that data are sampled across several @@ -39,18 +39,18 @@ dataset: ### Format key - Required -The data file format we supply the data loaders will influence the behaviour, so we must +The data file format we supply the data loaders will influence the behavior, so we must specify the data file format using the `format` key. Currently, DeepReg data loaders -support nifti and h5 file types - alternate file formats will raise errors in the data +support Nifti and H5 file types - alternate file formats will raise errors in the data loaders. To indicate which format to use, pass a string to this field as either "nifti" or "h5": ```yaml dataset: dir: - train: "data/test/h5/paired/train" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" ``` @@ -58,7 +58,7 @@ Depending on the data file format, DeepReg expects the images and labels to be s specific structures: check the [data loader configuration](dataset_loader.html) for more details. -### Labeled key - Required +### Labeled key - Required The `labeled` key indicates whether segmentation labels should be used during training. A Boolean is used to indicate the usage of labels: @@ -66,9 +66,9 @@ A Boolean is used to indicate the usage of labels: ```yaml dataset: dir: - train: "data/test/h5/paired/train1" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" labeled: true ``` @@ -79,26 +79,26 @@ available in the associated directories. ### Type key - Required The type of data loader used will depend on how one wants to train the network. -Currently, DeepReg data loaders support the `paired`, `unpaired` and `grouped` training +Currently, DeepReg data loaders support the `paired`, `unpaired`, and `grouped` training strategies. Passing a string that doesn't match any of the above would raise an error. The data loader type would be specified using the `type` key: ```yaml dataset: dir: - train: "data/test/h5/paired/train1" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" - type: "paired" # one of "paired", "unpaired" or "grouped" + type: "paired" # one of "paired", "unpaired" or "grouped" ``` -#### Data loader dependent keys +#### Data loader dependent keys -Depending on which string is passed to the `type` key, DeepReg will initialise a +Depending on which string is passed to the `type` key, DeepReg will initialize a different data loader instance with different sampling strategies. These are described in depth in the [dataset loader configuration](dataset_loader.html) documentation. Here -we outline the arguments necessary to configure the different dataloaders. +we outline the arguments necessary to configure the different data loaders. ###### Sample_label - Required @@ -118,13 +118,13 @@ be built to sample `all` the data-label pairs, regardless of the argument passed ```yaml dataset: dir: - train: "data/test/h5/paired/train" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" - type: "paired" # one of "paired", "unpaired" or "grouped" + type: "paired" # one of "paired", "unpaired" or "grouped" labeled: true - sample_label: "sample" # one of "sample", "all" or None + sample_label: "sample" # one of "sample", "all" or None ``` In the case the `labeled` argument is false, the sample_label is unused, but still must @@ -145,18 +145,18 @@ For more details please refer to ```yaml dataset: dir: - train: "data/test/h5/paired/train1" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" - type: "paired" # one of "paired", "unpaired" or "grouped" + type: "paired" # one of "paired", "unpaired" or "grouped" labeled: true - sample_label: "sample" # one of "sample", "all" or None + sample_label: "sample" # one of "sample", "all" or None moving_image_shape: [16, 16, 3] fixed_image_shape: [16, 16, 3] ``` -##### Unpaired +##### Unpaired - `image_shape`: (list, tuple) of ints, len 3, corresponding to (dim1, dim2, dim3) of the 3D image. @@ -164,13 +164,13 @@ dataset: ```yaml dataset: dir: - train: "data/test/h5/paired/train1" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" - type: "unpaired" # one of "paired", "unpaired" or "grouped" + type: "unpaired" # one of "paired", "unpaired" or "grouped" labeled: true - sample_label: "sample" # one of "sample", "all" or None + sample_label: "sample" # one of "sample", "all" or None image_shape: [16, 16, 3] ``` @@ -189,13 +189,13 @@ dataset: ```yaml dataset: dir: - train: "data/test/h5/paired/train1" # folders contains training data - valid: "data/test/h5/paired/valid" # folder contains validation data - test: "data/test/h5/paired/test" # folder contains test data + train: "data/test/h5/paired/train" # folder containing training data + valid: "data/test/h5/paired/valid" # folder containing validation data + test: "data/test/h5/paired/test" # folder containing test data format: "nifti" - type: "grouped" # one of "paired", "unpaired" or "grouped" + type: "grouped" # one of "paired", "unpaired" or "grouped" labeled: true - sample_label: "sample" # one of "sample", "all" or None + sample_label: "sample" # one of "sample", "all" or None image_shape: [16, 16, 3] sample_image_in_group: true intra_group_prob: 0.7 @@ -210,19 +210,21 @@ The `train` section defines the neural network training hyper-parameters, by spe subsections, `method`, `backbone`, `loss`, `optimizer`, `preprocess` and other training hyper-parameters, including `epochs` and `save_period`. -### Method - required +### Method - required -The `method` argument defines the registration type. It must be a string type, one of -"ddf", "dvf", "conditional", which are the currently supported registration methods. +The `method` argument defines the registration type. It must be a string. Feasible +values are: `ddf`, `dvf`, and `conditional`, corresponding to the dense displacement +field (DDF) based model, dense velocity field (DDF) based model, and conditional model +presented in the [registration tutorial](../tutorial/registration.html). ```yaml train: method: "ddf" # One of ddf, dvf, conditional ``` -### Backbone - required +### Backbone - required -The `backbone` subsection is used to define the network, with all the network specific +The `backbone` subsection is used to define the network, with all the network-specific arguments under the same indent. The first argument should be the argument `name`, which should be string type, one of "unet", "local" or "global", to define a UNet, LocalNet or GlobalNet backbone, respectively. With Registry functionalities, you can also define @@ -241,10 +243,10 @@ train: #### UNet -The UNet model requires several additional arguments to define it's structure: +The UNet model requires several additional arguments to define its structure: - `depth`: int, defines the depth of the UNet from first to bottom, bottleneck layer. -- `pooling`: Boolean, pooling method used for downsampling. True: non-parametrized +- `pooling`: Boolean, pooling method used for down-sampling. True: non-parametrized pooling will be used, False: conv3d will be used. - `concat_skip`: Boolean, concatenation method for skip layers in UNet. True: concatenation of layers, False: addition is used instead. @@ -260,9 +262,9 @@ train: concat_skip: true ``` -#### Local and GlobalNet +#### Local and GlobalNet -The LocalNet has an encoder-decoder structure, and extracts information from tensors at +The LocalNet has an encoder-decoder structure and extracts information from tensors at one or multiple resolution levels. We can define which levels to extract info from with the `extract_levels` argument. @@ -282,38 +284,41 @@ train: extract_levels: [0, 1, 2] ``` -### Loss - required +### Loss - required This section defines the loss in training. -The losses in DeepReg are defined depending on the type of network to be built, and can -be split into three sections: image and label losses (between moving and fixed tensors), -and regularization losses (on the DDFs predicted). +There are three different categories of losses in DeepReg: -DeepReg uses `tf.keras.Model` `add_loss()` in the Registry method to add losses to the -model, which provides some flexibility in configuration. +- **image loss**: loss between the fixed image and predicted fixed image (warped moving + image). +- **label loss**: loss between the fixed label and predicted fixed label (warped moving + label). +- **regularization loss**: loss on predicted dense displacement field (DDF). -Currently, DeepReg offers conditional, ddf/dvf and affine registration pre-built models. -Traditionally, models have been configured with the following losses: +Not all losses are applicable for all models, the details are in the following table. -- Conditional: label loss. -- DDF/DVF: ddf loss, image loss, label loss. -- Affine: ddf loss, image loss, label loss. +| | DDF / DVF | Conditional | +| ------------------- | ------------------------------ | -------------- | +| Image Loss | Applicable | Non-applicable | +| Label Loss | Applicable if data are labeled | Applicable | +| Regularization Loss | Applicable | Non-applicable | -The above sections are necessary to build a model correctly. Not passing all sections as -defined above may raise errors. Currently you can call one loss per field eg. one label -loss type, one image loss type and one ddf/dvf loss type. +The configuration for non-applicable losses will be ignored without errors. The loss +will also be ignored if the weight is zero. However, each model must define at least one +loss, otherwise error will be raised by TensorFlow. -However, setting the weight to 0 will effectively mean the model ignores the loss. -Additionally, weights on label loss will be ignored if the `labeled` key is false or if -segmentation labels are unavailable. +For each loss, there are multiple existing loss functions to choose. The registry +mechanism can also be used to use custom loss functions. Please read the +[registry documentation](registry.html) for more details. -#### Image +#### Image The image loss calculates dissimilarity between warped image tensors and fixed image tensors. -- `weight`: float type, weight of individual loss element in total loss function. +- `weight`: float type, the weight of individual loss element in the total loss + function. - `name`: string type, one of "lncc", "ssd" or "gmi". ```yaml @@ -347,7 +352,7 @@ same indent level: - `sigma_ratio`: float, optional, default=0.5. A hyperparameter for the Gaussian kernel density estimation. -#### Label +#### Label The label loss calculates dissimilarity between labels. @@ -401,12 +406,12 @@ at the same indent level: 0.5. -#### Regularization +#### Regularization The regularization section configures the losses for the DDF. To instantiate this part of the loss, pass "regularization" into the config file as a field. -- `weight`: float type, weight of the regularization loss. +- `weight`: float type, the weight of the regularization loss. - `name`: string type, the type of deformation energy to compute. Options include "bending", "gradient" @@ -465,7 +470,7 @@ train: name: "dice" ``` -### Optimizer - required +### Optimizer - required The optimizer can be defined by using a `name` and then passing optimizer specific arguments with the same name. All optimizers can use the `learning_rate` argument. @@ -474,17 +479,17 @@ arguments with the same name. All optimizers can use the `learning_rate` argumen "sgd", "rms". - `adam`: If adam is passed into `name`, the `adam` field must be passed. The dictionary - can be empty, which initalises a default + can be empty, which initializes a default [Keras Adam optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam). Alternatively, fields with names equivalent to those specified in the optimizer documentation can be used. - `sgd`: If sgd is passed into `name`, the `sgd` field must be passed. The dictionary - can be empty, which initalises a default + can be empty, which initializes a default [Keras SGD optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD). Alternatively, fields with names equivalent to those specified in the optimizer documentation can be used instead. - `rms`: If rms is passed into `name`, the `rms` field must be passed. The dictionary - can be empty, which initalises a default + can be empty, which initializes a default [Keras RMSprop optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop). Alternatively, fields with names equivalent to those specified in the optimizer documentation can be used instead. @@ -526,9 +531,9 @@ train: nesterov: false ``` -### Preprocess - required +### Preprocess - required -The `preprocess` field defines how the dataloader feeds data into the model. +The `preprocess` field defines how the data loader feeds data into the model. - `batch_size`: int, the batch size to pass to the network on each training step. - `shuffle_buffer_num_batch`: int, helps define how much data should be pre-loaded into @@ -559,7 +564,7 @@ train: ### Epochs - required -The `epochs` field defines number of epochs to train the network for. +The `epochs` field defines the number of epochs to train the network for. ```yaml train: @@ -584,7 +589,7 @@ train: epochs: 1000 ``` -### Saving frequency - required +### Saving frequency - required The `save_period` field defines the save frequency - the model will be saved every `save_period` epochs. diff --git a/test/unit/test_backbone.py b/test/unit/test_backbone.py index 56e1e2c18..58b3b4a14 100644 --- a/test/unit/test_backbone.py +++ b/test/unit/test_backbone.py @@ -78,7 +78,8 @@ def test_call_global_net(): is correct. """ out = 3 - im_size = [1, 2, 3] + im_size = (1, 2, 3) + batch_size = 5 # initialising GlobalNet instance global_test = g.GlobalNet( image_size=im_size, @@ -90,12 +91,14 @@ def test_call_global_net(): ) # pass an input of all zeros inputs = tf.constant( - np.zeros((5, im_size[0], im_size[1], im_size[2], out), dtype=np.float32) + np.zeros( + (batch_size, im_size[0], im_size[1], im_size[2], out), dtype=np.float32 + ) ) # get outputs by calling - output = global_test.call(inputs) - # expected shape is (5, 1, 2, 3, 3) - assert all(x == y for x, y in zip(inputs.shape, output.shape)) + ddf, theta = global_test.call(inputs) + assert ddf.shape == (batch_size, *im_size, 3) + assert theta.shape == (batch_size, 4, 3) class TestLocalNet: diff --git a/test/unit/test_loss_image.py b/test/unit/test_loss_image.py index 842c57f3f..22411ad5d 100644 --- a/test/unit/test_loss_image.py +++ b/test/unit/test_loss_image.py @@ -60,7 +60,7 @@ def test_get_config(self): expected = dict( num_bins=23, sigma_ratio=0.5, - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="GlobalMutualInformation", ) assert got == expected @@ -114,7 +114,7 @@ def test_get_config(self): expected = dict( kernel_size=9, kernel_type="rectangular", - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="LocalNormalizedCrossCorrelation", ) assert got == expected diff --git a/test/unit/test_loss_label.py b/test/unit/test_loss_label.py index 18a88c23b..e6c9b597f 100644 --- a/test/unit/test_loss_label.py +++ b/test/unit/test_loss_label.py @@ -44,7 +44,7 @@ def test_get_config(self): expected = dict( scales=None, kernel="gaussian", - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="MultiScaleLoss", ) assert got == expected @@ -89,7 +89,7 @@ def test_get_config(self): neg_weight=0.0, scales=None, kernel="gaussian", - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="DiceScore", ) assert got == expected @@ -130,7 +130,7 @@ def test_get_config(self): neg_weight=0.0, scales=None, kernel="gaussian", - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="CrossEntropy", ) assert got == expected @@ -173,7 +173,7 @@ def test_get_config(self): binary=False, scales=None, kernel="gaussian", - reduction=tf.keras.losses.Reduction.AUTO, + reduction=tf.keras.losses.Reduction.SUM, name="JaccardIndex", ) assert got == expected diff --git a/test/unit/test_network.py b/test/unit/test_network.py new file mode 100644 index 000000000..9fabb6171 --- /dev/null +++ b/test/unit/test_network.py @@ -0,0 +1,299 @@ +# coding=utf-8 + +""" +Tests for deepreg/_model/network/ddf_dvf.py +""" +import itertools +from copy import deepcopy +from unittest.mock import MagicMock, patch + +import pytest + +from deepreg.model.network import RegistrationModel +from deepreg.registry import REGISTRY + +moving_image_size = (1, 3, 5) +fixed_image_size = (2, 4, 6) +index_size = 2 +batch_size = 3 +backbone_args = { + "local": {"extract_levels": [1, 2]}, + "global": {"extract_levels": [1, 2]}, + "unet": {"depth": 2}, +} +config = { + "backbone": { + "num_channel_initial": 4, + }, + "loss": { + "image": {"name": "lncc", "weight": 0.1}, + "label": { + "name": "dice", + "weight": 1, + "scales": [0, 1], + }, + "regularization": {"weight": 0.1, "name": "bending"}, + }, +} + + +@pytest.fixture +def model(method: str, labeled: bool, backbone: str) -> RegistrationModel: + """ + A specific registration model object. + + :param method: name of method + :param labeled: whether the data is labeled + :param backbone: name of backbone + :return: the built object + """ + copied = deepcopy(config) + copied["method"] = method + copied["backbone"]["name"] = backbone + copied["backbone"] = {**backbone_args[backbone], **copied["backbone"]} + return REGISTRY.build_model( + config=dict( + name=method, # TODO we store method twice + moving_image_size=moving_image_size, + fixed_image_size=fixed_image_size, + index_size=index_size, + labeled=labeled, + batch_size=batch_size, + config=copied, + ) + ) + + +def pytest_generate_tests(metafunc): + """ + Test parameter generator. + + This function is called once per each test function. + It takes the attribute `params` from the test class, + and then use the same `params` for all tests inside the class. + This is specific for test of registration models only. + + This is modified from the pytest documentation, + where their version defined the params for each test function separately. + + https://docs.pytest.org/en/stable/example/parametrize.html#parametrizing-test-methods-through-per-class-configuration + + :param metafunc: + :return: + """ + # + funcarglist = metafunc.cls.params + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, [[funcargs[name] for name in argnames] for funcargs in funcarglist] + ) + + +class TestRegistrationModel: + params = [dict(labeled=True), dict(labeled=False)] + + @pytest.fixture + def empty_model(self, labeled: bool) -> RegistrationModel: + """ + A RegistrationModel with build_model and build_loss mocked/overwritten. + + :param labeled: whether the data is labeled + :return: the mocked object + """ + with patch.multiple( + RegistrationModel, + build_model=MagicMock(return_value=None), + build_loss=MagicMock(return_value=None), + ): + return RegistrationModel( + moving_image_size=moving_image_size, + fixed_image_size=fixed_image_size, + index_size=index_size, + labeled=labeled, + batch_size=batch_size, + config=dict(), + ) + + def test_get_config(self, empty_model, labeled): + got = empty_model.get_config() + assert got == dict( + moving_image_size=moving_image_size, + fixed_image_size=fixed_image_size, + index_size=index_size, + labeled=labeled, + batch_size=batch_size, + config=dict(), + num_devices=1, + name="RegistrationModel", + ) + + def test_build_inputs(self, empty_model, labeled): + inputs = empty_model.build_inputs() + expected_inputs_len = 5 if labeled else 3 + assert len(inputs) == expected_inputs_len + + moving_image = inputs["moving_image"] + fixed_image = inputs["fixed_image"] + indices = inputs["indices"] + assert moving_image.shape == (batch_size, *moving_image_size) + assert fixed_image.shape == (batch_size, *fixed_image_size) + assert indices.shape == (batch_size, index_size) + + if labeled: + moving_label = inputs["moving_label"] + fixed_label = inputs["fixed_label"] + assert moving_label.shape == (batch_size, *moving_image_size) + assert fixed_label.shape == (batch_size, *fixed_image_size) + + def test_concat_images(self, empty_model, labeled): + inputs = empty_model.build_inputs() + moving_image = inputs["moving_image"] + fixed_image = inputs["fixed_image"] + if labeled: + moving_label = inputs["moving_label"] + images = empty_model.concat_images(moving_image, fixed_image, moving_label) + assert images.shape == (batch_size, *fixed_image_size, 3) + else: + images = empty_model.concat_images(moving_image, fixed_image) + assert images.shape == (batch_size, *fixed_image_size, 2) + + +class TestBuildLoss: + params = [ + dict(option=0, expected=2), + dict(option=1, expected=2), + dict(option=2, expected=3), + ] + + def test_no_image_loss(self, option: int, expected: int): + method = "ddf" + backbone = "local" + labeled = True + copied = deepcopy(config) + copied["method"] = method + copied["backbone"]["name"] = backbone + copied["backbone"] = {**backbone_args[backbone], **copied["backbone"]} + + if option == 0: + # remove image loss config, so loss is not used + copied["loss"].pop("image") + elif option == 1: + # set image loss weight to zero, so loss is not used + copied["loss"]["image"]["weight"] = 0.0 + else: + # remove image loss weight, so loss is used with default weight 1 + copied["loss"]["image"].pop("weight") + + ddf_model = REGISTRY.build_model( + config=dict( + name=method, # TODO we store method twice + moving_image_size=moving_image_size, + fixed_image_size=fixed_image_size, + index_size=index_size, + labeled=labeled, + batch_size=batch_size, + config=copied, + ) + ) + + assert len(ddf_model._model.losses) == expected + + +class TestDDFModel: + params = [ + dict(method=method, labeled=labeled, backbone=backbone) + for method, labeled, backbone in itertools.product( + ["ddf"], [True, False], ["local", "global", "unet"] + ) + ] + + def test_build_model(self, model, labeled, backbone): + expected_outputs_len = 3 if labeled else 2 + if backbone == "global": + expected_outputs_len += 1 + theta = model._outputs["theta"] + assert theta.shape == (batch_size, 4, 3) + assert len(model._outputs) == expected_outputs_len + + ddf = model._outputs["ddf"] + pred_fixed_image = model._outputs["pred_fixed_image"] + assert ddf.shape == (batch_size, *fixed_image_size, 3) + assert pred_fixed_image.shape == (batch_size, *fixed_image_size) + + if labeled: + pred_fixed_label = model._outputs["pred_fixed_label"] + assert pred_fixed_label.shape == (batch_size, *fixed_image_size) + + def test_build_loss(self, model, labeled, backbone): + expected = 3 if labeled else 2 + assert len(model._model.losses) == expected + + def test_postprocess(self, model, labeled, backbone): + indices, processed = model.postprocess( + inputs=model._inputs, outputs=model._outputs + ) + assert indices.shape == (batch_size, index_size) + expected = 7 if labeled else 4 + if backbone == "global": + expected += 1 + assert len(processed) == expected + + +class TestDVFModel: + params = [ + dict(method=method, labeled=labeled, backbone=backbone) + for method, labeled, backbone in itertools.product( + ["dvf"], [True, False], ["local", "unet"] + ) + ] + + def test_build_model(self, model, labeled, backbone): + expected_outputs_len = 4 if labeled else 3 + assert len(model._outputs) == expected_outputs_len + + dvf = model._outputs["dvf"] + ddf = model._outputs["ddf"] + pred_fixed_image = model._outputs["pred_fixed_image"] + assert dvf.shape == (batch_size, *fixed_image_size, 3) + assert ddf.shape == (batch_size, *fixed_image_size, 3) + assert pred_fixed_image.shape == (batch_size, *fixed_image_size) + + if labeled: + pred_fixed_label = model._outputs["pred_fixed_label"] + assert pred_fixed_label.shape == (batch_size, *fixed_image_size) + + def test_build_loss(self, model, labeled, backbone): + expected = 3 if labeled else 2 + assert len(model._model.losses) == expected + + def test_postprocess(self, model, labeled, backbone): + indices, processed = model.postprocess( + inputs=model._inputs, outputs=model._outputs + ) + assert indices.shape == (batch_size, index_size) + expected = 8 if labeled else 5 + assert len(processed) == expected + + +class TestConditionalModel: + params = [ + dict(method=method, labeled=labeled, backbone=backbone) + for method, labeled, backbone in itertools.product( + ["conditional"], [True], ["local", "unet"] + ) + ] + + def test_build_model(self, model, labeled, backbone): + assert len(model._outputs) == 1 + pred_fixed_label = model._outputs["pred_fixed_label"] + assert pred_fixed_label.shape == (batch_size, *fixed_image_size) + + def test_build_loss(self, model, labeled, backbone): + assert len(model._model.losses) == 1 + + def test_postprocess(self, model, labeled, backbone): + indices, processed = model.postprocess( + inputs=model._inputs, outputs=model._outputs + ) + assert indices.shape == (batch_size, index_size) + assert len(processed) == 5 diff --git a/test/unit/test_network_affine.py b/test/unit/test_network_affine.py deleted file mode 100644 index 1ff0394c8..000000000 --- a/test/unit/test_network_affine.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 - -""" -Tests for deepreg/model/network/affine.py -""" -import tensorflow as tf - -from deepreg.model.network.affine import affine_forward, build_affine_model -from deepreg.model.network.util import build_backbone -from deepreg.registry import REGISTRY - - -def test_affine_forward(): - """ - Testing that affine_forward function returns the tensors with correct shapes - """ - - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - - global_net = build_backbone( - image_size=fixed_image_size, - out_channels=3, - config={ - "name": "global", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - method_name="affine", - registry=REGISTRY, - ) - - # Check conditional mode network output shapes - Pass - affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward( - backbone=global_net, - moving_image=tf.ones((batch_size,) + moving_image_size), - fixed_image=tf.ones((batch_size,) + fixed_image_size), - moving_label=tf.ones((batch_size,) + moving_image_size), - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - ) - assert affine.shape == (batch_size,) + (4,) + (3,) - assert ddf.shape == (batch_size,) + fixed_image_size + (3,) - assert pred_fixed_image.shape == (batch_size,) + fixed_image_size - assert pred_fixed_label.shape == (batch_size,) + fixed_image_size - assert grid_fixed.shape == fixed_image_size + (3,) - - -def test_build_affine_model(): - """ - Testing that build_affine_model function returns the tensors with correct shapes - """ - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - - model = build_affine_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=1, - labeled=True, - batch_size=batch_size, - train_config={ - "method": "affine", - "backbone": { - "name": "global", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - "loss": { - "image": {"name": "lncc", "weight": 0.1}, - "label": { - "name": "dice", - "weight": 1, - "scales": [0, 1, 2, 4, 8, 16, 32], - }, - "regularization": {"weight": 0.0, "name": "bending"}, - }, - }, - registry=REGISTRY, - ) - - inputs = { - "moving_image": tf.ones((batch_size,) + moving_image_size), - "fixed_image": tf.ones((batch_size,) + fixed_image_size), - "indices": 1, - "moving_label": tf.ones((batch_size,) + moving_image_size), - "fixed_label": tf.ones((batch_size,) + fixed_image_size), - } - - outputs = model(inputs) - - expected_outputs_keys = ["affine", "ddf", "pred_fixed_label"] - assert all(keys in expected_outputs_keys for keys in outputs) - assert outputs["pred_fixed_label"].shape == (batch_size,) + fixed_image_size - assert outputs["affine"].shape == (batch_size,) + (4,) + (3,) - assert outputs["ddf"].shape == (batch_size,) + fixed_image_size + (3,) diff --git a/test/unit/test_network_build.py b/test/unit/test_network_build.py deleted file mode 100644 index ee96ff647..000000000 --- a/test/unit/test_network_build.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - -from deepreg.model.network.build import build_model -from deepreg.registry import REGISTRY - - -class TestBuildModel: - moving_image_size = (4, 8, 16) - fixed_image_size = (8, 16, 24) - index_size = 2 - batch_size = 2 - train_config = { - "backbone": { - "name": "local", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - "loss": { - "image": {"name": "lncc", "weight": 0.1}, - "label": { - "name": "dice", - "weight": 1, - "scales": [0, 1, 2, 4, 8, 16, 32], - }, - "regularization": {"weight": 0.0, "name": "bending"}, - }, - } - - @pytest.mark.parametrize( - "method,backbone", - [ - ("ddf", "local"), - ("dvf", "local"), - ("conditional", "local"), - ("affine", "global"), - ], - ) - def test_build(self, method, backbone): - train_config = self.train_config.copy() - train_config["method"] = method - train_config["backbone"]["name"] = backbone - build_model( - moving_image_size=self.moving_image_size, - fixed_image_size=self.fixed_image_size, - index_size=self.index_size, - labeled=True, - batch_size=self.batch_size, - train_config=train_config, - registry=REGISTRY, - ) - - def test_build_err(self): - train_config = self.train_config.copy() - train_config["method"] = "unknown" - with pytest.raises(ValueError) as err_info: - build_model( - moving_image_size=self.moving_image_size, - fixed_image_size=self.fixed_image_size, - index_size=self.index_size, - labeled=True, - batch_size=self.batch_size, - train_config=train_config, - registry=REGISTRY, - ) - assert "Unknown method" in str(err_info.value) diff --git a/test/unit/test_network_cond.py b/test/unit/test_network_cond.py deleted file mode 100644 index 9aba30978..000000000 --- a/test/unit/test_network_cond.py +++ /dev/null @@ -1,89 +0,0 @@ -# coding=utf-8 - -""" -Tests for deepreg/model/network/cond.py -""" -import tensorflow as tf - -from deepreg.model.network.cond import build_conditional_model, conditional_forward -from deepreg.model.network.util import build_backbone -from deepreg.registry import REGISTRY - - -def test_conditional_forward(): - """ - Testing that conditional_forward function returns the tensors with correct shapes - """ - - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - - local_net = build_backbone( - image_size=fixed_image_size, - out_channels=1, - config={ - "name": "local", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - method_name="conditional", - registry=REGISTRY, - ) - - # Check conditional mode network output shapes - Pass - pred_fixed_label, grid_fixed = conditional_forward( - backbone=local_net, - moving_image=tf.ones((batch_size,) + moving_image_size), - fixed_image=tf.ones((batch_size,) + fixed_image_size), - moving_label=tf.ones((batch_size,) + moving_image_size), - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - ) - assert pred_fixed_label.shape == (batch_size,) + fixed_image_size - assert grid_fixed.shape == fixed_image_size + (3,) - - -def test_build_conditional_model(): - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - - model = build_conditional_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=1, - labeled=True, - batch_size=batch_size, - train_config={ - "method": "conditional", - "backbone": { - "name": "local", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - "loss": { - "image": {"name": "lncc", "weight": 0.0}, - "label": { - "name": "dice", - "weight": 1, - "scales": [0, 1, 2, 4, 8, 16, 32], - }, - "regularization": {"weight": 0.5, "name": "bending"}, - }, - }, - registry=REGISTRY, - ) - - inputs = { - "moving_image": tf.ones((batch_size,) + moving_image_size), - "fixed_image": tf.ones((batch_size,) + fixed_image_size), - "indices": 1, - "moving_label": tf.ones((batch_size,) + moving_image_size), - "fixed_label": tf.ones((batch_size,) + fixed_image_size), - } - outputs = model(inputs) - - expected_outputs_keys = ["pred_fixed_label"] - assert all(keys in expected_outputs_keys for keys in outputs) - assert outputs["pred_fixed_label"].shape == (batch_size,) + fixed_image_size diff --git a/test/unit/test_network_ddf_dvf.py b/test/unit/test_network_ddf_dvf.py deleted file mode 100644 index d9444e94d..000000000 --- a/test/unit/test_network_ddf_dvf.py +++ /dev/null @@ -1,135 +0,0 @@ -# coding=utf-8 - -""" -Tests for deepreg/model/network/ddf_dvf.py -""" -import pytest -import tensorflow as tf - -from deepreg.model.network.ddf_dvf import build_ddf_dvf_model, ddf_dvf_forward -from deepreg.model.network.util import build_backbone -from deepreg.registry import REGISTRY - - -def test_ddf_dvf_forward(): - """ - Testing that ddf_dvf_forward function returns the tensors with correct shapes - """ - - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - - local_net = build_backbone( - image_size=fixed_image_size, - out_channels=3, - config={ - "name": "local", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - method_name="ddf", - registry=REGISTRY, - ) - - # Check DDF mode network output shapes - Pass - dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward( - backbone=local_net, - moving_image=tf.ones((batch_size,) + moving_image_size), - fixed_image=tf.ones((batch_size,) + fixed_image_size), - moving_label=tf.ones((batch_size,) + moving_image_size), - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - output_dvf=False, - ) - assert dvf is None - assert ddf.shape == (batch_size,) + fixed_image_size + (3,) - assert pred_fixed_image.shape == (batch_size,) + fixed_image_size - assert pred_fixed_label.shape == (batch_size,) + fixed_image_size - assert grid_fixed.shape == fixed_image_size + (3,) - - # Check DVF mode network output shapes - Pass - dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward( - backbone=local_net, - moving_image=tf.ones((batch_size,) + moving_image_size), - fixed_image=tf.ones((batch_size,) + fixed_image_size), - moving_label=tf.ones((batch_size,) + moving_image_size), - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - output_dvf=True, - ) - assert dvf.shape == (batch_size,) + fixed_image_size + (3,) - assert ddf.shape == (batch_size,) + fixed_image_size + (3,) - assert pred_fixed_image.shape == (batch_size,) + fixed_image_size - assert pred_fixed_label.shape == (batch_size,) + fixed_image_size - assert grid_fixed.shape == fixed_image_size + (3,) - - -def test_build_ddf_dvf_model(): - """ - Testing that build_ddf_dvf_model function returns the tensors with correct shapes - """ - moving_image_size = (1, 3, 5) - fixed_image_size = (2, 4, 6) - batch_size = 1 - train_config = { - "method": "ddf", - "backbone": { - "name": "local", - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - "loss": { - "image": {"name": "lncc", "weight": 0.1}, - "label": { - "name": "dice", - "weight": 1, - "scales": [0, 1, 2, 4, 8, 16, 32], - }, - "regularization": {"weight": 0.0, "name": "bending"}, - }, - } - - # Create DDF model - model_ddf = build_ddf_dvf_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=1, - labeled=True, - batch_size=batch_size, - train_config=train_config, - registry=REGISTRY, - ) - - # Create DVF model - train_config["method"] = "dvf" - model_dvf = build_ddf_dvf_model( - moving_image_size=moving_image_size, - fixed_image_size=fixed_image_size, - index_size=1, - labeled=True, - batch_size=batch_size, - train_config=train_config, - registry=REGISTRY, - ) - inputs = { - "moving_image": tf.ones((batch_size,) + moving_image_size), - "fixed_image": tf.ones((batch_size,) + fixed_image_size), - "indices": 1, - "moving_label": tf.ones((batch_size,) + moving_image_size), - "fixed_label": tf.ones((batch_size,) + fixed_image_size), - } - outputs_ddf = model_ddf(inputs) - outputs_dvf = model_dvf(inputs) - - expected_outputs_keys = ["dvf", "ddf", "pred_fixed_label"] - assert all(keys in expected_outputs_keys for keys in outputs_ddf) - assert outputs_ddf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size - assert outputs_ddf["ddf"].shape == (batch_size,) + fixed_image_size + (3,) - with pytest.raises(KeyError): - outputs_ddf["dvf"] - - assert all(keys in expected_outputs_keys for keys in outputs_dvf) - assert outputs_dvf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size - assert outputs_dvf["dvf"].shape == (batch_size,) + fixed_image_size + (3,) - assert outputs_dvf["ddf"].shape == (batch_size,) + fixed_image_size + (3,) diff --git a/test/unit/test_network_util.py b/test/unit/test_network_util.py deleted file mode 100644 index bdf973770..000000000 --- a/test/unit/test_network_util.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding=utf-8 - -""" -Tests for deepreg/model/network/util -""" -import pytest - -import deepreg.model.network.util as util -from deepreg.registry import REGISTRY - - -class TestBuildBackbone: - def test_wrong_image_size(self): - with pytest.raises(ValueError) as err_info: - util.build_backbone( - image_size=(1, 1, 1, 1), - out_channels=1, - config={}, - method_name="ddf", - registry=REGISTRY, - ) - assert "image_size must be tuple of length 3" in str(err_info.value) - - def test_wrong_method_name(self): - with pytest.raises(ValueError) as err_info: - util.build_backbone( - image_size=(1, 2, 3), - out_channels=1, - config={"backbone": "local"}, - method_name="wrong", - registry=REGISTRY, - ) - assert "method name has to be one of ddf/dvf/conditional/affine" in str( - err_info.value - ) - - @pytest.mark.parametrize("method_name", ["ddf", "dvf", "conditional", "affine"]) - @pytest.mark.parametrize("out_channels", [1, 2, 3]) - @pytest.mark.parametrize("backbone_name", ["local", "global"]) - def test_local_global_backbone(self, method_name, out_channels, backbone_name): - util.build_backbone( - image_size=(2, 3, 4), - out_channels=out_channels, - config={ - "name": backbone_name, - "num_channel_initial": 4, - "extract_levels": [1, 2, 3], - }, - method_name=method_name, - registry=REGISTRY, - ) - - @pytest.mark.parametrize("method_name", ["ddf", "dvf", "conditional", "affine"]) - @pytest.mark.parametrize("out_channels", [1, 2, 3]) - def test_unet_backbone(self, method_name, out_channels): - util.build_backbone( - image_size=(2, 3, 4), - out_channels=out_channels, - config={ - "name": "unet", - "num_channel_initial": 4, - "depth": 4, - }, - method_name=method_name, - registry=REGISTRY, - ) - - -class TestBuildInputs: - moving_image_size = (2, 3, 4) - fixed_image_size = (1, 2, 3) - index_size = 3 - batch_size = 2 - - @pytest.mark.parametrize("labeled", [True, False]) - def test_input_shape(self, labeled): - ( - moving_image, - fixed_image, - moving_label, - fixed_label, - indices, - ) = util.build_inputs( - moving_image_size=self.moving_image_size, - fixed_image_size=self.fixed_image_size, - index_size=self.index_size, - batch_size=self.batch_size, - labeled=labeled, - ) - assert moving_image.shape == (self.batch_size, *self.moving_image_size) - assert fixed_image.shape == (self.batch_size, *self.fixed_image_size) - assert indices.shape == (self.batch_size, self.index_size) - if labeled: - assert moving_label.shape == (self.batch_size, *self.moving_image_size) - assert fixed_label.shape == (self.batch_size, *self.fixed_image_size) - else: - assert moving_label is None - assert fixed_label is None diff --git a/test/unit/test_train.py b/test/unit/test_train.py index 1d0f554d4..5cfccbec9 100644 --- a/test/unit/test_train.py +++ b/test/unit/test_train.py @@ -6,6 +6,7 @@ """ import os +import shutil import pytest @@ -49,9 +50,18 @@ def test_max_epochs(self, max_epochs, expected_epochs, expected_save_period): assert got_config["train"]["save_period"] == expected_save_period -def test_train_and_predict_main(): +@pytest.mark.parametrize( + "config_paths", + [ + ["config/unpaired_labeled_ddf.yaml"], + ["config/unpaired_labeled_ddf.yaml", "config/test/affine.yaml"], + ], +) +def test_train_and_predict_main(config_paths): """ Test main in train and predict by checking it can run. + + :param config_paths: list of file paths for configuration. """ train_main( args=[ @@ -60,8 +70,8 @@ def test_train_and_predict_main(): "--log_dir", "test_train", "--config_path", - "config/unpaired_labeled_ddf.yaml", ] + + config_paths ) # check output folders @@ -92,3 +102,6 @@ def test_train_and_predict_main(): assert os.path.isfile("logs/test_predict/test/metrics.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv") + + shutil.rmtree("logs/test_train") + shutil.rmtree("logs/test_predict")