Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import math
from copy import deepcopy
from typing import List, Sequence, Union

import torch
Expand All @@ -24,10 +25,10 @@
ChannelMatching,
InvalidPyTorchVersionError,
SkipMode,
ensure_tuple_rep,
look_up_option,
optional_import,
)
from monai.utils.misc import issequenceiterable

_C, _ = optional_import("monai._C")
if not PT_BEFORE_1_7:
Expand Down Expand Up @@ -393,13 +394,18 @@ def __init__(
(for example `parameters()` iterator could be used to get the parameters);
otherwise this module will fix the kernels using `sigma` as the std.
"""
if issequenceiterable(sigma):
if len(sigma) != spatial_dims: # type: ignore
raise ValueError
else:
sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
super().__init__()
self.sigma = [
torch.nn.Parameter(
torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None),
requires_grad=requires_grad,
)
for s in ensure_tuple_rep(sigma, int(spatial_dims))
for s in sigma # type: ignore
]
self.truncated = truncated
self.approx = approx
Expand Down
53 changes: 36 additions & 17 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,15 +1030,24 @@ class GaussianSmooth(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "erf") -> None:
self.sigma = sigma
self.approx = approx

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx)
input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0)
return gaussian_filter(input_data).squeeze(0).detach().numpy()
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) # type: ignore
sigma: Union[Sequence[torch.Tensor], torch.Tensor]
if isinstance(self.sigma, Sequence):
sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]
else:
sigma = torch.as_tensor(self.sigma, device=img_t.device)
gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)
out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)
out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None)
return out


class RandGaussianSmooth(RandomizableTransform):
Expand Down Expand Up @@ -1079,10 +1088,10 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1])
self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1])

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
self.randomize()
if not self._do_transform:
img, *_ = convert_data_type(img, dtype=torch.float)
return img
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=img.ndim - 1)
return GaussianSmooth(sigma=sigma, approx=self.approx)(img)
Expand Down Expand Up @@ -1115,6 +1124,8 @@ class GaussianSharpen(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(
self,
sigma1: Union[Sequence[float], float] = 3.0,
Expand All @@ -1127,14 +1138,19 @@ def __init__(
self.alpha = alpha
self.approx = approx

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx)
gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx)
input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0)
blurred_f = gaussian_filter1(input_data)
filter_blurred_f = gaussian_filter2(blurred_f)
return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy()
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore

gf1, gf2 = [
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)
for sigma in (self.sigma1, self.sigma2)
]
blurred_f = gf1(img_t.unsqueeze(0))
filter_blurred_f = gf2(blurred_f)
out_t: torch.Tensor = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0)
out, *_ = convert_data_type(out_t, type(img), device=img.device if isinstance(img, torch.Tensor) else None)
return out


class RandGaussianSharpen(RandomizableTransform):
Expand All @@ -1159,6 +1175,8 @@ class RandGaussianSharpen(RandomizableTransform):

"""

backend = GaussianSharpen.backend

def __init__(
self,
sigma1_x: Tuple[float, float] = (0.5, 1.0),
Expand Down Expand Up @@ -1194,10 +1212,11 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1])
self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1])

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
self.randomize()
# if not doing, just need to convert to tensor
if not self._do_transform:
img, *_ = convert_data_type(img, dtype=torch.float32)
return img
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=img.ndim - 1)
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=img.ndim - 1)
Expand Down
38 changes: 25 additions & 13 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size
from monai.utils.type_conversion import convert_data_type

__all__ = [
"RandGaussianNoised",
Expand Down Expand Up @@ -897,6 +898,8 @@ class GaussianSmoothd(MapTransform):

"""

backend = GaussianSmooth.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -907,7 +910,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.converter = GaussianSmooth(sigma, approx=approx)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand All @@ -931,6 +934,8 @@ class RandGaussianSmoothd(RandomizableTransform, MapTransform):

"""

backend = GaussianSmooth.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -954,14 +959,15 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1])
self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1])

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize()
if not self._do_transform:
return d
for key in self.key_iterator(d):
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1)
d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key])
if self._do_transform:
sigma = ensure_tuple_size(tup=(self.x, self.y, self.z), dim=d[key].ndim - 1)
d[key] = GaussianSmooth(sigma=sigma, approx=self.approx)(d[key])
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float)
return d


Expand All @@ -985,6 +991,8 @@ class GaussianSharpend(MapTransform):

"""

backend = GaussianSharpen.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -997,7 +1005,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand Down Expand Up @@ -1028,6 +1036,8 @@ class RandGaussianSharpend(RandomizableTransform, MapTransform):

"""

backend = GaussianSharpen.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -1066,15 +1076,17 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1])
self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1])

def __call__(self, data):
def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize()
if not self._do_transform:
return d
for key in self.key_iterator(d):
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1)
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1)
d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key])
if self._do_transform:
sigma1 = ensure_tuple_size(tup=(self.x1, self.y1, self.z1), dim=d[key].ndim - 1)
sigma2 = ensure_tuple_size(tup=(self.x2, self.y2, self.z2), dim=d[key].ndim - 1)
d[key] = GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(d[key])
else:
# if not doing the transform, convert to torch
d[key], *_ = convert_data_type(d[key], torch.Tensor, dtype=torch.float32)
return d


Expand Down
83 changes: 56 additions & 27 deletions tests/test_gaussian_sharpen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,50 +11,79 @@

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import GaussianSharpen
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [
{},
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
np.array(
TESTS = []

for p in TEST_NDARRAYS:
TESTS.append(
[
[[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]],
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
{},
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
p(
[
[
[4.1081963, 3.4950666, 4.1081963],
[3.7239995, 2.8491793, 3.7239995],
[4.569839, 3.9529324, 4.569839],
],
[[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],
]
),
]
),
]
)

TEST_CASE_2 = [
{"sigma1": 1.0, "sigma2": 0.75, "alpha": 20},
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
np.array(
TESTS.append(
[
[[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]],
[[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]],
{"sigma1": 1.0, "sigma2": 0.75, "alpha": 20},
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
p(
[
[
[4.513644, 4.869134, 4.513644],
[8.467242, 9.4004135, 8.467242],
[10.416813, 12.0653515, 10.416813],
],
[
[15.711488, 17.569994, 15.711488],
[21.16811, 23.501041, 21.16811],
[21.614658, 24.766209, 21.614658],
],
]
),
]
),
]
)

TEST_CASE_3 = [
{"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20},
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
np.array(
TESTS.append(
[
[[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]],
[[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]],
{"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20},
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
p(
[
[
[3.3324685, 3.335536, 3.3324673],
[7.7666636, 8.16056, 7.7666636],
[12.662973, 14.317837, 12.6629715],
],
[
[15.329051, 16.57557, 15.329051],
[19.41665, 20.40139, 19.416655],
[24.659554, 27.557873, 24.659554],
],
]
),
]
),
]
)


class TestGaussianSharpen(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand(TESTS)
def test_value(self, argments, image, expected_data):
result = GaussianSharpen(**argments)(image)
np.testing.assert_allclose(result, expected_data, rtol=1e-4)
assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=False)


if __name__ == "__main__":
Expand Down
Loading