Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion generative/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@

from .fid import FID
from .mmd import MMD
from .ms_ssim import MSSSIM
from .ms_ssim import MultiScaleSSIMMetric
from .ssim import SSIMMetric
199 changes: 106 additions & 93 deletions generative/metrics/ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,130 +11,143 @@

from __future__ import annotations

from collections.abc import Sequence

import torch
import torch.nn.functional as F
from monai.metrics import SSIMMetric
from monai.metrics.regression import RegressionMetric
from monai.utils import MetricReduction
from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep

from generative.metrics.ssim import compute_ssim_and_cs


class KernelType(StrEnum):
GAUSSIAN = "gaussian"
UNIFORM = "uniform"

class MSSSIM(RegressionMetric):

class MultiScaleSSIMMetric(RegressionMetric):
"""
Computes Multi-Scale Structural Similarity Index Measure.
Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM).

[1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November.
Multiscale structural similarity for image quality assessment.
In The Thrity-Seventh Asilomar Conference on Signals, Systems
In The Thirty-Seventh Asilomar Conference on Signals, Systems
& Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee.

Args:
data_range: dynamic range of the data
win_size: gaussian weighting window size
spatial_dims: number of spatial dimensions of the input images.
data_range: value range of input images. (usually 1.0 or 255)
kernel_type: type of kernel, can be "gaussian" or "uniform".
kernel_size: size of kernel
kernel_sigma: standard deviation for Gaussian kernel.
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,W,H);
if 3, it is expected to be (B,C,W,H,D)
weights: parameters for image similarity and contrast sensitivity
at different resolution scores.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output.
Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
weights: parameters for image similarity and contrast sensitivity at different resolution scores.
reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
"""

def __init__(
self,
data_range: torch.Tensor | float,
win_size: int = 7,
spatial_dims: int,
data_range: float = 1.0,
kernel_type: KernelType | str = KernelType.GAUSSIAN,
kernel_size: int | Sequence[int, ...] = 11,
kernel_sigma: float | Sequence[float, ...] = 1.5,
k1: float = 0.01,
k2: float = 0.03,
spatial_dims: int = 2,
weights: list | None = None,
weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
) -> None:
super().__init__()

if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
super().__init__(reduction=reduction, get_not_nans=get_not_nans)

self.data_range = data_range
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.spatial_dims = spatial_dims
self.weights = weights
self.reduction = reduction
self.data_range = data_range
self.kernel_type = kernel_type

if not isinstance(kernel_size, Sequence):
kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)
self.kernel_size = kernel_size

self.SSIM = SSIMMetric(self.data_range, self.win_size, self.k1, self.k2, self.spatial_dims)
if not isinstance(kernel_sigma, Sequence):
kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)
self.kernel_sigma = kernel_sigma

self.k1 = k1
self.k2 = k2
self.weights = weights

def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is
(B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
A fastMRI sample should use the 2D format with C being
the number of slices.
y: second sample (e.g., the reconstructed image). It has similar
shape as x
y_pred: Predicted image.
It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
y: Reference image.
It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].

Raises:
ValueError: when `y_pred` is not a 2D or 3D image.
"""

if not x.shape == y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {x.shape} and {y.shape}.")

for d in range(len(x.shape) - 1, 1, -1):
x = x.squeeze(dim=d)
y = y.squeeze(dim=d)

if len(x.shape) == 4:
avg_pool = F.avg_pool2d
elif len(x.shape) == 5:
avg_pool = F.avg_pool3d
else:
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {x.shape}")

if self.weights is None:
# as per Ref 1 - Sec 3.2.
self.weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
self.weights = torch.tensor(self.weights)

divisible_by = 2 ** (len(self.weights) - 1)
bigger_than = (self.win_size + 2) * 2 ** (len(self.weights) - 1)
for idx, shape_size in enumerate(x.shape[2:]):
if shape_size % divisible_by != 0:
dims = y_pred.ndimension()
if self.spatial_dims == 2 and dims != 4:
raise ValueError(
f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} "
f"spatial dimensions, got {dims}."
)

if self.spatial_dims == 3 and dims != 5:
raise ValueError(
f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}"
f" spatial dimensions, got {dims}."
)

# check if image have enough size for the number of downsamplings and the size of the kernel
weights_div = max(1, (len(self.weights) - 1)) ** 2
y_pred_spatial_dims = y_pred.shape[2:]
for i in range(len(y_pred_spatial_dims)):
if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1:
raise ValueError(
f"Image size needs to be divisible by {divisible_by} but "
f"dimension {idx + 2} has size {shape_size}"
f"For a given number of `weights` parameters {len(self.weights)} and kernel size "
f"{self.kernel_size[i]}, the image height must be larger than "
f"{(self.kernel_size[i] - 1) * weights_div}."
)

if shape_size < bigger_than:
raise ValueError(
f"Image size should be larger than {bigger_than} due to "
f"the {len(self.weights) - 1} downsamplings in MS-SSIM."
)
weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float)

avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d")

multiscale_list: list[torch.Tensor] = []
for i in range(len(weights)):
ssim, cs = compute_ssim_and_cs(
y_pred=y_pred,
y=y,
spatial_dims=self.spatial_dims,
data_range=self.data_range,
kernel_type=self.kernel_type,
kernel_size=self.kernel_size,
kernel_sigma=self.kernel_sigma,
k1=self.k1,
k2=self.k2,
)

cs_per_batch = cs.view(cs.shape[0], -1).mean(1)

multiscale_list.append(torch.relu(cs_per_batch))
y_pred = avg_pool(y_pred, kernel_size=2)
y = avg_pool(y, kernel_size=2)

ssim = ssim.view(ssim.shape[0], -1).mean(1)
multiscale_list[-1] = torch.relu(ssim)
multiscale_list = torch.stack(multiscale_list)

ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0)

ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean(
1, keepdim=True
)

levels = self.weights.shape[0]
mcs_list: list[torch.Tensor] = []
for i in range(levels):
ssim, cs = self.SSIM._compute_metric_and_contrast(x, y)

if i < levels - 1:
mcs_list.append(torch.relu(cs))
padding = [s % 2 for s in x.shape[2:]]
x = avg_pool(x, kernel_size=2, padding=padding)
y = avg_pool(y, kernel_size=2, padding=padding)

ssim = torch.relu(ssim) # (batch, 1)
# (level, batch, 1)
mcs_and_ssim = torch.stack(mcs_list + [ssim], dim=0)
ms_ssim = torch.prod(mcs_and_ssim ** self.weights.view(-1, 1, 1), dim=0)

if self.reduction == MetricReduction.MEAN.value:
ms_ssim = ms_ssim.mean()
elif self.reduction == MetricReduction.SUM.value:
ms_ssim = ms_ssim.sum()
elif self.reduction == MetricReduction.NONE.value:
pass

return ms_ssim
return ms_ssim_per_batch
Loading