Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d4de169
Issue #557: add ddf, dvf models
mathpluscode Jan 8, 2021
8aff6bb
Issue #557: add conditional and affine model, fix predict
mathpluscode Jan 8, 2021
5c0a07c
Issue #557: rewrite tests for DDF DVF models
mathpluscode Jan 9, 2021
52e9569
Issue #557: add tests for conditional model
mathpluscode Jan 9, 2021
020a72b
Issue #557: remove outdated code
mathpluscode Jan 9, 2021
83d6db7
Issue #557: change models to use dict for input/output
mathpluscode Jan 9, 2021
d203343
Issue #557: fix pylint linting
mathpluscode Jan 9, 2021
18ba530
Issue #557: complete network docstrings
mathpluscode Jan 9, 2021
c4d7fb1
Issue #557: fix affine model and test it
mathpluscode Jan 9, 2021
0d94e67
Issue #557: cover get config and ignore call
mathpluscode Jan 9, 2021
442876d
Issue #468: support mixed precision
mathpluscode Jan 9, 2021
a12c0e0
Issue #468: default not use mixed precision
mathpluscode Jan 9, 2021
16a125a
Issue #557: change loss reduction strategy to fix multi-gpu training
mathpluscode Jan 9, 2021
c4cb904
Issue #557: fix unit tests linked to loss reduction strategy
mathpluscode Jan 9, 2021
824004f
Revert "Issue #468: default not use mixed precision"
mathpluscode Jan 9, 2021
2e42b23
Merge branch '615-resample-extrapolation' into 557-refactoring-models
mathpluscode Jan 10, 2021
f597085
Merge remote-tracking branch 'origin/main' into 557-refactoring-models
mathpluscode Jan 10, 2021
1f87e26
Issue #557: update grouped prostate ckpt
mathpluscode Jan 10, 2021
24d9017
Merge branch '617-refactoring-of-data-augmentation' into 557-refactor…
mathpluscode Jan 10, 2021
a136d59
Issue #557: do not reserve all GPU memory in demos
mathpluscode Jan 10, 2021
140c53a
Issue #557: update grouped mr heart ckpt
mathpluscode Jan 10, 2021
07c660a
Merge remote-tracking branch 'origin/main' into 557-refactoring-models
mathpluscode Jan 11, 2021
e052b99
Issue #557: update paired mrus prostate
mathpluscode Jan 11, 2021
4cb53d0
Merge remote-tracking branch 'origin/main' into 557-refactoring-models
mathpluscode Jan 11, 2021
95a61fd
Issue #557: update paired ct lung checkpoint
mathpluscode Jan 13, 2021
557a6dc
Revert "Issue #557: update paired ct lung checkpoint"
mathpluscode Jan 13, 2021
c38aa9c
Issue #557: update paired ct lung ckpt
mathpluscode Jan 13, 2021
c2adf2e
Issue #557: update unpaired us prostate
mathpluscode Jan 16, 2021
d8fee7b
Issue #557: update unpaired ct lung
mathpluscode Jan 16, 2021
f2376dc
Issue #557: update ckpt paths
mathpluscode Jan 16, 2021
98e5943
Merge remote-tracking branch 'origin/main' into 557-refactoring-models
mathpluscode Jan 16, 2021
32ff565
Issue #557: update demo visualization and add warning notes
mathpluscode Jan 16, 2021
8ff134d
Merge remote-tracking branch 'origin/main' into 557-refactoring-models
mathpluscode Jan 17, 2021
e9db939
Issue #557: make loss building more robust
mathpluscode Jan 18, 2021
9841cff
Issue #557: adjust config doc regarding changes
mathpluscode Jan 18, 2021
839c4ab
Issue #557: replace strange space tokens
mathpluscode Jan 18, 2021
c5a4419
Issue #557: use deepcopy instead of naive copy
mathpluscode Jan 18, 2021
f63dd8f
Issue #557: add tests for build loss
mathpluscode Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions config/test/affine.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
train:
method: "ddf" # ddf / dvf / conditional
backbone:
name: "global"
num_channel_initial: 1
extract_levels: [0, 1, 2, 3, 4]
loss:
image:
name: "lncc"
weight: 0.1
label:
weight: 1.0
name: "dice"
scales: [0, 1, 2, 4, 8, 16, 32]
regularization:
weight: 0.5
name: "bending"
preprocess:
batch_size: 2
shuffle_buffer_num_batch: 1
optimizer:
name: "adam"
adam:
learning_rate: 1.0e-5
sgd:
learning_rate: 1.0e-4
momentum: 0.9
rms:
learning_rate: 1.0e-4
momentum: 0.9
epochs: 2
save_period: 2
2 changes: 1 addition & 1 deletion config/unpaired_labeled_ddf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dataset:

train:
# define neural network structure
method: "ddf" # options include "ddf", "dvf", "conditional" and "affine"
method: "ddf" # options include "ddf", "dvf", "conditional"
backbone:
name: "local" # options include "local", "unet" and "global" - use "global" when method=="affine"
num_channel_initial: 1 # number of initial channel in local net, controls the size of the network
Expand Down
3 changes: 2 additions & 1 deletion deepreg/dataset/loader/grouped_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Read https://deepreg.readthedocs.io/en/latest/api/loader.html#module-deepreg.dataset.loader.grouped_loader for more details.
"""
import random
from copy import deepcopy
from typing import List

from deepreg.dataset.loader.interface import (
Expand Down Expand Up @@ -264,7 +265,7 @@ def sample_index_generator(self):
else:
# sample indices are pre-calculated
assert self.sample_indices is not None
sample_indices = self.sample_indices.copy()
sample_indices = deepcopy(self.sample_indices)
rnd.shuffle(sample_indices) # shuffle in place
for sample_index in sample_indices:
group_index1, image_index1, group_index2, image_index2 = sample_index
Expand Down
1 change: 1 addition & 0 deletions deepreg/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
import deepreg.model.backbone
import deepreg.model.loss
import deepreg.model.network
14 changes: 8 additions & 6 deletions deepreg/model/backbone/global_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def __init__(
units=12, bias_initializer=self.transform_initial
)

def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
def call(
self, inputs: tf.Tensor, training=None, mask=None
) -> (tf.Tensor, tf.Tensor):
"""
Build GlobalNet graph based on built layers.

Expand All @@ -88,10 +90,10 @@ def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
) # level E of encoding

# predict affine parameters theta of shape = [batch, 4, 3]
self.theta = self._dense_layer(h_out)
self.theta = tf.reshape(self.theta, shape=(-1, 4, 3))
theta = self._dense_layer(h_out)
theta = tf.reshape(theta, shape=(-1, 4, 3))

# warp the reference grid with affine parameters to output a ddf
grid_warped = layer_util.warp_grid(self.reference_grid, self.theta)
output = grid_warped - self.reference_grid
return output
grid_warped = layer_util.warp_grid(self.reference_grid, theta)
ddf = grid_warped - self.reference_grid
return ddf, theta
12 changes: 6 additions & 6 deletions deepreg/model/loss/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class SumSquaredDifference(tf.keras.losses.Loss):

def __init__(
self,
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
Comment thread
YipengHu marked this conversation as resolved.
name: str = "SumSquaredDifference",
):
"""
Init.

:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down Expand Up @@ -57,15 +57,15 @@ def __init__(
self,
num_bins: int = 23,
sigma_ratio: float = 0.5,
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "GlobalMutualInformation",
):
"""
Init.

:param num_bins: number of bins for intensity, the default value is empirical.
:param sigma_ratio: a hyper param for gaussian function
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down Expand Up @@ -248,15 +248,15 @@ def __init__(
self,
kernel_size: int = 9,
kernel_type: str = "rectangular",
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "LocalNormalizedCrossCorrelation",
):
"""
Init.

:param kernel_size: int. Kernel size or kernel sigma for kernel_type='gauss'.
:param kernel_type: str, rectangular, triangular or gaussian
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: name of the loss
"""
Expand Down
16 changes: 8 additions & 8 deletions deepreg/model/loss/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def __init__(
self,
scales: Optional[List] = None,
kernel: str = "gaussian",
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "MultiScaleLoss",
):
"""
Init.

:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
neg_weight: float = 0.0,
scales: Optional[List] = None,
kernel: str = "gaussian",
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "DiceScore",
):
"""
Expand All @@ -154,7 +154,7 @@ def __init__(
:param neg_weight: weight for negative class.
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
neg_weight: float = 0.0,
scales: Optional[List] = None,
kernel: str = "gaussian",
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "CrossEntropy",
):
"""
Expand All @@ -224,7 +224,7 @@ def __init__(
:param neg_weight: weight for negative class
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(
binary: bool = False,
scales: Optional[List] = None,
kernel: str = "gaussian",
reduction: str = tf.keras.losses.Reduction.AUTO,
reduction: str = tf.keras.losses.Reduction.SUM,
name: str = "JaccardIndex",
):
"""
Expand All @@ -285,7 +285,7 @@ def __init__(
:param binary: if True, project y_true, y_pred to 0 or 1.
:param scales: list of scalars or None, if None, do not apply any scaling.
:param kernel: gaussian or cauchy.
:param reduction: using AUTO reduction,
:param reduction: using SUM reduction over batch axis,
calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
:param name: str, name of the loss.
"""
Expand Down
Loading