Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 27 additions & 68 deletions generative/metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# =========================================================================
# Adapted from https://github.com/photosynthesis-team/piq
# which has the following license:
# https://github.com/photosynthesis-team/piq/blob/master/LICENSE
#
# Copyright 2023 photosynthesis-team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================


from __future__ import annotations

import numpy as np
import torch
from monai.metrics.metric import Metric
from scipy import linalg


class FIDMetric(Metric):
Expand Down Expand Up @@ -70,77 +53,53 @@ def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)


def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
"""
Estimate a covariance matrix of the variables.

Args:
m: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
and each column a single observation of all those variables.
rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
observations.
"""
if m.dim() < 2:
m = m.view(1, -1)
if input_data.dim() < 2:
input_data = input_data.view(1, -1)

if not rowvar and m.size(0) != 1:
m = m.t()
if not rowvar and input_data.size(0) != 1:
input_data = input_data.t()

fact = 1.0 / (m.size(1) - 1)
m = m - torch.mean(m, dim=1, keepdim=True)
mt = m.t()
return fact * m.matmul(mt).squeeze()
factor = 1.0 / (input_data.size(1) - 1)
input_data = input_data - torch.mean(input_data, dim=1, keepdim=True)
return factor * input_data.matmul(input_data.t()).squeeze()


def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> tuple[torch.Tensor, torch.Tensor]:
"""
Square root of matrix using Newton-Schulz Iterative method. Based on:
https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py. Bechmark shown in:
https://github.com/photosynthesis-team/piq/issues/190#issuecomment-742039303

Args:
matrix: matrix or batch of matrices
num_iters: Number of iteration of the method

"""
dim = matrix.size(0)
norm_of_matrix = matrix.norm(p="fro")
y_matrix = matrix.div(norm_of_matrix)
i_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)
z_matrix = torch.eye(dim, dim, device=matrix.device, dtype=matrix.dtype)

s_matrix = torch.empty_like(matrix)
error = torch.empty(1, device=matrix.device, dtype=matrix.dtype)

for _ in range(num_iters):
t = 0.5 * (3.0 * i_matrix - z_matrix.mm(y_matrix))
y_matrix = y_matrix.mm(t)
z_matrix = t.mm(z_matrix)

s_matrix = y_matrix * torch.sqrt(norm_of_matrix)

norm_of_matrix = torch.norm(matrix)
error = matrix - torch.mm(s_matrix, s_matrix)
error = torch.norm(error) / norm_of_matrix

if torch.isclose(error, torch.tensor([0.0], device=error.device, dtype=error.dtype), atol=1e-5):
break

return s_matrix, error
def _sqrtm(input_data: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a matrix."""
scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False)
return torch.from_numpy(scipy_res)


def compute_frechet_distance(
mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
) -> torch.Tensor:
"""The Frechet distance between multivariate normal distributions."""
diff = mu_x - mu_y
covmean, _ = _sqrtm_newton_schulz(sigma_x.mm(sigma_y))

# If calculation produces singular product, epsilon is added to diagonal of cov estimates
covmean = _sqrtm(sigma_x.mm(sigma_y))

# Product might be almost singular
if not torch.isfinite(covmean).all():
print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates")
offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
covmean, _ = _sqrtm_newton_schulz((sigma_x + offset).mm(sigma_y + offset))
covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset))

# Numerical error might give slight imaginary component
if torch.is_complex(covmean):
if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3):
raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.")
covmean = covmean.real

tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean
2 changes: 1 addition & 1 deletion tests/test_compute_fid_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_results(self):
x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
results = FIDMetric()(x, y)
np.testing.assert_allclose(results.cpu().numpy(), 0.4433, atol=1e-4)
np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4)

def test_input_dimensions(self):
with self.assertRaises(ValueError):
Expand Down