Skip to content
201 changes: 103 additions & 98 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
ensure_tuple_size,
fall_back_tuple,
)
from monai.utils.deprecated import deprecated_arg
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype

__all__ = [
Expand Down Expand Up @@ -1271,68 +1273,6 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
)


class RandGibbsNoise(RandomizableTransform):
"""
Naturalistic image augmentation via Gibbs artifacts. The transform
randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
are one of the common type of type artifacts appearing in MRI scans.

The transform is applied to all the channels in the data.

For general information on Gibbs artifacts, please refer to:
https://pubs.rsna.org/doi/full/10.1148/rg.313105115
https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949


Args:
prob (float): probability of applying the transform.
alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be
sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
as_tensor_output: if true return torch.Tensor, else return np.array. default: True.
"""

def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None:

if len(alpha) != 2:
raise ValueError("alpha length must be 2.")
if alpha[1] > 1 or alpha[0] < 0:
raise ValueError("alpha must take values in the interval [0,1]")
if alpha[0] > alpha[1]:
raise ValueError("When alpha = [a,b] we need a < b.")

self.alpha = alpha
self.sampled_alpha = -1.0 # stores last alpha sampled by randomize()
self.as_tensor_output = as_tensor_output

RandomizableTransform.__init__(self, prob=prob)

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:

# randomize application and possibly alpha
self._randomize(None)

if self._do_transform:
# apply transform
transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output)
img = transform(img)
else:
if isinstance(img, np.ndarray) and self.as_tensor_output:
img = torch.Tensor(img)
elif isinstance(img, torch.Tensor) and not self.as_tensor_output:
img = img.detach().cpu().numpy()
return img

def _randomize(self, _: Any) -> None:
"""
(1) Set random variable to apply the transform.
(2) Get alpha from uniform distribution.
"""
super().randomize(None)
self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1])


class GibbsNoise(Transform, Fourier):
"""
The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
Expand All @@ -1351,35 +1291,34 @@ class GibbsNoise(Transform, Fourier):
Args:
alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
as_tensor_output: if true return torch.Tensor, else return np.array. Default: True.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

@deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:

if alpha > 1 or alpha < 0:
raise ValueError("alpha must take values in the interval [0,1].")
self.alpha = alpha
self.as_tensor_output = as_tensor_output

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
n_dims = len(img.shape[1:])

if isinstance(img, np.ndarray):
img = torch.Tensor(img)
# FT
k = self.shift_fourier(img, n_dims)
# build and apply mask
k = self._apply_mask(k)
# map back
img = self.inv_shift_fourier(k, n_dims)

return img if self.as_tensor_output else img.cpu().detach().numpy() # type: ignore
return img

def _apply_mask(self, k: torch.Tensor) -> torch.Tensor:
def _apply_mask(self, k: NdarrayOrTensor) -> NdarrayOrTensor:
"""Builds and applies a mask on the spatial dimensions.

Args:
k (np.ndarray): k-space version of the image.
k: k-space version of the image.
Returns:
masked version of the k-space image.
"""
Expand All @@ -1400,11 +1339,73 @@ def _apply_mask(self, k: torch.Tensor) -> torch.Tensor:
# add channel dimension into mask
mask = np.repeat(mask[None], k.shape[0], axis=0)

if isinstance(k, torch.Tensor):
mask, *_ = convert_data_type(mask, torch.Tensor, device=k.device)

# apply binary mask
k_masked = k * torch.tensor(mask, device=k.device)
k_masked: NdarrayOrTensor
k_masked = k * mask
return k_masked


class RandGibbsNoise(RandomizableTransform):
"""
Naturalistic image augmentation via Gibbs artifacts. The transform
randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
are one of the common type of type artifacts appearing in MRI scans.

The transform is applied to all the channels in the data.

For general information on Gibbs artifacts, please refer to:
https://pubs.rsna.org/doi/full/10.1148/rg.313105115
https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949


Args:
prob (float): probability of applying the transform.
alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be
sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
"""

backend = GibbsNoise.backend

@deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None:

if len(alpha) != 2:
raise ValueError("alpha length must be 2.")
if alpha[1] > 1 or alpha[0] < 0:
raise ValueError("alpha must take values in the interval [0,1]")
if alpha[0] > alpha[1]:
raise ValueError("When alpha = [a,b] we need a < b.")

self.alpha = alpha
self.sampled_alpha = -1.0 # stores last alpha sampled by randomize()

RandomizableTransform.__init__(self, prob=prob)

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

# randomize application and possibly alpha
self._randomize(None)

if self._do_transform:
# apply transform
transform = GibbsNoise(self.sampled_alpha)
img = transform(img)
return img

def _randomize(self, _: Any) -> None:
"""
(1) Set random variable to apply the transform.
(2) Get alpha from uniform distribution.
"""
super().randomize(None)
self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1])


class KSpaceSpikeNoise(Transform, Fourier):
"""
Apply localized spikes in `k`-space at the given locations and intensities.
Expand Down Expand Up @@ -1432,8 +1433,6 @@ class KSpaceSpikeNoise(Transform, Fourier):
receive a sequence of intensities. This value should be tested as it is
data-dependent. The default values are the 2.5 the mean of the
log-intensity for each channel.
as_tensor_output: if ``True`` return torch.Tensor, else return np.array.
Default: ``True``.

Example:
When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))``
Expand All @@ -1442,6 +1441,9 @@ class KSpaceSpikeNoise(Transform, Fourier):
with `log-intensity = 14`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

@deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
loc: Union[Tuple, Sequence[Tuple]],
Expand All @@ -1450,7 +1452,6 @@ def __init__(
):

self.loc = ensure_tuple(loc)
self.as_tensor_output = as_tensor_output
self.k_intensity = k_intensity

# assert one-to-one relationship between factors and locations
Expand All @@ -1464,7 +1465,7 @@ def __init__(
if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence):
raise ValueError("There must be one intensity_factor value for each tuple of indices in loc.")

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: image with dimensions (C, H, W) or (C, H, W, D)
Expand All @@ -1481,17 +1482,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,

n_dims = len(img.shape[1:])

if isinstance(img, np.ndarray):
img = torch.Tensor(img)
lib = np if isinstance(img, np.ndarray) else torch

# FT
k = self.shift_fourier(img, n_dims)
log_abs = torch.log(torch.absolute(k) + 1e-10)
phase = torch.angle(k)
log_abs = lib.log(lib.absolute(k) + 1e-10) # type: ignore
phase = lib.angle(k) # type: ignore

k_intensity = self.k_intensity
# default log intensity
if k_intensity is None:
k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5)
k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore

# highlight
if isinstance(self.loc[0], Sequence):
Expand All @@ -1500,10 +1501,15 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
else:
self._set_spike(log_abs, self.loc, k_intensity)
# map back
k = torch.exp(log_abs) * torch.exp(1j * phase)
# complex exponential not implemented for older pytorch
if isinstance(phase, torch.Tensor) and not is_module_ver_at_least(torch, (1, 6, 0)):
phase = phase.cpu()
k = torch.exp(log_abs) * torch.exp(1j * phase.cpu()).to(log_abs.device)
else:
k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore
img = self.inv_shift_fourier(k, n_dims)

return img if self.as_tensor_output else img.cpu().detach().numpy() # type: ignore
return img

def _check_indices(self, img) -> None:
"""Helper method to check consistency of self.loc and input image.
Expand All @@ -1523,7 +1529,7 @@ def _check_indices(self, img) -> None:
f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image."
)

def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]):
def _set_spike(self, k: NdarrayOrTensor, idx: Tuple, val: Union[Sequence[float], float]):
"""
Helper function to introduce a given intensity at given location.

Expand Down Expand Up @@ -1569,8 +1575,6 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier):
log-intensity for each channel.
channel_wise: treat each channel independently. True by
default.
as_tensor_output: if True return torch.Tensor, else
return np.array. default: True.

Example:
To apply `k`-space spikes randomly with probability 0.5, and
Expand All @@ -1579,6 +1583,9 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier):
``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)``
"""

backend = KSpaceSpikeNoise.backend

@deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
prob: float = 0.1,
Expand All @@ -1589,7 +1596,6 @@ def __init__(

self.intensity_range = intensity_range
self.channel_wise = channel_wise
self.as_tensor_output = as_tensor_output
self.sampled_k_intensity: List = []
self.sampled_locs: List[Tuple] = []

Expand All @@ -1598,7 +1604,7 @@ def __init__(

super().__init__(prob)

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply transform to `img`. Assumes data is in channel-first form.

Expand All @@ -1617,20 +1623,18 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
self.sampled_k_intensity = []
self.sampled_locs = []

if not isinstance(img, torch.Tensor):
img = torch.Tensor(img)

intensity_range = self._make_sequence(img)
self._randomize(img, intensity_range)

# build/appy transform only if there are spike locations
if self.sampled_locs:
transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output)
return transform(img)
transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity)
out: NdarrayOrTensor = transform(img)
return out

return img if self.as_tensor_output else img.detach().numpy()
return img

def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None:
def _randomize(self, img: NdarrayOrTensor, intensity_range: Sequence[Sequence[float]]) -> None:
"""
Helper method to sample both the location and intensity of the spikes.
When not working channel wise (channel_wise=False) it use the random
Expand Down Expand Up @@ -1658,7 +1662,7 @@ def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float
else:
self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img)

def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]:
def _make_sequence(self, x: NdarrayOrTensor) -> Sequence[Sequence[float]]:
"""
Formats the sequence of intensities ranges to Sequence[Sequence[float]].
"""
Expand All @@ -1670,7 +1674,7 @@ def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]:
return (ensure_tuple(self.intensity_range),) * x.shape[0]
return ensure_tuple(self.intensity_range)

def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]:
def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]:
"""
Sets default intensity ranges to be sampled.

Expand All @@ -1680,8 +1684,9 @@ def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]:
n_dims = len(img.shape[1:])

k = self.shift_fourier(img, n_dims)
log_abs = torch.log(torch.absolute(k) + 1e-10)
shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5
mod = torch if isinstance(img, torch.Tensor) else np
log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore
shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5 # type: ignore
return tuple((i * 0.95, i * 1.1) for i in shifted_means)


Expand Down
Loading