From 86b2fcb002483fe7a60963d5358d557095ef0b45 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 17 Sep 2021 18:05:45 +0100 Subject: [PATCH 1/7] GibbsNoise, RandGibbsNoise, KSpaceSpikeNoise, RandKSpaceSpikeNoise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 191 +++++++++++------------ monai/transforms/intensity/dictionary.py | 66 +++----- monai/transforms/utils.py | 24 +-- tests/test_gibbs_noise.py | 41 ++--- tests/test_gibbs_noised.py | 52 +++--- tests/test_k_space_spike_noise.py | 43 ++--- tests/test_k_space_spike_noised.py | 59 ++++--- 7 files changed, 235 insertions(+), 241 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a8babfc659..3285c306e6 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1252,68 +1252,6 @@ def __call__(self, img: np.ndarray) -> np.ndarray: ) -class RandGibbsNoise(RandomizableTransform): - """ - Naturalistic image augmentation via Gibbs artifacts. The transform - randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts - are one of the common type of type artifacts appearing in MRI scans. - - The transform is applied to all the channels in the data. - - For general information on Gibbs artifacts, please refer to: - https://pubs.rsna.org/doi/full/10.1148/rg.313105115 - https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 - - - Args: - prob (float): probability of applying the transform. - alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes - values in the interval [0,1] with alpha = 0 acting as the identity mapping. - If a length-2 list is given as [a,b] then the value of alpha will be - sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - """ - - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: - - if len(alpha) != 2: - raise ValueError("alpha length must be 2.") - if alpha[1] > 1 or alpha[0] < 0: - raise ValueError("alpha must take values in the interval [0,1]") - if alpha[0] > alpha[1]: - raise ValueError("When alpha = [a,b] we need a < b.") - - self.alpha = alpha - self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output - - RandomizableTransform.__init__(self, prob=prob) - - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: - - # randomize application and possibly alpha - self._randomize(None) - - if self._do_transform: - # apply transform - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) - img = transform(img) - else: - if isinstance(img, np.ndarray) and self.as_tensor_output: - img = torch.Tensor(img) - elif isinstance(img, torch.Tensor) and not self.as_tensor_output: - img = img.detach().cpu().numpy() - return img - - def _randomize(self, _: Any) -> None: - """ - (1) Set random variable to apply the transform. - (2) Get alpha from uniform distribution. - """ - super().randomize(None) - self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - - class GibbsNoise(Transform, Fourier): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts @@ -1332,21 +1270,19 @@ class GibbsNoise(Transform, Fourier): Args: alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. Default: True. """ - def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, alpha: float = 0.5) -> None: if alpha > 1 or alpha < 0: raise ValueError("alpha must take values in the interval [0,1].") self.alpha = alpha - self.as_tensor_output = as_tensor_output - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) # FT k = self.shift_fourier(img, n_dims) # build and apply mask @@ -1354,13 +1290,13 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, # map back img = self.inv_shift_fourier(k, n_dims) - return img if self.as_tensor_output else img.cpu().detach().numpy() # type: ignore + return img - def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: + def _apply_mask(self, k: NdarrayOrTensor) -> NdarrayOrTensor: """Builds and applies a mask on the spatial dimensions. Args: - k (np.ndarray): k-space version of the image. + k: k-space version of the image. Returns: masked version of the k-space image. """ @@ -1381,11 +1317,72 @@ def _apply_mask(self, k: torch.Tensor) -> torch.Tensor: # add channel dimension into mask mask = np.repeat(mask[None], k.shape[0], axis=0) + if isinstance(k, torch.Tensor): + mask, *_ = convert_data_type(mask, torch.Tensor, device=k.device) + # apply binary mask - k_masked = k * torch.tensor(mask, device=k.device) + k_masked: NdarrayOrTensor + k_masked = k * mask return k_masked +class RandGibbsNoise(RandomizableTransform): + """ + Naturalistic image augmentation via Gibbs artifacts. The transform + randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts + are one of the common type of type artifacts appearing in MRI scans. + + The transform is applied to all the channels in the data. + + For general information on Gibbs artifacts, please refer to: + https://pubs.rsna.org/doi/full/10.1148/rg.313105115 + https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949 + + + Args: + prob (float): probability of applying the transform. + alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes + values in the interval [0,1] with alpha = 0 acting as the identity mapping. + If a length-2 list is given as [a,b] then the value of alpha will be + sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. + """ + + backend = GibbsNoise.backend + + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: + + if len(alpha) != 2: + raise ValueError("alpha length must be 2.") + if alpha[1] > 1 or alpha[0] < 0: + raise ValueError("alpha must take values in the interval [0,1]") + if alpha[0] > alpha[1]: + raise ValueError("When alpha = [a,b] we need a < b.") + + self.alpha = alpha + self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() + + RandomizableTransform.__init__(self, prob=prob) + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + + # randomize application and possibly alpha + self._randomize(None) + + if self._do_transform: + # apply transform + transform = GibbsNoise(self.sampled_alpha) + img = transform(img) + return img + + def _randomize(self, _: Any) -> None: + """ + (1) Set random variable to apply the transform. + (2) Get alpha from uniform distribution. + """ + super().randomize(None) + self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) + + class KSpaceSpikeNoise(Transform, Fourier): """ Apply localized spikes in `k`-space at the given locations and intensities. @@ -1413,8 +1410,6 @@ class KSpaceSpikeNoise(Transform, Fourier): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. Example: When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))`` @@ -1423,15 +1418,15 @@ class KSpaceSpikeNoise(Transform, Fourier): with `log-intensity = 14`. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, ): self.loc = ensure_tuple(loc) - self.as_tensor_output = as_tensor_output self.k_intensity = k_intensity # assert one-to-one relationship between factors and locations @@ -1445,7 +1440,7 @@ def __init__( if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence): raise ValueError("There must be one intensity_factor value for each tuple of indices in loc.") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: image with dimensions (C, H, W) or (C, H, W, D) @@ -1462,17 +1457,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, n_dims = len(img.shape[1:]) - if isinstance(img, np.ndarray): - img = torch.Tensor(img) + lib = np if isinstance(img, np.ndarray) else torch + # FT k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - phase = torch.angle(k) + log_abs = lib.log(lib.absolute(k) + 1e-10) # type: ignore + phase = lib.angle(k) # type: ignore k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore # highlight if isinstance(self.loc[0], Sequence): @@ -1481,10 +1476,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = torch.exp(log_abs) * torch.exp(1j * phase) + k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore img = self.inv_shift_fourier(k, n_dims) - return img if self.as_tensor_output else img.cpu().detach().numpy() # type: ignore + return img def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1504,7 +1499,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: NdarrayOrTensor, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. @@ -1550,8 +1545,6 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): log-intensity for each channel. channel_wise: treat each channel independently. True by default. - as_tensor_output: if True return torch.Tensor, else - return np.array. default: True. Example: To apply `k`-space spikes randomly with probability 0.5, and @@ -1560,17 +1553,17 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): ``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)`` """ + backend = KSpaceSpikeNoise.backend + def __init__( self, prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise=True, - as_tensor_output: bool = True, ): self.intensity_range = intensity_range self.channel_wise = channel_wise - self.as_tensor_output = as_tensor_output self.sampled_k_intensity: List = [] self.sampled_locs: List[Tuple] = [] @@ -1579,7 +1572,7 @@ def __init__( super().__init__(prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply transform to `img`. Assumes data is in channel-first form. @@ -1598,20 +1591,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - if not isinstance(img, torch.Tensor): - img = torch.Tensor(img) - intensity_range = self._make_sequence(img) self._randomize(img, intensity_range) # build/appy transform only if there are spike locations if self.sampled_locs: - transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) + transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity) return transform(img) - return img if self.as_tensor_output else img.detach().numpy() + return img - def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None: + def _randomize(self, img: NdarrayOrTensor, intensity_range: Sequence[Sequence[float]]) -> None: """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1639,7 +1629,7 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float else: self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) - def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1651,7 +1641,7 @@ def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]: return (ensure_tuple(self.intensity_range),) * x.shape[0] return ensure_tuple(self.intensity_range) - def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: + def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. @@ -1661,8 +1651,9 @@ def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: n_dims = len(img.shape[1:]) k = self.shift_fourier(img, n_dims) - log_abs = torch.log(torch.absolute(k) + 1e-10) - shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 + mod = torch if isinstance(img, torch.Tensor) else np + log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore + shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 # type: ignore return tuple((i * 0.95, i * 1.1) for i in shifted_means) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 07a6045870..9111699051 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1159,16 +1159,16 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ + backend = GibbsNoise.backend + def __init__( self, keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: @@ -1176,11 +1176,8 @@ def __init__( RandomizableTransform.__init__(self, prob=prob) self.alpha = alpha self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self._randomize(None) @@ -1188,13 +1185,8 @@ def __call__( for i, key in enumerate(self.key_iterator(d)): if self._do_transform: if i == 0: - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) + transform = GibbsNoise(self.sampled_alpha) d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def _randomize(self, _: Any) -> None: @@ -1205,11 +1197,6 @@ def _randomize(self, _: Any) -> None: super().randomize(None) self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: - if isinstance(d, torch.Tensor): - d_numpy: np.ndarray = d.cpu().detach().numpy() - return d_numpy - class GibbsNoised(MapTransform): """ @@ -1227,20 +1214,17 @@ class GibbsNoised(MapTransform): you need to transform. alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ - def __init__( - self, keys: KeysCollection, alpha: float = 0.5, as_tensor_output: bool = True, allow_missing_keys: bool = False - ) -> None: + backend = GibbsNoise.backend + + def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.transform = GibbsNoise(alpha, as_tensor_output) + self.transform = GibbsNoise(alpha) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): @@ -1279,8 +1263,6 @@ class KSpaceSpikeNoised(MapTransform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1291,21 +1273,20 @@ class KSpaceSpikeNoised(MapTransform): with `log-intensity = 14`. """ + backend = KSpaceSpikeNoise.backend + def __init__( self, keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1352,8 +1333,6 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): common_sampling: If ``True`` same values for location and log-intensity will be sampled for the image and label. common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1363,6 +1342,8 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): ``RandKSpaceSpikeNoised("image", prob=0.5, intensity_ranges={"image":(13,15)}, channel_wise=True)``. """ + backend = KSpaceSpikeNoise.backend + def __init__( self, keys: KeysCollection, @@ -1372,7 +1353,6 @@ def __init__( channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ): @@ -1381,21 +1361,16 @@ def __init__( self.common_sampling = common_sampling self.common_seed = common_seed - self.as_tensor_output = as_tensor_output # the spikes artifact is amplitude dependent so we instantiate one per key self.transforms = {} if isinstance(intensity_ranges, Mapping): for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise( - prob, intensity_ranges[k], channel_wise, self.as_tensor_output - ) + self.transforms[k] = RandKSpaceSpikeNoise(prob, intensity_ranges[k], channel_wise) else: for k in self.keys: - self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise, self.as_tensor_output) + self.transforms[k] = RandKSpaceSpikeNoise(prob, None, channel_wise) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1412,11 +1387,6 @@ def __call__( for key, t in self.key_iterator(d, self.transforms): if self._do_transform: d[key] = self.transforms[t](d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 05e45bf26f..9c7a64eac2 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1104,7 +1104,7 @@ class Fourier: @deprecated_arg( name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." ) - def shift_fourier(x: torch.Tensor, spatial_dims: int, n_dims: Optional[int] = None) -> torch.Tensor: + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. @@ -1121,16 +1121,19 @@ def shift_fourier(x: torch.Tensor, spatial_dims: int, n_dims: Optional[int] = No """ if n_dims is not None: spatial_dims = n_dims - k: torch.Tensor = torch.fft.fftshift( - torch.fft.fftn(x, dim=tuple(range(-spatial_dims, 0))), dim=tuple(range(-spatial_dims, 0)) - ) + dims = tuple(range(-spatial_dims, 0)) + k: NdarrayOrTensor + if isinstance(x, torch.Tensor): + k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims) + else: + k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) return k @staticmethod @deprecated_arg( name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." ) - def inv_shift_fourier(k: torch.Tensor, spatial_dims: int, n_dims: Optional[int] = None) -> torch.Tensor: + def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. @@ -1147,10 +1150,13 @@ def inv_shift_fourier(k: torch.Tensor, spatial_dims: int, n_dims: Optional[int] """ if n_dims is not None: spatial_dims = n_dims - x: torch.Tensor = torch.fft.ifftn( - torch.fft.ifftshift(k, dim=tuple(range(-spatial_dims, 0))), dim=tuple(range(-spatial_dims, 0)) - ).real - return x + dims = tuple(range(-spatial_dims, 0)) + out: NdarrayOrTensor + if isinstance(k, torch.Tensor): + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real # type: ignore def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Optional[Hashable] = None) -> int: diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 264e2e630a..2c5e117eaf 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,36 +39,39 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoise(alpha, as_tensor_output) + t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertEqual(type(out1), type(im)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 558556489a..f02052818f 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -19,19 +19,18 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) - + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,49 +40,56 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoised(KEYS, alpha, as_tensor_output) + t = GibbsNoised(KEYS, alpha) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index bb6d05e676..66763f286f 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -20,17 +20,14 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,34 +37,44 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) + return im_type(im[None]) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - im = self.get_data(im_shape, as_tensor_input) + im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() + out2 = out2.cpu() + np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index 616662b3cd..3fa6a394f3 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -20,19 +20,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -42,55 +39,69 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(data) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1) - @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_dict_matches(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) From 3a597cfbd2b022a10b20470f596c03028371c5a6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 17 Sep 2021 18:28:14 +0100 Subject: [PATCH 2/7] update tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_gibbs_noise.py | 50 +++++++++---------- tests/test_rand_gibbs_noised.py | 66 ++++++++++++++----------- tests/test_rand_k_space_spike_noise.py | 62 +++++++++++++---------- tests/test_rand_k_space_spike_noised.py | 66 ++++++++++++++----------- 4 files changed, 133 insertions(+), 111 deletions(-) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index a0701d09c3..15cadea0e2 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -19,17 +19,17 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -39,50 +39,50 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoise(0.0, alpha, as_tensor_output) + t = RandGibbsNoise(0.0, alpha) out = t(im) - np.testing.assert_allclose(im, out) + torch.testing.assert_allclose(im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoise(1.0, alpha, as_tensor_output) + t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) _ = t(deepcopy(im)) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index b778bffdda..b8bac67b81 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -19,19 +19,19 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandGibbsNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,70 +41,76 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + return {k: input_type(v) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 0.0, alpha) out = t(data) for k in KEYS: - np.testing.assert_allclose(data[k], out[k]) + torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(data)) t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) _ = t(deepcopy(data)) - self.assertGreaterEqual(t.sampled_alpha, 0.5) - self.assertLessEqual(t.sampled_alpha, 0.51) + self.assertTrue(0.5 <= t.sampled_alpha <= 0.51) if __name__ == "__main__": diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index 71f7e36d9b..1c9ca9c1d5 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -19,18 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - for channel_wise in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) + for p in TEST_NDARRAYS: + for channel_wise in (True, False): + TESTS.append((shape, p, channel_wise)) -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestRandKSpaceSpikeNoise(unittest.TestCase): def setUp(self): set_determinism(0) @@ -40,44 +37,55 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return im_type(im) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(im, out) - @parameterized.expand(TEST_CASES) - def test_1_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_1_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14] - t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) out = t(im) - base_t = KSpaceSpikeNoise(t.sampled_locs, [14], as_tensor_output) + base_t = KSpaceSpikeNoise(t.sampled_locs, [14]) out = out - base_t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(out, im * 0) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + out1, out2 = out1.cpu(), out2.cpu() np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, _, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14.1] t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) _ = t(deepcopy(im)) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 1056ebf163..869aa50872 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -19,19 +19,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] -@SkipIfBeforePyTorchVersion((1, 8)) -@SkipIfNoModule("torch.fft") class TestKSpaceSpikeNoised(unittest.TestCase): def setUp(self): set_determinism(0) @@ -41,17 +38,16 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return dict(zip(KEYS, ims)) + ims = [im_type(im[None]) for im in ims] + return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) intensity_ranges = {"image": (13, 15), "label": (13, 15)} t = RandKSpaceSpikeNoised( @@ -60,7 +56,6 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): prob=1.0, intensity_ranges=intensity_ranges, channel_wise=True, - as_tensor_output=as_tensor_output, ) t.set_rand_state(42) out1 = t(deepcopy(data)) @@ -69,12 +64,16 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k], atol=1e-10) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) intensity_ranges = {"image": (13, 15), "label": (13, 15)} t1 = RandKSpaceSpikeNoised( KEYS, @@ -82,7 +81,6 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): prob=1.0, intensity_ranges=intensity_ranges, channel_wise=True, - as_tensor_output=as_tensor_output, ) t2 = RandKSpaceSpikeNoised( @@ -91,19 +89,25 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): prob=1.0, intensity_ranges=intensity_ranges, channel_wise=True, - as_tensor_output=as_tensor_output, ) out1 = t1(data) out2 = t2(data) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() + data[k] = data[k].cpu() + np.testing.assert_allclose(data[k], out1[k]) np.testing.assert_allclose(data[k], out2[k]) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) intensity_ranges = {"image": (13, 13.1), "label": (13, 13.1)} t = RandKSpaceSpikeNoised( KEYS, @@ -111,7 +115,6 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): prob=1.0, intensity_ranges=intensity_ranges, channel_wise=True, - as_tensor_output=True, ) _ = t(data) @@ -120,9 +123,9 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): self.assertGreaterEqual(t.transforms["label"].sampled_k_intensity[0], 13) self.assertLessEqual(t.transforms["label"].sampled_k_intensity[0], 13.1) - @parameterized.expand(TEST_CASES) - def test_same_transformation(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_transformation(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} @@ -135,11 +138,16 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): intensity_ranges=intensity_ranges, channel_wise=True, common_sampling=True, - as_tensor_output=True, ) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) From f2e98646862970d0a5842459543ac51b15f6f5fb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 22 Sep 2021 11:45:50 +0100 Subject: [PATCH 3/7] pytorch 1.6 compatible Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 3196d05cd7..e61acadc45 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -43,7 +43,7 @@ optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) ndimage, _ = optional_import("scipy.ndimage") @@ -1273,7 +1273,11 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = dims = tuple(range(-spatial_dims, 0)) k: NdarrayOrTensor if isinstance(x, torch.Tensor): - k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims) + if hasattr(torch.fft, "fftshift"): + k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims) + else: + k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) + k, *_ = convert_to_dst_type(k, x) else: k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) return k @@ -1302,10 +1306,14 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[in dims = tuple(range(-spatial_dims, 0)) out: NdarrayOrTensor if isinstance(k, torch.Tensor): - out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + if hasattr(torch.fft, "ifftshift"): + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real + else: + out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real + out, *_ = convert_to_dst_type(out, k) else: - out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) - return out.real # type: ignore + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real + return out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Optional[Hashable] = None) -> int: From 1c4528751011d18fddd1783e54161a731cc2f476 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 23 Sep 2021 14:12:56 +0100 Subject: [PATCH 4/7] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 393dd89071..83fdc81a64 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -16,6 +16,7 @@ from abc import abstractmethod from collections.abc import Iterable from functools import partial +from monai.utils.misc import is_module_ver_at_least from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from warnings import warn @@ -1476,7 +1477,12 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore + # complex exponential not implemented for older pytorch + if isinstance(phase, torch.Tensor) and not is_module_ver_at_least(torch, (1, 6, 0)): + phase = phase.cpu() + k = torch.exp(log_abs) * torch.exp(1j * phase.cpu()).to(log_abs.device) + else: + k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore img = self.inv_shift_fourier(k, n_dims) return img From 15e25fe970b6a2607880cab0f04be2c81787a799 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 23 Sep 2021 14:13:50 +0100 Subject: [PATCH 5/7] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 83fdc81a64..817c62f59e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -16,7 +16,6 @@ from abc import abstractmethod from collections.abc import Iterable from functools import partial -from monai.utils.misc import is_module_ver_at_least from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from warnings import warn @@ -41,6 +40,7 @@ fall_back_tuple, ) from monai.utils.enums import TransformBackends +from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype __all__ = [ From 28b31c683aeaba79ef5800c814506761c562a0b3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 27 Sep 2021 11:41:42 +0100 Subject: [PATCH 6/7] as_tensor_output deprecated Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 11 +++++++++-- monai/transforms/intensity/dictionary.py | 12 +++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index e53af4c833..98cbc68afc 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -39,6 +39,7 @@ ensure_tuple_size, fall_back_tuple, ) +from monai.utils.deprecated import deprecated_arg from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype @@ -1294,7 +1295,8 @@ class GibbsNoise(Transform, Fourier): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, alpha: float = 0.5) -> None: + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: if alpha > 1 or alpha < 0: raise ValueError("alpha must take values in the interval [0,1].") @@ -1369,7 +1371,8 @@ class RandGibbsNoise(RandomizableTransform): backend = GibbsNoise.backend - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: if len(alpha) != 2: raise ValueError("alpha length must be 2.") @@ -1440,10 +1443,12 @@ class KSpaceSpikeNoise(Transform, Fourier): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, + as_tensor_output: bool = True, ): self.loc = ensure_tuple(loc) @@ -1580,11 +1585,13 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier): backend = KSpaceSpikeNoise.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise=True, + as_tensor_output: bool = True, ): self.intensity_range = intensity_range diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 2f8c097ec0..204b3d7c37 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -48,6 +48,7 @@ from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import is_positive from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size +from monai.utils.deprecated import deprecated_arg from monai.utils.type_conversion import convert_data_type __all__ = [ @@ -1176,12 +1177,14 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): backend = GibbsNoise.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1231,7 +1234,10 @@ class GibbsNoised(MapTransform): backend = GibbsNoise.backend - def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None: + @deprecated_arg(name="as_tensor_output", since="0.6") + def __init__( + self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False, as_tensor_output: bool = True + ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.transform = GibbsNoise(alpha) @@ -1287,12 +1293,14 @@ class KSpaceSpikeNoised(MapTransform): backend = KSpaceSpikeNoise.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1356,6 +1364,7 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): backend = KSpaceSpikeNoise.backend + @deprecated_arg(name="as_tensor_output", since="0.6") def __init__( self, keys: KeysCollection, @@ -1366,6 +1375,7 @@ def __init__( common_sampling: bool = False, common_seed: int = 42, allow_missing_keys: bool = False, + as_tensor_output: bool = True, ): MapTransform.__init__(self, keys, allow_missing_keys) From e0f08d5ac0151293d46a1a99b5500d523152380a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:58:46 +0100 Subject: [PATCH 7/7] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 98cbc68afc..f9fc1ad6ab 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1629,7 +1629,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # build/appy transform only if there are spike locations if self.sampled_locs: transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity) - return transform(img) + out: NdarrayOrTensor = transform(img) + return out return img