Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,9 @@ def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]:
k = self.shift_fourier(img, n_dims)
mod = torch if isinstance(k, torch.Tensor) else np
log_abs = mod.log(mod.absolute(k) + 1e-10)
shifted_means = mod.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5
shifted_means = mod.mean(log_abs, tuple(range(-n_dims, 0))) * 2.5
if isinstance(shifted_means, torch.Tensor):
shifted_means = shifted_means.to("cpu")
return tuple((i * 0.95, i * 1.1) for i in shifted_means)


Expand Down
18 changes: 9 additions & 9 deletions tests/test_k_space_spike_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
TESTS = []
for shape in ((128, 64), (64, 48, 80)):
for p in TEST_NDARRAYS:
TESTS.append((shape, p))
for intensity in [10, None]:
TESTS.append((shape, p, intensity))


class TestKSpaceSpikeNoise(unittest.TestCase):
Expand All @@ -43,11 +44,10 @@ def get_data(im_shape, im_type):
return im_type(im[None])

@parameterized.expand(TESTS)
def test_same_result(self, im_shape, im_type):
def test_same_result(self, im_shape, im_type, k_intensity):

im = self.get_data(im_shape, im_type)
loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0]
k_intensity = 10
t = KSpaceSpikeNoise(loc, k_intensity)

out1 = t(deepcopy(im))
Expand All @@ -62,11 +62,10 @@ def test_same_result(self, im_shape, im_type):
np.testing.assert_allclose(out1, out2)

@parameterized.expand(TESTS)
def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input):
def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input, k_intensity):

im = self.get_data(im_shape, as_tensor_input)
loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0]
k_intensity = 10
t = KSpaceSpikeNoise(loc, k_intensity)
out = t(im)

Expand All @@ -75,10 +74,11 @@ def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input):
self.assertEqual(im.device, out.device)
out = out.cpu()

n_dims = len(im_shape)
out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
log_mag = np.log(np.absolute(out_k))
np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-4)
if k_intensity is not None:
n_dims = len(im_shape)
out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
log_mag = np.log(np.absolute(out_k))
np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-4)


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions tests/test_rand_k_space_spike_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def test_intensity(self, im_shape, im_type, channel_wise):
self.assertGreaterEqual(t.sampled_k_intensity[0], 14)
self.assertLessEqual(t.sampled_k_intensity[0], 14.1)

@parameterized.expand(TESTS)
def test_default_intensity(self, im_shape, im_type, channel_wise):
im = self.get_data(im_shape, im_type)
t = RandKSpaceSpikeNoise(1.0, intensity_range=None, channel_wise=channel_wise)
out = t(deepcopy(im))
self.assertEqual(out.shape, im.shape)


if __name__ == "__main__":
unittest.main()