diff --git a/generative/metrics/fid.py b/generative/metrics/fid.py index dade6d43..8dc7b154 100644 --- a/generative/metrics/fid.py +++ b/generative/metrics/fid.py @@ -8,31 +8,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# ========================================================================= -# Adapted from https://github.com/photosynthesis-team/piq -# which has the following license: -# https://github.com/photosynthesis-team/piq/blob/master/LICENSE -# -# Copyright 2023 photosynthesis-team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ========================================================================= + from __future__ import annotations +import numpy as np import torch from monai.metrics.metric import Metric +from scipy import linalg class FIDMetric(Metric): @@ -70,64 +53,32 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) -def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor: +def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor: """ Estimate a covariance matrix of the variables. Args: - m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, + input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, and each column a single observation of all those variables. rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. Otherwise, the relationship is transposed: each column represents a variable, while the rows contain observations. """ - if m.dim() < 2: - m = m.view(1, -1) + if input_data.dim() < 2: + input_data = input_data.view(1, -1) - if not rowvar and m.size(0) != 1: - m = m.t() + if not rowvar and input_data.size(0) != 1: + input_data = input_data.t() - fact = 1.0 / (m.size(1) - 1) - m = m - torch.mean(m, dim=1, keepdim=True) - mt = m.t() - return fact * m.matmul(mt).squeeze() + factor = 1.0 / (input_data.size(1) - 1) + input_data = input_data - torch.mean(input_data, dim=1, keepdim=True) + return factor * input_data.matmul(input_data.t()).squeeze() -def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]: - """ - Square root of matrix using Newton-Schulz Iterative method. Based on: - https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in: - https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303 - - Args: - matrix: matrix or batch of matrices - num_iters: Number of iteration of the method - - """ - dim = matrix.size(0) - norm_of_matrix = matrix.norm(p="fro") - y_matrix = matrix.div(norm_of_matrix) - i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype) - z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype) - - s_matrix = torch.empty_like(matrix) - error = torch.empty(1, device=matrix.device, dtype=matrix.dtype) - - for _ in range(num_iters): - t = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix)) - y_matrix = y_matrix.mm(t) - z_matrix = t.mm(z_matrix) - - s_matrix = y_matrix * torch.sqrt(norm_of_matrix) - - norm_of_matrix = torch.norm(matrix) - error = matrix - torch.mm(s_matrix, s_matrix) - error = torch.norm(error) / norm_of_matrix - - if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5): - break - - return s_matrix, error +def _sqrtm(input_data: torch.Tensor) -> torch.Tensor: + """Compute the square root of a matrix.""" + scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False) + return torch.from_numpy(scipy_res) def compute_frechet_distance( @@ -135,12 +86,20 @@ def compute_frechet_distance( ) -> torch.Tensor: """The Frechet distance between multivariate normal distributions.""" diff = mu_x - mu_y - covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y)) - # If calculation produces singular product, epsilon is added to diagonal of cov estimates + covmean = _sqrtm(sigma_x.mm(sigma_y)) + + # Product might be almost singular if not torch.isfinite(covmean).all(): + print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates") offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon - covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset)) + covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset)) + + # Numerical error might give slight imaginary component + if torch.is_complex(covmean): + if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3): + raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.") + covmean = covmean.real tr_covmean = torch.trace(covmean) return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean diff --git a/tests/test_compute_fid_metric.py b/tests/test_compute_fid_metric.py index 3323066f..318b9d25 100644 --- a/tests/test_compute_fid_metric.py +++ b/tests/test_compute_fid_metric.py @@ -24,7 +24,7 @@ def test_results(self): x = torch.Tensor([[1, 2], [1, 2], [1, 2]]) y = torch.Tensor([[2, 2], [1, 2], [1, 2]]) results = FIDMetric()(x, y) - np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4) + np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4) def test_input_dimensions(self): with self.assertRaises(ValueError):