From 05d570e137cfc5a683447f97ed42d5410acce5ea Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 8 Apr 2023 10:41:12 +0100 Subject: [PATCH 1/2] Fix MMD Metric Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/__init__.py | 2 +- generative/metrics/mmd.py | 19 ++++--------------- tests/test_compute_mmd_metric.py | 7 ++++--- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/generative/metrics/__init__.py b/generative/metrics/__init__.py index 87e2d949..35af1fe4 100644 --- a/generative/metrics/__init__.py +++ b/generative/metrics/__init__.py @@ -12,6 +12,6 @@ from __future__ import annotations from .fid import FIDMetric -from .mmd import MMD +from .mmd import MMDMetric from .ms_ssim import MultiScaleSSIMMetric from .ssim import SSIMMetric diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py index bba93141..7ece6019 100644 --- a/generative/metrics/mmd.py +++ b/generative/metrics/mmd.py @@ -14,11 +14,10 @@ from collections.abc import Callable import torch -from monai.metrics.regression import RegressionMetric -from monai.utils import MetricReduction +from monai.metrics.metric import Metric -class MMD(RegressionMetric): +class MMDMetric(Metric): """ Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two distributions. It is a non-negative metric where a smaller value indicates a closer match between the two @@ -31,29 +30,19 @@ class MMD(RegressionMetric): filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a feature extractor or an Identity function. y_pred_transform: Callable to transform the y_pred tensor before computing the metric. - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, available - reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, - `"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. This parameter is ignored due to - the mathematical formulation of MMD. - get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here - `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. - This parameter is ignored due to the mathematical formulation of MMD. - """ def __init__( self, y_transform: Callable | None = None, y_pred_transform: Callable | None = None, - reduction: MetricReduction | str = MetricReduction.MEAN, - get_not_nans: bool = False, ) -> None: - super().__init__(reduction=reduction, get_not_nans=get_not_nans) + super().__init__() self.y_transform = y_transform self.y_pred_transform = y_pred_transform - def _compute_metric(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: """ Args: y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. diff --git a/tests/test_compute_mmd_metric.py b/tests/test_compute_mmd_metric.py index ab016653..a888c1f3 100644 --- a/tests/test_compute_mmd_metric.py +++ b/tests/test_compute_mmd_metric.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from generative.metrics import MMD +from generative.metrics import MMDMetric TEST_CASES = [ [ @@ -36,12 +36,13 @@ class TestMMDMetric(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_results(self, input_param, input_data, expected_val): - results = MMD(**input_param)._compute_metric(**input_data) + metric = MMDMetric(**input_param) + results = metric(**input_data) np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) def test_if_inputs_different_shapes(self): with self.assertRaises(ValueError): - MMD()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + MMDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) if __name__ == "__main__": From aa330831f38cdcfcc98f2f5484fcb4b597ab52cc Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 8 Apr 2023 10:45:47 +0100 Subject: [PATCH 2/2] Fix MMD Metric Signed-off-by: Walter Hugo Lopez Pinaya --- generative/metrics/mmd.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/generative/metrics/mmd.py b/generative/metrics/mmd.py index 7ece6019..416e1dfa 100644 --- a/generative/metrics/mmd.py +++ b/generative/metrics/mmd.py @@ -32,11 +32,7 @@ class MMDMetric(Metric): y_pred_transform: Callable to transform the y_pred tensor before computing the metric. """ - def __init__( - self, - y_transform: Callable | None = None, - y_pred_transform: Callable | None = None, - ) -> None: + def __init__(self, y_transform: Callable | None = None, y_pred_transform: Callable | None = None) -> None: super().__init__() self.y_transform = y_transform