diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index e05137d8..f846ba1b 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -13,4 +13,5 @@ from .fid import FID from .mmd import MMD -from .ms_ssim import MSSSIM +from .ms_ssim import MultiScaleSSIMMetric +from .ssim import SSIMMetric diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 9cd6cabb..ff6c4f1c 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -11,130 +11,143 @@ from __future__ import annotations +from collections.abc import Sequence + import torch import torch.nn.functional as F -from monai.metrics import SSIMMetric from monai.metrics.regression import RegressionMetric -from monai.utils import MetricReduction +from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep + +from generative.metrics.ssim import compute_ssim_and_cs + +class KernelType(StrEnum): + GAUSSIAN = "gaussian" + UNIFORM = "uniform" -class MSSSIM(RegressionMetric): + +class MultiScaleSSIMMetric(RegressionMetric): """ - Computes Multi-Scale Structural Similarity Index Measure. + Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM). [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. Multiscale structural similarity for image quality assessment. - In The Thrity-Seventh Asilomar Conference on Signals, Systems + In The Thirty-Seventh Asilomar Conference on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. Args: - data_range: dynamic range of the data - win_size: gaussian weighting window size + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator - spatial_dims: if 2, input shape is expected to be (B,C,W,H); - if 3, it is expected to be (B,C,W,H,D) - weights: parameters for image similarity and contrast sensitivity - at different resolution scores. - reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the output. - Defaults to ``"mean"``. - - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + weights: parameters for image similarity and contrast sensitivity at different resolution scores. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) """ def __init__( self, - data_range: torch.Tensor | float, - win_size: int = 7, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: float | Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, - spatial_dims: int = 2, - weights: list | None = None, + weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: - super().__init__() - - if not (win_size % 2 == 1): - raise ValueError("Window size should be odd.") + super().__init__(reduction=reduction, get_not_nans=get_not_nans) - self.data_range = data_range - self.win_size = win_size - self.k1, self.k2 = k1, k2 self.spatial_dims = spatial_dims - self.weights = weights - self.reduction = reduction + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size - self.SSIM = SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims) + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + self.weights = weights - def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: - x: first sample (e.g., the reference image). Its shape is - (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. - A fastMRI sample should use the 2D format with C being - the number of slices. - y: second sample (e.g., the reconstructed image). It has similar - shape as x + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. """ - - if not x.shape == y.shape: - raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.") - - for d in range(len(x.shape) - 1, 1, -1): - x = x.squeeze(dim=d) - y = y.squeeze(dim=d) - - if len(x.shape) == 4: - avg_pool = F.avg_pool2d - elif len(x.shape) == 5: - avg_pool = F.avg_pool3d - else: - raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}") - - if self.weights is None: - # as per Ref 1 - Sec 3.2. - self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] - self.weights = torch.tensor(self.weights) - - divisible_by = 2 ** (len(self.weights) - 1) - bigger_than = (self.win_size + 2) * 2 ** (len(self.weights) - 1) - for idx, shape_size in enumerate(x.shape[2:]): - if shape_size % divisible_by != 0: + dims = y_pred.ndimension() + if self.spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " + f"spatial dimensions, got {dims}." + ) + + if self.spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f" spatial dimensions, got {dims}." + ) + + # check if image have enough size for the number of downsamplings and the size of the kernel + weights_div = max(1, (len(self.weights) - 1)) ** 2 + y_pred_spatial_dims = y_pred.shape[2:] + for i in range(len(y_pred_spatial_dims)): + if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1: raise ValueError( - f"Image size needs to be divisible by {divisible_by} but " - f"dimension {idx + 2} has size {shape_size}" + f"For a given number of `weights` parameters {len(self.weights)} and kernel size " + f"{self.kernel_size[i]}, the image height must be larger than " + f"{(self.kernel_size[i] - 1) * weights_div}." ) - if shape_size < bigger_than: - raise ValueError( - f"Image size should be larger than {bigger_than} due to " - f"the {len(self.weights) - 1} downsamplings in MS-SSIM." - ) + weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float) + + avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d") + + multiscale_list: list[torch.Tensor] = [] + for i in range(len(weights)): + ssim, cs = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + cs_per_batch = cs.view(cs.shape[0], -1).mean(1) + + multiscale_list.append(torch.relu(cs_per_batch)) + y_pred = avg_pool(y_pred, kernel_size=2) + y = avg_pool(y, kernel_size=2) + + ssim = ssim.view(ssim.shape[0], -1).mean(1) + multiscale_list[-1] = torch.relu(ssim) + multiscale_list = torch.stack(multiscale_list) + + ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0) + + ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) - levels = self.weights.shape[0] - mcs_list: list[torch.Tensor] = [] - for i in range(levels): - ssim, cs = self.SSIM._compute_metric_and_contrast(x, y) - - if i < levels - 1: - mcs_list.append(torch.relu(cs)) - padding = [s % 2 for s in x.shape[2:]] - x = avg_pool(x, kernel_size=2, padding=padding) - y = avg_pool(y, kernel_size=2, padding=padding) - - ssim = torch.relu(ssim) # (batch, 1) - # (level, batch, 1) - mcs_and_ssim = torch.stack(mcs_list + [ssim], dim=0) - ms_ssim = torch.prod(mcs_and_ssim ** self.weights.view(-1, 1, 1), dim=0) - - if self.reduction == MetricReduction.MEAN.value: - ms_ssim = ms_ssim.mean() - elif self.reduction == MetricReduction.SUM.value: - ms_ssim = ms_ssim.sum() - elif self.reduction == MetricReduction.NONE.value: - pass - - return ms_ssim + return ms_ssim_per_batch diff --git a/generative/metrics/ssim.py b/generative/metrics/ssim.py new file mode 100644 index 00000000..07039309 --- /dev/null +++ b/generative/metrics/ssim.py @@ -0,0 +1,231 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from monai.metrics.regression import RegressionMetric +from monai.utils import MetricReduction, StrEnum, convert_data_type, ensure_tuple_rep +from monai.utils.type_conversion import convert_to_dst_type + + +class KernelType(StrEnum): + GAUSSIAN = "gaussian" + UNIFORM = "uniform" + + +class SSIMMetric(RegressionMetric): + r""" + Computes the Structural Similarity Index Measure (SSIM). + + .. math:: + \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \ + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)} + + For more info, visit + https://vicuesoft.com/glossary/term/ssim-ms-ssim/ + + SSIM reference paper: + Wang, Zhou, et al. "Image quality assessment: from error visibility to structural + similarity." IEEE transactions on image processing 13.4 (2004): 600-612. + + Args: + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. + k1: stability constant used in the luminance denominator + k2: stability constant used in the contrast denominator + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) + """ + + def __init__( + self, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: float | Sequence[float, ...] = 1.5, + k1: float = 0.01, + k2: float = 0.03, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ) -> None: + super().__init__(reduction=reduction, get_not_nans=get_not_nans) + + self.spatial_dims = spatial_dims + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. + """ + dims = y_pred.ndimension() + if self.spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " + f"spatial dimensions, got {dims}." + ) + + if self.spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f" spatial dimensions, got {dims}." + ) + + ssim_value_full_image, _ = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) + + return ssim_per_batch + + +def _gaussian_kernel( + spatial_dims: int, num_channels: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] +) -> torch.Tensor: + """Computes 2D or 3D gaussian kernel. + + Args: + spatial_dims: number of spatial dimensions of the input images. + num_channels: number of channels in the image + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. + """ + + def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: + """Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + """ + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) + + gaussian_kernel_x = gaussian_1d(kernel_size[0], kernel_sigma[0]) + gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1]) + + if spatial_dims == 3: + gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] + kernel = torch.mul( + kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), + gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), + ) + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1], kernel_size[2]) + + return kernel.expand(kernel_dimensions) + + +def compute_ssim_and_cs( + y_pred: torch.Tensor, + y: torch.Tensor, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: Sequence[int, ...] = 11, + kernel_sigma: Sequence[float, ...] = 1.5, + k1: float = 0.01, + k2: float = 0.03, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch + of images. + + Args: + y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + spatial_dims: number of spatial dimensions of the images (2, 3) + data_range: the data range of the images. + kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform". + kernel_size: the size of the kernel to use for the SSIM computation. + kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. + k1: the first stability constant. + k2: the second stability constant. + + Returns: + ssim: the Structural Similarity Index Measure score for the batch of images. + cs: the Contrast Sensitivity for the batch of images. + """ + if y.shape != y_pred.shape: + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") + + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] + + num_channels = y_pred.size(1) + + if kernel_type == KernelType.GAUSSIAN: + kernel = _gaussian_kernel(spatial_dims, num_channels, kernel_size, kernel_sigma) + elif kernel_type == KernelType.UNIFORM: + kernel = torch.ones((num_channels, 1, *kernel_size)) / torch.prod(torch.tensor(kernel_size)) + + kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] + + c1 = (k1 * data_range) ** 2 # stability constant for luminance + c2 = (k2 * data_range) ** 2 # stability constant for contrast + + conv_fn = getattr(F, f"conv{spatial_dims}d") + mu_x = conv_fn(y_pred, kernel, groups=num_channels) + mu_y = conv_fn(y, kernel, groups=num_channels) + mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) + mu_yy = conv_fn(y * y, kernel, groups=num_channels) + mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) + + sigma_x = mu_xx - mu_x * mu_x + sigma_y = mu_yy - mu_y * mu_y + sigma_xy = mu_xy - mu_x * mu_y + + contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) + ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity + + return ssim_value_full_image, contrast_sensitivity diff --git a/tests/test_compute_ms_ssim_metric.py b/tests/test_compute_ms_ssim_metric.py deleted file mode 100644 index f648086c..00000000 --- a/tests/test_compute_ms_ssim_metric.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from generative.metrics import MSSSIM - -TEST_CASES = [ - [ - {"data_range": torch.tensor(1.0)}, - {"x": torch.ones([3, 3, 144, 144]) / 2, "y": torch.ones([3, 3, 144, 144]) / 2}, - 1.0, - ], - [ - {"data_range": torch.tensor(1.0), "spatial_dims": 3}, - {"x": torch.ones([3, 3, 144, 144, 144]) / 2, "y": torch.ones([3, 3, 144, 144, 144]) / 2}, - 1.0, - ], -] - - -class TestMSSSIMMetric(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_results(self, input_param, input_data, expected_val): - results = MSSSIM(**input_param)._compute_metric(**input_data) - np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) - - def test_win_size_not_odd(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0, win_size=8) - - def test_if_inputs_different_shapes(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) - - def test_wrong_shape(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 144, 144]), torch.ones([3, 144, 144])) - - def test_input_too_small(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 8, 8]), torch.ones([3, 3, 8, 8])) - - def test_input_non_divisible(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 149, 149]), torch.ones([3, 3, 149, 149])) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py new file mode 100644 index 00000000..1f385fd4 --- /dev/null +++ b/tests/test_compute_multiscalessim_metric.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from monai.utils import set_determinism + +from generative.metrics import MultiScaleSSIMMetric + + +class TestMultiScaleSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.023176 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test2d_uniform(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.022655 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test3d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.061796 + self.assertTrue(expected_value - result.item() < 0.000001) + + def input_ill_input_shape(self): + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) + metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) + + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) + metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) + + def small_inputs(self): + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=2) + metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main()