From f4f7cde0b9cf61ace22748f8c1a0c61224f2001a Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 16:43:40 +0100 Subject: [PATCH 1/6] [WIP] Fix SSIM Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/ssim.py | 177 +++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 generative/metrics/ssim.py diff --git a/generative/metrics/ssim.py b/generative/metrics/ssim.py new file mode 100644 index 00000000..6a3df2d1 --- /dev/null +++ b/generative/metrics/ssim.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +from monai.utils import ensure_tuple_rep +from monai.utils import StrEnum +import torch +import torch.nn.functional as F + +from monai.utils.type_conversion import convert_to_dst_type +from monai.metrics.regression import RegressionMetric +from monai.utils import MetricReduction, convert_data_type + +class KernelType(StrEnum): + GAUSSIAN = "gaussian" + UNIFORM = "uniform" + + +class SSIMMetric(RegressionMetric): + r""" + Computes the Structural Similarity Index Measure (SSIM). + + .. math:: + \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \ + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)} + + For more info, visit + https://vicuesoft.com/glossary/term/ssim-ms-ssim/ + + SSIM reference paper: + Wang, Zhou, et al. "Image quality assessment: from error visibility to structural + similarity." IEEE transactions on image processing 13.4 (2004): 600-612. + + Args: + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. + k1: stability constant used in the luminance denominator + k2: stability constant used in the contrast denominator + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) + """ + + def __init__( + self, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: int = 1.5, + k1: float = 0.01, + k2: float = 0.03, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ): + super().__init__(reduction=reduction, get_not_nans=get_not_nans) + + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size + + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + self.spatial_dims = spatial_dims + + def _gaussian_kernel(self, channel: int) -> torch.Tensor: + """Computes 2D or 3D gaussian kernel. + + Args: + channel: number of channels in the image + """ + + def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: + """ Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + """ + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) + + gaussian_kernel_x = gaussian_1d(self.kernel_size[0], self.kernel_sigma[0]) + gaussian_kernel_y = gaussian_1d(self.kernel_size[1], self.kernel_sigma[1]) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + kernel_dimensions = (channel, 1, self.kernel_size[0], self.kernel_size[1]) + + if self.spatial_dims == 3: + gaussian_kernel_z = gaussian_1d(self.kernel_size[2], self.kernel_sigma[2])[None,] + kernel = torch.mul( + kernel.unsqueeze(-1).repeat(1, 1, self.kernel_size[2]), + gaussian_kernel_z.expand(self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]), + ) + kernel_dimensions = (channel, 1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]) + + return kernel.expand(kernel_dimensions) + + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. + """ + dims = y_pred.ndimension() + if self.spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " + f"spatial dimensions, got {dims}." + ) + + if self.spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f" spatial dimensions, got {dims}." + ) + + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] + + num_channels = y_pred.size(1) + + if self.kernel_type == KernelType.GAUSSIAN: + kernel = self._gaussian_kernel(num_channels) + elif self.kernel_type == KernelType.UNIFORM: + kernel = torch.ones((num_channels, 1, *self.kernel_size)) / torch.prod(torch.tensor(self.kernel_size)) + + kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] + + c1 = (self.k1 * self.data_range) ** 2 # stability constant for luminance + c2 = (self.k2 * self.data_range) ** 2 # stability constant for contrast + + conv_fn = getattr(F, f"conv{self.spatial_dims}d") + mu_x = conv_fn(y_pred, kernel, groups=num_channels) + mu_y = conv_fn(y, kernel, groups=num_channels) + mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) + mu_yy = conv_fn(y * y, kernel, groups=num_channels) + mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) + + sigma_x = mu_xx - mu_x * mu_x + sigma_y = mu_yy - mu_y * mu_y + sigma_xy = mu_xy - mu_x * mu_y + + contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) + ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity + + ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean(1, keepdim=True) + + return ssim_per_batch From 4486166a39a9ea20d28513d0475050405e78c7c3 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 22:42:40 +0100 Subject: [PATCH 2/6] Fix SSIM and MS-SIM Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/__init__.py | 3 +- generative/metrics/ms_ssim.py | 197 ++++++++++++++++++--------------- generative/metrics/ssim.py | 177 ++++++++++++++++++----------- 3 files changed, 220 insertions(+), 157 deletions(-) diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index e05137d8..f846ba1b 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -13,4 +13,5 @@ from .fid import FID from .mmd import MMD -from .ms_ssim import MSSSIM +from .ms_ssim import MultiScaleSSIMMetric +from .ssim import SSIMMetric diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 9cd6cabb..07c69040 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -11,16 +11,24 @@ from __future__ import annotations +from collections.abc import Sequence + import torch import torch.nn.functional as F -from monai.metrics import SSIMMetric from monai.metrics.regression import RegressionMetric -from monai.utils import MetricReduction +from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep + +from generative.metrics.ssim import compute_ssim_and_cs + +class KernelType(StrEnum): + GAUSSIAN = "gaussian" + UNIFORM = "uniform" -class MSSSIM(RegressionMetric): + +class MultiScaleSSIMMetric(RegressionMetric): """ - Computes Multi-Scale Structural Similarity Index Measure. + Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM). [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. Multiscale structural similarity for image quality assessment. @@ -28,113 +36,118 @@ class MSSSIM(RegressionMetric): & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. Args: - data_range: dynamic range of the data - win_size: gaussian weighting window size + spatial_dims: number of spatial dimensions of the input images. + data_range: value range of input images. (usually 1.0 or 255) + kernel_type: type of kernel, can be "gaussian" or "uniform". + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. k1: stability constant used in the luminance denominator k2: stability constant used in the contrast denominator - spatial_dims: if 2, input shape is expected to be (B,C,W,H); - if 3, it is expected to be (B,C,W,H,D) - weights: parameters for image similarity and contrast sensitivity - at different resolution scores. - reduction: {``"none"``, ``"mean"``, ``"sum"``} - Specifies the reduction to apply to the output. - Defaults to ``"mean"``. - - ``"none"``: no reduction will be applied. - - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - - ``"sum"``: the output will be summed. + weights: parameters for image similarity and contrast sensitivity at different resolution scores. + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) """ def __init__( self, - data_range: torch.Tensor | float, - win_size: int = 7, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: int | Sequence[int, ...] = 11, + kernel_sigma: int | Sequence[int, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, - spatial_dims: int = 2, - weights: list | None = None, + weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, ) -> None: - super().__init__() - - if not (win_size % 2 == 1): - raise ValueError("Window size should be odd.") + super().__init__(reduction=reduction, get_not_nans=get_not_nans) - self.data_range = data_range - self.win_size = win_size - self.k1, self.k2 = k1, k2 self.spatial_dims = spatial_dims - self.weights = weights - self.reduction = reduction + self.data_range = data_range + self.kernel_type = kernel_type + + if not isinstance(kernel_size, Sequence): + kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) + self.kernel_size = kernel_size - self.SSIM = SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims) + if not isinstance(kernel_sigma, Sequence): + kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) + self.kernel_sigma = kernel_sigma + + self.k1 = k1 + self.k2 = k2 + self.weights = weights - def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: - x: first sample (e.g., the reference image). Its shape is - (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. - A fastMRI sample should use the 2D format with C being - the number of slices. - y: second sample (e.g., the reconstructed image). It has similar - shape as x + y_pred: Predicted image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + y: Reference image. + It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. + Raises: + ValueError: when `y_pred` is not a 2D or 3D image. """ - - if not x.shape == y.shape: - raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.") - - for d in range(len(x.shape) - 1, 1, -1): - x = x.squeeze(dim=d) - y = y.squeeze(dim=d) - - if len(x.shape) == 4: - avg_pool = F.avg_pool2d - elif len(x.shape) == 5: - avg_pool = F.avg_pool3d - else: - raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}") - - if self.weights is None: - # as per Ref 1 - Sec 3.2. - self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] - self.weights = torch.tensor(self.weights) - - divisible_by = 2 ** (len(self.weights) - 1) - bigger_than = (self.win_size + 2) * 2 ** (len(self.weights) - 1) - for idx, shape_size in enumerate(x.shape[2:]): - if shape_size % divisible_by != 0: + dims = y_pred.ndimension() + if self.spatial_dims == 2 and dims != 4: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " + f"spatial dimensions, got {dims}." + ) + + if self.spatial_dims == 3 and dims != 5: + raise ValueError( + f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" + f" spatial dimensions, got {dims}." + ) + + # check if image have enough size for the number of downsamplings and the size of the kernel + weights_div = max(1, (len(self.weights) - 1)) ** 2 + y_pred_spatial_dims = y_pred.shape[2:] + for i in range(len(y_pred_spatial_dims)): + if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1: raise ValueError( - f"Image size needs to be divisible by {divisible_by} but " - f"dimension {idx + 2} has size {shape_size}" + f"For a given number of `weights` parameters {len(self.weights)} and kernel size " + f"{self.kernel_size[i]}, the image height must be larger than " + f"{(self.kernel_size[i] - 1) * weights_div}." ) - if shape_size < bigger_than: - raise ValueError( - f"Image size should be larger than {bigger_than} due to " - f"the {len(self.weights) - 1} downsamplings in MS-SSIM." - ) + weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float) + + avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d") + + multiscale_list: list[torch.Tensor] = [] + for i in range(len(weights)): + ssim, cs = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + cs_per_batch = cs.view(cs.shape[0], -1).mean(1) + + multiscale_list.append(torch.relu(cs_per_batch)) + y_pred = avg_pool(y_pred, kernel_size=2) + y = avg_pool(y, kernel_size=2) + + ssim = ssim.view(ssim.shape[0], -1).mean(1) + multiscale_list[-1] = torch.relu(ssim) + multiscale_list = torch.stack(multiscale_list) + + ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0) + + ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) - levels = self.weights.shape[0] - mcs_list: list[torch.Tensor] = [] - for i in range(levels): - ssim, cs = self.SSIM._compute_metric_and_contrast(x, y) - - if i < levels - 1: - mcs_list.append(torch.relu(cs)) - padding = [s % 2 for s in x.shape[2:]] - x = avg_pool(x, kernel_size=2, padding=padding) - y = avg_pool(y, kernel_size=2, padding=padding) - - ssim = torch.relu(ssim) # (batch, 1) - # (level, batch, 1) - mcs_and_ssim = torch.stack(mcs_list + [ssim], dim=0) - ms_ssim = torch.prod(mcs_and_ssim ** self.weights.view(-1, 1, 1), dim=0) - - if self.reduction == MetricReduction.MEAN.value: - ms_ssim = ms_ssim.mean() - elif self.reduction == MetricReduction.SUM.value: - ms_ssim = ms_ssim.sum() - elif self.reduction == MetricReduction.NONE.value: - pass - - return ms_ssim + return ms_ssim_per_batch diff --git a/generative/metrics/ssim.py b/generative/metrics/ssim.py index 6a3df2d1..ac148004 100644 --- a/generative/metrics/ssim.py +++ b/generative/metrics/ssim.py @@ -13,14 +13,12 @@ from collections.abc import Sequence -from monai.utils import ensure_tuple_rep -from monai.utils import StrEnum import torch import torch.nn.functional as F - -from monai.utils.type_conversion import convert_to_dst_type from monai.metrics.regression import RegressionMetric -from monai.utils import MetricReduction, convert_data_type +from monai.utils import MetricReduction, StrEnum, convert_data_type, ensure_tuple_rep +from monai.utils.type_conversion import convert_to_dst_type + class KernelType(StrEnum): GAUSSIAN = "gaussian" @@ -62,7 +60,7 @@ def __init__( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: int = 1.5, + kernel_sigma: int | Sequence[int, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, @@ -70,6 +68,7 @@ def __init__( ): super().__init__(reduction=reduction, get_not_nans=get_not_nans) + self.spatial_dims = spatial_dims self.data_range = data_range self.kernel_type = kernel_type @@ -83,41 +82,6 @@ def __init__( self.k1 = k1 self.k2 = k2 - self.spatial_dims = spatial_dims - - def _gaussian_kernel(self, channel: int) -> torch.Tensor: - """Computes 2D or 3D gaussian kernel. - - Args: - channel: number of channels in the image - """ - - def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: - """ Computes 1D gaussian kernel. - - Args: - kernel_size: size of the gaussian kernel - sigma: Standard deviation of the gaussian kernel - """ - dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) - gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) - return (gauss / gauss.sum()).unsqueeze(dim=0) - - gaussian_kernel_x = gaussian_1d(self.kernel_size[0], self.kernel_sigma[0]) - gaussian_kernel_y = gaussian_1d(self.kernel_size[1], self.kernel_sigma[1]) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - - kernel_dimensions = (channel, 1, self.kernel_size[0], self.kernel_size[1]) - - if self.spatial_dims == 3: - gaussian_kernel_z = gaussian_1d(self.kernel_size[2], self.kernel_sigma[2])[None,] - kernel = torch.mul( - kernel.unsqueeze(-1).repeat(1, 1, self.kernel_size[2]), - gaussian_kernel_z.expand(self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]), - ) - kernel_dimensions = (channel, 1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]) - - return kernel.expand(kernel_dimensions) def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -143,35 +107,120 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor f" spatial dimensions, got {dims}." ) - y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] - y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] + ssim_value_full_image, _ = compute_ssim_and_cs( + y_pred=y_pred, + y=y, + spatial_dims=self.spatial_dims, + data_range=self.data_range, + kernel_type=self.kernel_type, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + k1=self.k1, + k2=self.k2, + ) + + ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean( + 1, keepdim=True + ) - num_channels = y_pred.size(1) + return ssim_per_batch + + +def _gaussian_kernel(spatial_dims, channel: int, kernel_size, kernel_sigma) -> torch.Tensor: + """Computes 2D or 3D gaussian kernel. + + Args: + channel: number of channels in the image + """ + + def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: + """Computes 1D gaussian kernel. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + """ + dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) + gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) + return (gauss / gauss.sum()).unsqueeze(dim=0) + + gaussian_kernel_x = gaussian_1d(kernel_size[0], kernel_sigma[0]) + gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1]) + + if spatial_dims == 3: + gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] + kernel = torch.mul( + kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), + gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), + ) + kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1], kernel_size[2]) + + return kernel.expand(kernel_dimensions) + + +def compute_ssim_and_cs( + y_pred: torch.Tensor, + y: torch.Tensor, + spatial_dims: int, + data_range: float = 1.0, + kernel_type: KernelType | str = KernelType.GAUSSIAN, + kernel_size: Sequence[int, ...] = 11, + kernel_sigma: Sequence[int, ...] = 1.5, + k1: float = 0.01, + k2: float = 0.03, +): + """ + Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch + of images. + + Args: + y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) + spatial_dims: number of spatial dimensions of the images (2, 3) + data_range: the data range of the images. + kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform". + kernel_size: the size of the kernel to use for the SSIM computation. + kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. + k1: the first stability constant. + k2: the second stability constant. + + Returns: + ssim: the Structural Similarity Index Measure score for the batch of images. + cs: the Contrast Sensitivity for the batch of images. + """ + if y.shape != y_pred.shape: + raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") - if self.kernel_type == KernelType.GAUSSIAN: - kernel = self._gaussian_kernel(num_channels) - elif self.kernel_type == KernelType.UNIFORM: - kernel = torch.ones((num_channels, 1, *self.kernel_size)) / torch.prod(torch.tensor(self.kernel_size)) + y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] + y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] - kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] + num_channels = y_pred.size(1) - c1 = (self.k1 * self.data_range) ** 2 # stability constant for luminance - c2 = (self.k2 * self.data_range) ** 2 # stability constant for contrast + if kernel_type == KernelType.GAUSSIAN: + kernel = _gaussian_kernel(spatial_dims, num_channels, kernel_size, kernel_sigma) + elif kernel_type == KernelType.UNIFORM: + kernel = torch.ones((num_channels, 1, *kernel_size)) / torch.prod(torch.tensor(kernel_size)) - conv_fn = getattr(F, f"conv{self.spatial_dims}d") - mu_x = conv_fn(y_pred, kernel, groups=num_channels) - mu_y = conv_fn(y, kernel, groups=num_channels) - mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) - mu_yy = conv_fn(y * y, kernel, groups=num_channels) - mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) + kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] - sigma_x = mu_xx - mu_x * mu_x - sigma_y = mu_yy - mu_y * mu_y - sigma_xy = mu_xy - mu_x * mu_y + c1 = (k1 * data_range) ** 2 # stability constant for luminance + c2 = (k2 * data_range) ** 2 # stability constant for contrast - contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) - ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity + conv_fn = getattr(F, f"conv{spatial_dims}d") + mu_x = conv_fn(y_pred, kernel, groups=num_channels) + mu_y = conv_fn(y, kernel, groups=num_channels) + mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) + mu_yy = conv_fn(y * y, kernel, groups=num_channels) + mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) - ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean(1, keepdim=True) + sigma_x = mu_xx - mu_x * mu_x + sigma_y = mu_yy - mu_y * mu_y + sigma_xy = mu_xy - mu_x * mu_y - return ssim_per_batch + contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) + ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity + + return ssim_value_full_image, contrast_sensitivity From aeaaae967c61e283d40d08e667bdea8a61eae48d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 27 Mar 2023 23:52:25 +0100 Subject: [PATCH 3/6] Fix SSIM data typing Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/ssim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/metrics/ssim.py b/generative/metrics/ssim.py index ac148004..9b2afdac 100644 --- a/generative/metrics/ssim.py +++ b/generative/metrics/ssim.py @@ -65,7 +65,7 @@ def __init__( k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, - ): + ) -> None: super().__init__(reduction=reduction, get_not_nans=get_not_nans) self.spatial_dims = spatial_dims @@ -171,7 +171,7 @@ def compute_ssim_and_cs( kernel_sigma: Sequence[int, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, -): +) -> tuple[torch.Tensor, torch.Tensor]: """ Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch of images. From 46d9cbb1ea4837e66008ec5bdaf577e2f6b6012b Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Tue, 28 Mar 2023 00:00:26 +0100 Subject: [PATCH 4/6] Fix SSIM data typing Signed-off-by: Walter Hugo Lopez Pinaya --- tests/test_compute_ms_ssim_metric.py | 64 ----------------- tests/test_compute_multiscalessim_metric.py | 78 +++++++++++++++++++++ 2 files changed, 78 insertions(+), 64 deletions(-) delete mode 100644 tests/test_compute_ms_ssim_metric.py create mode 100644 tests/test_compute_multiscalessim_metric.py diff --git a/tests/test_compute_ms_ssim_metric.py b/tests/test_compute_ms_ssim_metric.py deleted file mode 100644 index f648086c..00000000 --- a/tests/test_compute_ms_ssim_metric.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from generative.metrics import MSSSIM - -TEST_CASES = [ - [ - {"data_range": torch.tensor(1.0)}, - {"x": torch.ones([3, 3, 144, 144]) / 2, "y": torch.ones([3, 3, 144, 144]) / 2}, - 1.0, - ], - [ - {"data_range": torch.tensor(1.0), "spatial_dims": 3}, - {"x": torch.ones([3, 3, 144, 144, 144]) / 2, "y": torch.ones([3, 3, 144, 144, 144]) / 2}, - 1.0, - ], -] - - -class TestMSSSIMMetric(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_results(self, input_param, input_data, expected_val): - results = MSSSIM(**input_param)._compute_metric(**input_data) - np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) - - def test_win_size_not_odd(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0, win_size=8) - - def test_if_inputs_different_shapes(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) - - def test_wrong_shape(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 144, 144]), torch.ones([3, 144, 144])) - - def test_input_too_small(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 8, 8]), torch.ones([3, 3, 8, 8])) - - def test_input_non_divisible(self): - with self.assertRaises(ValueError): - MSSSIM(data_range=1.0)(torch.ones([3, 3, 149, 149]), torch.ones([3, 3, 149, 149])) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py new file mode 100644 index 00000000..952bd920 --- /dev/null +++ b/tests/test_compute_multiscalessim_metric.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from generative.metrics import MultiScaleSSIMMetric +from monai.utils import set_determinism + + +class TestMultiScaleSSIMMetric(unittest.TestCase): + def test2d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.023176 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test2d_uniform(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.022655 + self.assertTrue(expected_value - result.item() < 0.000001) + + def test3d_gaussian(self): + set_determinism(0) + preds = torch.abs(torch.randn(1, 1, 64, 64, 64)) + target = torch.abs(torch.randn(1, 1, 64, 64, 64)) + preds = preds / preds.max() + target = target / target.max() + + metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) + metric(preds, target) + result = metric.aggregate() + expected_value = 0.061796 + self.assertTrue(expected_value - result.item() < 0.000001) + + def input_ill_input_shape(self): + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) + metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) + + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) + metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) + + def small_inputs(self): + with self.assertRaises(ValueError): + metric = MultiScaleSSIMMetric(spatial_dims=2) + metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) + + +if __name__ == "__main__": + unittest.main() From ff66b0cacea9b398c10e034c73b4b5ac5670cfa5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 09:34:40 +0100 Subject: [PATCH 5/6] Update generative/metrics/ms_ssim.py Co-authored-by: Mark Graham --- generative/metrics/ms_ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 07c69040..ac1b96fd 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -32,7 +32,7 @@ class MultiScaleSSIMMetric(RegressionMetric): [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. Multiscale structural similarity for image quality assessment. - In The Thrity-Seventh Asilomar Conference on Signals, Systems + In The Thirty-Seventh Asilomar Conference on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. Args: From af60f830b2819e25d72716a6e13376cf93ff1f39 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 29 Mar 2023 11:06:05 +0100 Subject: [PATCH 6/6] Fix SSIM data typing Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/ms_ssim.py | 2 +- generative/metrics/ssim.py | 17 +++++++++++------ tests/test_compute_multiscalessim_metric.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/generative/metrics/ms_ssim.py b/generative/metrics/ms_ssim.py index 07c69040..3b404452 100644 --- a/generative/metrics/ms_ssim.py +++ b/generative/metrics/ms_ssim.py @@ -56,7 +56,7 @@ def __init__( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: int | Sequence[int, ...] = 1.5, + kernel_sigma: float | Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), diff --git a/generative/metrics/ssim.py b/generative/metrics/ssim.py index 9b2afdac..07039309 100644 --- a/generative/metrics/ssim.py +++ b/generative/metrics/ssim.py @@ -60,7 +60,7 @@ def __init__( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: int | Sequence[int, ...] = 11, - kernel_sigma: int | Sequence[int, ...] = 1.5, + kernel_sigma: float | Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, reduction: MetricReduction | str = MetricReduction.MEAN, @@ -126,11 +126,16 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor return ssim_per_batch -def _gaussian_kernel(spatial_dims, channel: int, kernel_size, kernel_sigma) -> torch.Tensor: +def _gaussian_kernel( + spatial_dims: int, num_channels: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] +) -> torch.Tensor: """Computes 2D or 3D gaussian kernel. Args: - channel: number of channels in the image + spatial_dims: number of spatial dimensions of the input images. + num_channels: number of channels in the image + kernel_size: size of kernel + kernel_sigma: standard deviation for Gaussian kernel. """ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: @@ -148,7 +153,7 @@ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) - kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1]) + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1]) if spatial_dims == 3: gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] @@ -156,7 +161,7 @@ def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), ) - kernel_dimensions = (channel, 1, kernel_size[0], kernel_size[1], kernel_size[2]) + kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1], kernel_size[2]) return kernel.expand(kernel_dimensions) @@ -168,7 +173,7 @@ def compute_ssim_and_cs( data_range: float = 1.0, kernel_type: KernelType | str = KernelType.GAUSSIAN, kernel_size: Sequence[int, ...] = 11, - kernel_sigma: Sequence[int, ...] = 1.5, + kernel_sigma: Sequence[float, ...] = 1.5, k1: float = 0.01, k2: float = 0.03, ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/tests/test_compute_multiscalessim_metric.py b/tests/test_compute_multiscalessim_metric.py index 952bd920..1f385fd4 100644 --- a/tests/test_compute_multiscalessim_metric.py +++ b/tests/test_compute_multiscalessim_metric.py @@ -14,9 +14,9 @@ import unittest import torch +from monai.utils import set_determinism from generative.metrics import MultiScaleSSIMMetric -from monai.utils import set_determinism class TestMultiScaleSSIMMetric(unittest.TestCase):