diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3256eacfe6..49ed4c9e6c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -744,6 +744,14 @@ Smooth Field :members: :special-members: __call__ +`RandSmoothDeform` +"""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeform.png + :alt: example of RandSmoothDeform +.. autoclass:: RandSmoothDeform + :members: + :special-members: __call__ + Utility ^^^^^^^ @@ -1553,6 +1561,14 @@ Smooth Field (Dict) :members: :special-members: __call__ +`RandSmoothDeformd` +""""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RandSmoothDeformd.png + :alt: example of RandSmoothDeformd +.. autoclass:: RandSmoothDeformd + :members: + :special-members: __call__ + Utility (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8db0a2cf7a..c9cd1b4e0d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -277,8 +277,13 @@ VoteEnsembled, VoteEnsembleDict, ) -from .smooth_field.array import RandSmoothFieldAdjustContrast, RandSmoothFieldAdjustIntensity, SmoothField -from .smooth_field.dictionary import RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from .smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, + SmoothField, +) +from .smooth_field.dictionary import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd from .spatial.array import ( Affine, AffineGrid, diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py index 50758bfba2..31ce76e5b5 100644 --- a/monai/transforms/smooth_field/array.py +++ b/monai/transforms/smooth_field/array.py @@ -14,90 +14,163 @@ from typing import Any, Optional, Sequence, Union import numpy as np +import torch +from torch.nn.functional import grid_sample, interpolate import monai -from monai.transforms.spatial.array import Resize -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform -from monai.transforms.utils import rescale_array -from monai.utils import InterpolateMode, ensure_tuple +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import Randomizable, RandomizableTransform +from monai.transforms.utils_pytorch_numpy_unification import moveaxis +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.module import look_up_option, pytorch_after +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor -__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity"] +__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"] class SmoothField(Randomizable): """ - Generate a smooth field array by defining a smaller randomized field and then resizing to the desired size. This - exploits interpolation to create a smoothly varying field used for other applications. + Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size. + + This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized + field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using + `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random + array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device. Args: - spatial_size: final output size of the array rand_size: size of the randomized field to start from - padder: optional transform to add padding to the randomized field - mode: interpolation mode to use when upsampling - align_corners: if True align the corners when upsampling field + pad: number of pixels/voxels along the edges of the field to pad with `pad_val` + pad_val: value with which to pad field edges low: low value for randomized field high: high value for randomized field channels: number of channels of final output + spatial_size: final output size of the array, None to produce original uninterpolated field + mode: interpolation mode for resizing the field + align_corners: if True align the corners when upsampling field + device: Pytorch device to define field on """ def __init__( self, - spatial_size: Union[Sequence[int], int], - rand_size: Union[Sequence[int], int], - padder: Optional[Transform] = None, - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, - align_corners: Optional[bool] = None, + rand_size: Sequence[int], + pad: int = 0, + pad_val: float = 0, low: float = -1.0, high: float = 1.0, channels: int = 1, + spatial_size: Optional[Sequence[int]] = None, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + device: Optional[torch.device] = None, ): - self.resizer: Transform = Resize(spatial_size, mode=mode, align_corners=align_corners) - self.rand_size: tuple = ensure_tuple(rand_size) - self.padder: Optional[Transform] = padder - self.field: Optional[np.ndarray] = None - self.low: float = low - self.high: float = high - self.channels: int = channels + self.rand_size = tuple(rand_size) + self.pad = pad + self.low = low + self.high = high + self.channels = channels + self.mode = mode + self.align_corners = align_corners + self.device = device + + self.spatial_size: Optional[Sequence[int]] = None + self.spatial_zoom: Optional[Sequence[float]] = None + + if low >= high: + raise ValueError("Value for `low` must be less than `high` otherwise field will be zeros") + + self.total_rand_size = tuple(rs + self.pad * 2 for rs in self.rand_size) + + self.field = torch.ones((1, self.channels) + self.total_rand_size, device=self.device) * pad_val + + self.crand_size = (self.channels,) + self.rand_size + + pad_slice = slice(None) if self.pad == 0 else slice(self.pad, -self.pad) + self.rand_slices = (0, slice(None)) + (pad_slice,) * len(self.rand_size) + + self.set_spatial_size(spatial_size) def randomize(self, data: Optional[Any] = None) -> None: - self.field = self.R.uniform(self.low, self.high, (self.channels,) + self.rand_size) # type: ignore - if self.padder is not None: - self.field = self.padder(self.field) + self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size)) + + def set_spatial_size(self, spatial_size: Optional[Sequence[int]]) -> None: + """ + Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given + dimension, or not interpolate at all if None. + + Args: + spatial_size: new size to interpolate to, or None to not interpolate + """ + if spatial_size is None: + self.spatial_size = None + self.spatial_zoom = None + else: + self.spatial_size = tuple(spatial_size) + self.spatial_zoom = tuple(s / f for s, f in zip(self.spatial_size, self.total_rand_size)) + + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.mode = mode + + def __call__(self, randomize=False) -> torch.Tensor: + if randomize: + self.randomize() + + field = self.field.clone() - def __call__(self): - resized_field = self.resizer(self.field) + if self.spatial_zoom is not None: + resized_field = interpolate( # type: ignore + input=field, # type: ignore + scale_factor=self.spatial_zoom, + mode=look_up_option(self.mode, InterpolateMode).value, + align_corners=self.align_corners, + recompute_scale_factor=False, + ) - return rescale_array(resized_field, self.field.min(), self.field.max()) + mina = resized_field.min() + maxa = resized_field.max() + minv = self.field.min() + maxv = self.field.max() + + # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks + norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina) + field = norm_field.mul_(maxv - minv).add_(minv) + + return field class RandSmoothFieldAdjustContrast(RandomizableTransform): """ - Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation. This - uses SmoothFieldAdjustContrast and SmoothField internally. + Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input + values by the power of the smooth field so the range of values given by `gamma` should be chosen with this + in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided. + Afte the contrast is adjusted the values of the result are rescaled to the range of the original input. Args: spatial_size: size of input array's spatial dimensions rand_size: size of the randomized field to start from - padder: optional transform to add padding to the randomized field + pad: number of pixels/voxels along the edges of the field to pad with 1 mode: interpolation mode to use when upsampling align_corners: if True align the corners when upsampling field prob: probability transform is applied gamma: (min, max) range for exponential field + device: Pytorch device to define field on """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, - spatial_size: Union[Sequence[int], int], - rand_size: Union[Sequence[int], int], - padder: Optional[Transform] = None, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5), + device: Optional[torch.device] = None, ): super().__init__(prob) @@ -109,7 +182,18 @@ def __init__( self.gamma = (min(gamma), max(gamma)) - self.sfield = SmoothField(spatial_size, rand_size, padder, mode, align_corners, self.gamma[0], self.gamma[1]) + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -124,7 +208,10 @@ def randomize(self, data: Optional[Any] = None) -> None: if self._do_transform: self.sfield.randomize() - def __call__(self, img: np.ndarray, randomize: bool = True): + def set_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ @@ -139,44 +226,51 @@ def __call__(self, img: np.ndarray, randomize: bool = True): img_rng = img_max - img_min field = self.sfield() - field, *_ = convert_to_dst_type(field, img) + rfield, *_ = convert_to_dst_type(field, img) - img = (img - img_min) / max(img_rng, 1e-10) # rescale to unit values - img = img ** field # contrast is changed by raising image data to a power, in this case the field + # everything below here is to be computed using the destination type (numpy, tensor, etc.) - out = (img * img_rng) + img_min # rescale back to the original image value range + img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values + img = img ** rfield # contrast is changed by raising image data to a power, in this case the field - out, *_ = convert_to_dst_type(out, img, img.dtype) + out = (img * img_rng) + img_min # rescale back to the original image value range return out class RandSmoothFieldAdjustIntensity(RandomizableTransform): """ - Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation. This - uses SmoothField internally. + Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation. + + This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the + edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the + inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values + of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than + the original value range. Args: spatial_size: size of input array rand_size: size of the randomized field to start from - padder: optional transform to add padding to the randomized field + pad: number of pixels/voxels along the edges of the field to pad with 1 mode: interpolation mode to use when upsampling align_corners: if True align the corners when upsampling field prob: probability transform is applied gamma: (min, max) range of intensity multipliers + device: Pytorch device to define field on """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, - spatial_size: Union[Sequence[int], int], - rand_size: Union[Sequence[int], int], - padder: Optional[Transform] = None, - mode: Union[monai.utils.InterpolateMode, str] = monai.utils.InterpolateMode.AREA, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.1, 1.0), + device: Optional[torch.device] = None, ): super().__init__(prob) @@ -188,7 +282,18 @@ def __init__( self.gamma = (min(gamma), max(gamma)) - self.sfield = SmoothField(spatial_size, rand_size, padder, mode, align_corners, self.gamma[0], self.gamma[1]) + self.sfield = SmoothField( + rand_size=rand_size, + pad=pad, + pad_val=1, + low=self.gamma[0], + high=self.gamma[1], + channels=1, + spatial_size=spatial_size, + mode=mode, + align_corners=align_corners, + device=device, + ) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -203,7 +308,10 @@ def randomize(self, data: Optional[Any] = None) -> None: if self._do_transform: self.sfield.randomize() - def __call__(self, img: np.ndarray, randomize: bool = True): + def set_mode(self, mode: Union[InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ @@ -217,7 +325,137 @@ def __call__(self, img: np.ndarray, randomize: bool = True): field = self.sfield() rfield, *_ = convert_to_dst_type(field, img) + # everything below here is to be computed using the destination type (numpy, tensor, etc.) + out = img * rfield - out, *_ = convert_to_dst_type(out, img, img.dtype) return out + + +class RandSmoothDeform(RandomizableTransform): + """ + Deform an image using a random smooth field and Pytorch's grid_sample. + + The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension + of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in + every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size. + + Args: + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions, single min/max value or min/max pair + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH] + + def __init__( + self, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleMode, str] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + device: Optional[torch.device] = None, + ): + super().__init__(prob) + + self.grid_dtype = grid_dtype + self.grid_mode = grid_mode + self.def_range = def_range + self.device = device + self.grid_align_corners = grid_align_corners + self.grid_padding_mode = grid_padding_mode + + if isinstance(def_range, (int, float)): + self.def_range = (-def_range, def_range) + else: + if len(def_range) != 2: + raise ValueError("Argument `def_range` should be a number or pair of numbers.") + + self.def_range = (min(def_range), max(def_range)) + + self.sfield = SmoothField( + spatial_size=spatial_size, + rand_size=rand_size, + pad=pad, + low=self.def_range[0], + high=self.def_range[1], + channels=len(rand_size), + mode=field_mode, + align_corners=align_corners, + device=device, + ) + + grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:] + grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space] + + if pytorch_after(1, 10): + grid = torch.meshgrid(*grid_ranges, indexing="ij") + else: + grid = torch.meshgrid(*grid_ranges) + + self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "Randomizable": + super().set_random_state(seed, state) + self.sfield.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + + if self._do_transform: + self.sfield.randomize() + + def set_field_mode(self, mode: Union[monai.utils.InterpolateMode, str]) -> None: + self.sfield.set_mode(mode) + + def set_grid_mode(self, mode: Union[monai.utils.GridSampleMode, str]) -> None: + self.grid_mode = mode + + def __call__( + self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None + ) -> NdarrayOrTensor: + if randomize: + self.randomize() + + if not self._do_transform: + return img + + device = device if device is not None else self.device + + field = self.sfield() + + dgrid = self.grid + field.to(self.grid_dtype) + dgrid = moveaxis(dgrid, 1, -1) # type: ignore + + img_t = convert_to_tensor(img[None], torch.float32, device) + + out = grid_sample( + input=img_t, + grid=dgrid, + mode=look_up_option(self.grid_mode, GridSampleMode).value, + align_corners=self.grid_align_corners, + padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode).value, + ) + + out_t, *_ = convert_to_dst_type(out.squeeze(0), img) + + return out_t diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py index b14bde6c9f..4eca541fcc 100644 --- a/monai/transforms/smooth_field/dictionary.py +++ b/monai/transforms/smooth_field/dictionary.py @@ -13,31 +13,45 @@ from typing import Any, Hashable, Mapping, Optional, Sequence, Union import numpy as np +import torch from monai.config import KeysCollection -from monai.transforms.smooth_field.array import RandSmoothFieldAdjustContrast, RandSmoothFieldAdjustIntensity -from monai.transforms.transform import MapTransform, RandomizableTransform, Transform -from monai.utils import InterpolateMode +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple_rep from monai.utils.enums import TransformBackends -__all__ = ["RandSmoothFieldAdjustContrastd", "RandSmoothFieldAdjustIntensityd"] +__all__ = ["RandSmoothFieldAdjustContrastd", "RandSmoothFieldAdjustIntensityd", "RandSmoothDeformd"] + + +InterpolateModeType = Union[InterpolateMode, str] +GridSampleModeType = Union[GridSampleMode, str] class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform): """ - Dictionary version of RandSmoothFieldAdjustContrast. The field is randomized once per invocation by default so the - same field is applied to every selected key. + Dictionary version of RandSmoothFieldAdjustContrast. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. Args: keys: key names to apply the augment to spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions rand_size: size of the randomized field to start from - padder: optional transform to add padding to the randomized field + pad: number of pixels/voxels along the edges of the field to pad with 0 mode: interpolation mode to use when upsampling align_corners: if True align the corners when upsampling field prob: probability transform is applied gamma: (min, max) range for exponential field apply_same_field: if True, apply the same field to each key, otherwise randomize individually + device: Pytorch device to define field on """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -45,28 +59,32 @@ class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - spatial_size: Union[Sequence[int], int], - rand_size: Union[Sequence[int], int], - padder: Optional[Transform] = None, - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, align_corners: Optional[bool] = None, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5), apply_same_field: bool = True, + device: Optional[torch.device] = None, ): RandomizableTransform.__init__(self, prob) MapTransform.__init__(self, keys) + self.apply_same_field = apply_same_field + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.trans = RandSmoothFieldAdjustContrast( spatial_size=spatial_size, rand_size=rand_size, - padder=padder, - mode=mode, + pad=pad, + mode=self.mode[0], align_corners=align_corners, prob=1.0, gamma=gamma, + device=device, ) - self.apply_same_field = apply_same_field def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -81,7 +99,7 @@ def randomize(self, data: Optional[Any] = None) -> None: if self._do_transform: self.trans.randomize() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() if not self._do_transform: @@ -89,10 +107,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. d = dict(data) - for key in self.key_iterator(d): + for idx, key in enumerate(self.key_iterator(d)): if not self.apply_same_field: self.randomize() # new field for every key + self.trans.set_mode(self.mode[idx % len(self.mode)]) d[key] = self.trans(d[key], False) return d @@ -100,19 +119,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform): """ - Dictionary version of RandSmoothFieldAdjustIntensity. The field is randomized once per invocation by default so - the same field is applied to every selected key. + Dictionary version of RandSmoothFieldAdjustIntensity. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with + one for each key in `keys`. Args: keys: key names to apply the augment to spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions rand_size: size of the randomized field to start from - padder: optional transform to add padding to the randomized field + pad: number of pixels/voxels along the edges of the field to pad with 0 mode: interpolation mode to use when upsampling align_corners: if True align the corners when upsampling field prob: probability transform is applied gamma: (min, max) range of intensity multipliers apply_same_field: if True, apply the same field to each key, otherwise randomize individually + device: Pytorch device to define field on """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -120,28 +143,32 @@ class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform): def __init__( self, keys: KeysCollection, - spatial_size: Union[Sequence[int], int], - rand_size: Union[Sequence[int], int], - padder: Optional[Transform] = None, - mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, align_corners: Optional[bool] = None, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.1, 1.0), apply_same_field: bool = True, + device: Optional[torch.device] = None, ): RandomizableTransform.__init__(self, prob) MapTransform.__init__(self, keys) + self.apply_same_field = apply_same_field + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.trans = RandSmoothFieldAdjustIntensity( spatial_size=spatial_size, rand_size=rand_size, - padder=padder, - mode=mode, + pad=pad, + mode=self.mode[0], align_corners=align_corners, prob=1.0, gamma=gamma, + device=device, ) - self.apply_same_field = apply_same_field def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -154,7 +181,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.trans.randomize() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() if not self._do_transform: @@ -162,10 +189,108 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. d = dict(data) - for key in self.key_iterator(d): + for idx, key in enumerate(self.key_iterator(d)): if not self.apply_same_field: self.randomize() # new field for every key + self.trans.set_mode(self.mode[idx % len(self.mode)]) d[key] = self.trans(d[key], False) return d + + +class RandSmoothDeformd(RandomizableTransform, MapTransform): + """ + Dictionary version of RandSmoothDeform. + + The field is randomized once per invocation by default so the same field is applied to every selected key. The + `field_mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values + with one for each key in `keys`. Similarly the `grid_mode` parameter can be one value or one per key. + + Args: + keys: key names to apply the augment to + spatial_size: input array size to which deformation grid is interpolated + rand_size: size of the randomized field to start from + pad: number of pixels/voxels along the edges of the field to pad with 0 + field_mode: interpolation mode to use when upsampling the deformation field + align_corners: if True align the corners when upsampling field + prob: probability transform is applied + def_range: value of the deformation range in image size fractions + grid_dtype: type for the deformation grid calculated from the field + grid_mode: interpolation mode used for sampling input using deformation grid + grid_padding_mode: padding mode used for sampling input using deformation grid + grid_align_corners: if True align the corners when sampling the deformation grid + apply_same_field: if True, apply the same field to each key, otherwise randomize individually + device: Pytorch device to define field on + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + spatial_size: Sequence[int], + rand_size: Sequence[int], + pad: int = 0, + field_mode: Union[InterpolateModeType, Sequence[InterpolateModeType]] = InterpolateMode.AREA, + align_corners: Optional[bool] = None, + prob: float = 0.1, + def_range: Union[Sequence[float], float] = 1.0, + grid_dtype=torch.float32, + grid_mode: Union[GridSampleModeType, Sequence[GridSampleModeType]] = GridSampleMode.NEAREST, + grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + grid_align_corners: Optional[bool] = False, + apply_same_field: bool = True, + device: Optional[torch.device] = None, + ): + RandomizableTransform.__init__(self, prob) + MapTransform.__init__(self, keys) + + self.field_mode = ensure_tuple_rep(field_mode, len(self.keys)) + self.grid_mode = ensure_tuple_rep(grid_mode, len(self.keys)) + self.apply_same_field = apply_same_field + + self.trans = RandSmoothDeform( + rand_size=rand_size, + spatial_size=spatial_size, + pad=pad, + field_mode=self.field_mode[0], + align_corners=align_corners, + prob=1.0, + def_range=def_range, + grid_dtype=grid_dtype, + grid_mode=self.grid_mode[0], + grid_padding_mode=grid_padding_mode, + grid_align_corners=grid_align_corners, + device=device, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandSmoothDeformd": + super().set_random_state(seed, state) + self.trans.set_random_state(seed, state) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + self.trans.randomize() + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + + if not self._do_transform: + return data + + d = dict(data) + + for idx, key in enumerate(self.key_iterator(d)): + if not self.apply_same_field: + self.randomize() # new field for every key + + self.trans.set_field_mode(self.field_mode[idx % len(self.field_mode)]) + self.trans.set_grid_mode(self.grid_mode[idx % len(self.grid_mode)]) + + d[key] = self.trans(d[key], False, self.trans.device) + + return d diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index 64d4e345c1..ab282d5332 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -145,8 +145,16 @@ ) from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd -from monai.transforms.smooth_field.array import RandSmoothFieldAdjustContrast, RandSmoothFieldAdjustIntensity -from monai.transforms.smooth_field.dictionary import RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from monai.transforms.smooth_field.array import ( + RandSmoothDeform, + RandSmoothFieldAdjustContrast, + RandSmoothFieldAdjustIntensity, +) +from monai.transforms.smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustIntensityd, +) from monai.transforms.spatial.array import ( GridDistortion, Rand2DElastic, @@ -689,20 +697,38 @@ def create_transform_im( data, ) create_transform_im( - RandSmoothFieldAdjustContrast, dict(spatial_size=(217, 217, 217), rand_size=(100, 100, 100), prob=1.0), data + RandSmoothFieldAdjustContrast, dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), data ) create_transform_im( RandSmoothFieldAdjustContrastd, - dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(100, 100, 100), prob=1.0), + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), data, ) create_transform_im( RandSmoothFieldAdjustIntensity, - dict(spatial_size=(217, 217, 217), rand_size=(100, 100, 100), prob=1.0, gamma=(0.5, 4.5)), + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), data, ) create_transform_im( RandSmoothFieldAdjustIntensityd, - dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(100, 100, 100), prob=1.0, gamma=(0.5, 4.5)), + dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)), + data, + ) + + create_transform_im( + RandSmoothDeform, + dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, def_range=0.05, grid_mode="blinear"), + data, + ) + create_transform_im( + RandSmoothDeformd, + dict( + keys=keys, + spatial_size=(217, 217, 217), + rand_size=(10, 10, 10), + prob=1.0, + def_range=0.05, + grid_mode="blinear", + ), data, ) diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py index 11343dcb9a..5849b96167 100644 --- a/tests/test_smooth_field.py +++ b/tests/test_smooth_field.py @@ -10,47 +10,72 @@ # limitations under the License. import unittest +from itertools import product import numpy as np +import torch from parameterized import parameterized -from monai.transforms import RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from monai.transforms import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env _rtol = 5e-3 if is_tf32_env() else 1e-4 -INPUT_SHAPE1 = (1, 8, 8) -INPUT_SHAPE2 = (2, 8, 8) +INPUT_SHAPES = ((1, 8, 8), (2, 8, 8), (1, 8, 8, 8)) TESTS_CONTRAST = [] TESTS_INTENSITY = [] +TESTS_DEFORM = [] -for p in TEST_NDARRAYS: - TESTS_CONTRAST += [ - ( - {"keys": ("test",), "spatial_size": INPUT_SHAPE1[1:], "rand_size": (4, 4), "prob": 1.0}, - {"test": p(np.ones(INPUT_SHAPE1, np.float32))}, - {"test": p(np.ones(INPUT_SHAPE1, np.float32))}, - ), +KEY = "test" + +for arr_type, shape in product(TEST_NDARRAYS, INPUT_SHAPES): + in_arr = arr_type(np.ones(shape, np.float32)) + exp_arr = arr_type(np.ones(shape, np.float32)) + rand_size = (4,) * (len(shape) - 1) + + device = torch.device("cpu") + + if isinstance(in_arr, torch.Tensor) and in_arr.get_device() >= 0: + device = torch.device(in_arr.get_device()) + + TESTS_CONTRAST.append( ( - {"keys": ("test",), "spatial_size": INPUT_SHAPE2[1:], "rand_size": (4, 4), "prob": 1.0}, - {"test": p(np.ones(INPUT_SHAPE2, np.float32))}, - {"test": p(np.ones(INPUT_SHAPE2, np.float32))}, - ), - ] + {"keys": (KEY,), "spatial_size": shape[1:], "rand_size": rand_size, "prob": 1.0, "device": device}, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) - TESTS_INTENSITY += [ + TESTS_INTENSITY.append( ( - {"keys": ("test",), "spatial_size": INPUT_SHAPE1[1:], "rand_size": (4, 4), "prob": 1.0, "gamma": (1, 1)}, - {"test": p(np.ones(INPUT_SHAPE1, np.float32))}, - {"test": p(np.ones(INPUT_SHAPE1, np.float32))}, - ), + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "gamma": (0.9, 1), + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) + + TESTS_DEFORM.append( ( - {"keys": ("test",), "spatial_size": INPUT_SHAPE2[1:], "rand_size": (4, 4), "prob": 1.0, "gamma": (1, 1)}, - {"test": p(np.ones(INPUT_SHAPE2, np.float32))}, - {"test": p(np.ones(INPUT_SHAPE2, np.float32))}, - ), - ] + { + "keys": (KEY,), + "spatial_size": shape[1:], + "rand_size": rand_size, + "prob": 1.0, + "device": device, + "def_range": 0.1, + }, + {KEY: in_arr}, + {KEY: exp_arr}, + ) + ) class TestSmoothField(unittest.TestCase): @@ -62,7 +87,18 @@ def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expec res = g(input_data) for key, result in res.items(): expected = expected_val[key] - assert_allclose(result, expected, rtol=_rtol, atol=5e-3) + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_contrastd_pad(self): + input_param, input_data, expected_val = TESTS_CONTRAST[0] + + g = RandSmoothFieldAdjustContrastd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) @parameterized.expand(TESTS_INTENSITY) def test_rand_smooth_field_adjust_intensityd(self, input_param, input_data, expected_val): @@ -72,4 +108,36 @@ def test_rand_smooth_field_adjust_intensityd(self, input_param, input_data, expe res = g(input_data) for key, result in res.items(): expected = expected_val[key] - assert_allclose(result, expected, rtol=_rtol, atol=5e-3) + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_field_adjust_intensityd_pad(self): + input_param, input_data, expected_val = TESTS_INTENSITY[0] + + g = RandSmoothFieldAdjustIntensityd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + @parameterized.expand(TESTS_DEFORM) + def test_rand_smooth_deformd(self, input_param, input_data, expected_val): + g = RandSmoothDeformd(**input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1) + + def test_rand_smooth_deformd_pad(self): + input_param, input_data, expected_val = TESTS_DEFORM[0] + + g = RandSmoothDeformd(pad=1, **input_param) + g.set_random_state(123) + + res = g(input_data) + for key, result in res.items(): + expected = expected_val[key] + assert_allclose(result, expected, rtol=_rtol, atol=1e-1)