This tutorial will take an image intensity based loss function (mutual information) as an example to show how to add a new loss to DeepReg.
Three main types of the loss functions are supported in DeepReg:
intensity (image) based loss
, label based loss
and deformation loss
. See
Docs here for details. The corresponding source files for the
losses is included in deepreg/model/loss
.
The first step is to add your own loss function, which should take at least 2
parameters, y_true
for the ground truth and y_pred
for the prediction. e.g. in
deepreg/model/loss/image.py
. The loss can be defined as:
def global_mutual_information(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
"""
differentiable global mutual information loss via Parzen windowing method.
reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5
:y_true: shape = (batch, dim1, dim2, dim3, ch)
:y_pred: shape = (batch, dim1, dim2, dim3, ch)
:return: shape = (batch,)
"""
...
return tf.reduce_sum(pab * tf.math.log(pab / papb + eps), axis=[1, 2])
In order to be compatible with the pipeline in DeepReg, another modification is needed
in deeepreg/model/loss/image.py
. The modification is simply adding an elif
branch to
the dissimilarity_fn
function. The following code block shows the added elif
branch
where we use "gmi"
to represent the global mutual information:
def dissimilarity_fn(
y_true: tf.Tensor, y_pred: tf.Tensor, name: str, **kwargs
) -> tf.Tensor:
"""
:param y_true: fixed_image, shape = (batch, f_dim1, f_dim2, f_dim3)
:param y_pred: warped_moving_image, shape = (batch, f_dim1, f_dim2, f_dim3)
:param name: name of the dissimilarity function
:param kwargs: absorb additional parameters
:return: shape = (batch,)
"""
assert name in ["lncc", "ssd", "gmi"]
# shape = (batch, f_dim1, f_dim2, f_dim3, 1)
y_true = tf.expand_dims(y_true, axis=4)
y_pred = tf.expand_dims(y_pred, axis=4)
if name == "lncc":
return -local_normalized_cross_correlation(y_true, y_pred, **kwargs)
elif name == "ssd":
return ssd(y_true, y_pred)
elif name == "gmi":
return -global_mutual_information(y_true, y_pred)
else:
raise ValueError("Unknown loss type.")
Add corresponding unit test for the new added functions to deepreg/test/unit
. This is
optional for the users. Everyone is warmly welcome to make contribution to DeepReg.
Please follow our contribution guidelines here.
We take the
paired prostate MR and Ultrasound registration demo
as an example. In order to use the newly added loss, all that is needed is to modify the
loss configuration in the train configuration file paired_mrus_prostate_train.yaml
(lines 10-15):
# define the loss function for training
loss:
dissimilarity:
image:
name: "gmi"
weight: 1.0
That's it. Follow the instructions in the demo to begin training with the newly added loss.