/
fid.py
300 lines (241 loc) · 13.1 KB
/
fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, List, Optional, Union
import numpy as np
import torch
from torch import Tensor
from torch.autograd import Function
from torch.nn import Module
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_info
from torchmetrics.utilities.imports import _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
else:
class FeatureExtractorInceptionV3(Module): # type: ignore
pass
__doctest_skip__ = ["FrechetInceptionDistance", "FID"]
if _SCIPY_AVAILABLE:
import scipy
class NoTrainInceptionV3(FeatureExtractorInceptionV3):
def __init__(
self,
name: str,
features_list: List[str],
feature_extractor_weights_path: Optional[str] = None,
) -> None:
super().__init__(name, features_list, feature_extractor_weights_path)
# put into evaluation mode
self.eval()
def train(self, mode: bool) -> "NoTrainInceptionV3":
"""the inception network should not be able to be switched away from evaluation mode."""
return super().train(False)
def forward(self, x: Tensor) -> Tensor:
out = super().forward(x)
return out[0].reshape(x.shape[0], -1)
class MatrixSquareRoot(Function):
"""Square root of a positive definite matrix.
All credit to `Square Root of a Positive Definite Matrix`_
"""
@staticmethod
def forward(ctx: Any, input_data: Tensor) -> Tensor:
# TODO: update whenever pytorch gets an matrix square root function
# Issue: https://github.com/pytorch/pytorch/issues/9983
m = input_data.detach().cpu().numpy().astype(np.float_)
scipy_res, _ = scipy.linalg.sqrtm(m, disp=False)
sqrtm = torch.from_numpy(scipy_res.real).to(input_data)
ctx.save_for_backward(sqrtm)
return sqrtm
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Tensor:
grad_input = None
if ctx.needs_input_grad[0]:
(sqrtm,) = ctx.saved_tensors
sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)
gm = grad_output.data.cpu().numpy().astype(np.float_)
# Given a positive semi-definite matrix X,
# since X = X^{1/2}X^{1/2}, we can compute the gradient of the
# matrix square root dX^{1/2} by solving the Sylvester equation:
# dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).
grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm)
grad_input = torch.from_numpy(grad_sqrtm).to(grad_output)
return grad_input
sqrtm = MatrixSquareRoot.apply
def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor, eps: float = 1e-6) -> Tensor:
r"""Adjusted version of `Fid Score`_
The Frechet Inception Distance between two multivariate Gaussians X_x ~ N(mu_1, sigm_1)
and X_y ~ N(mu_2, sigm_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(sigm_1 + sigm_2 - 2*sqrt(sigm_1*sigm_2)).
Args:
mu1: mean of activations calculated on predicted (x) samples
sigma1: covariance matrix over activations calculated on predicted (x) samples
mu2: mean of activations calculated on target (y) samples
sigma2: covariance matrix over activations calculated on target (y) samples
eps: offset constant - used if sigma_1 @ sigma_2 matrix is singular
Returns:
Scalar value of the distance between sets.
"""
diff = mu1 - mu2
covmean = sqrtm(sigma1.mm(sigma2))
# Product might be almost singular
if not torch.isfinite(covmean).all():
rank_zero_info(f"FID calculation produces singular product; adding {eps} to diagonal of covariance estimates")
offset = torch.eye(sigma1.size(0), device=mu1.device, dtype=mu1.dtype) * eps
covmean = sqrtm((sigma1 + offset).mm(sigma2 + offset))
tr_covmean = torch.trace(covmean)
return diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) - 2 * tr_covmean
class FrechetInceptionDistance(Metric):
r"""Calculates Fréchet inception distance (FID_) which is used to access the quality of generated images. Given
by.
.. math::
FID = |\mu - \mu_w| + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})
where :math:`\mathcal{N}(\mu, \Sigma)` is the multivariate normal distribution estimated from Inception v3
(`fid ref1`_) features calculated on real life images and :math:`\mathcal{N}(\mu_w, \Sigma_w)` is the
multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images.
The metric was originally proposed in `fid ref1`_.
Using the default feature extraction (Inception v3 using the original weights from `fid ref2`_), the input is
expected to be mini-batches of 3-channel RGB images of shape ``(3 x H x W)``. If argument ``normalize``
is ``True`` images are expected to be dtype ``float`` and have values in the ``[0, 1]`` range, else if
``normalize`` is set to ``False`` images are expected to have dtype ``uint8`` and take values in the ``[0, 255]``
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian
flag ``real`` determines if the images should update the statistics of the real distribution or the
fake distribution.
.. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install
torchmetrics[image]`` or ``pip install scipy``
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
As input to ``forward`` and ``update`` the metric accepts the following input
- ``imgs`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with
- ``real`` (:class:`~bool`): bool indicating if ``imgs`` belong to the real or the fake distribution
As output of `forward` and `compute` the metric returns the following output
- ``fid`` (:class:`~torch.Tensor`): float scalar tensor with mean FID value over samples
Args:
feature:
Either an integer or ``nn.Module``:
- an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following:
64, 192, 768, 2048
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
an ``(N,d)`` matrix where ``N`` is the batch size and ``d`` is the feature size.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed
ValueError:
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
TypeError:
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
ValueError:
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> fid = FrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> fid.update(imgs_dist1, real=True)
>>> fid.update(imgs_dist2, real=False)
>>> fid.compute()
tensor(12.7202)
"""
higher_is_better: bool = False
is_differentiable: bool = False
full_state_update: bool = False
real_features_sum: Tensor
real_features_cov_sum: Tensor
real_features_num_samples: Tensor
fake_features_sum: Tensor
fake_features_cov_sum: Tensor
fake_features_num_samples: Tensor
def __init__(
self,
feature: Union[int, Module] = 2048,
reset_real_features: bool = True,
normalize: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if isinstance(feature, int):
num_features = feature
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)
self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, Module):
self.inception = feature
dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")
if not isinstance(reset_real_features, bool):
raise ValueError("Argument `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features
if not isinstance(normalize, bool):
raise ValueError("Argument `normalize` expected to be a bool")
self.normalize = normalize
mx_nb_feets = (num_features, num_features)
self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("real_features_cov_sum", torch.zeros(mx_nb_feets).double(), dist_reduce_fx="sum")
self.add_state("real_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")
self.add_state("fake_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum")
self.add_state("fake_features_cov_sum", torch.zeros(mx_nb_feets).double(), dist_reduce_fx="sum")
self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum")
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
"""Update the state with extracted features."""
imgs = (imgs * 255).byte() if self.normalize else imgs
features = self.inception(imgs)
self.orig_dtype = features.dtype
features = features.double()
if features.dim() == 1:
features = features.unsqueeze(0)
if real:
self.real_features_sum += features.sum(dim=0)
self.real_features_cov_sum += features.t().mm(features)
self.real_features_num_samples += imgs.shape[0]
else:
self.fake_features_sum += features.sum(dim=0)
self.fake_features_cov_sum += features.t().mm(features)
self.fake_features_num_samples += imgs.shape[0]
def compute(self) -> Tensor:
"""Calculate FID score based on accumulated extracted features from the two distributions."""
mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0)
mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0)
cov_real_num = self.real_features_cov_sum - self.real_features_num_samples * mean_real.t().mm(mean_real)
cov_real = cov_real_num / (self.real_features_num_samples - 1)
cov_fake_num = self.fake_features_cov_sum - self.fake_features_num_samples * mean_fake.t().mm(mean_fake)
cov_fake = cov_fake_num / (self.fake_features_num_samples - 1)
return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(self.orig_dtype)
def reset(self) -> None:
if not self.reset_real_features:
real_features_sum = deepcopy(self.real_features_sum)
real_features_cov_sum = deepcopy(self.real_features_cov_sum)
real_features_num_samples = deepcopy(self.real_features_num_samples)
super().reset()
self.real_features_sum = real_features_sum
self.real_features_cov_sum = real_features_cov_sum
self.real_features_num_samples = real_features_num_samples
else:
super().reset()