In [None]:
import torch
import torch.nn as nn
from pytorch_fid.inception import InceptionV3
from ignite.metrics import FID

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# wrapper class as feature_extractor
class WrapperInceptionV3(nn.Module):

    def __init__(self, fid_incv3):
        super().__init__()
        self.fid_incv3 = fid_incv3

    @torch.no_grad()
    def forward(self, x):
        y = self.fid_incv3(x)
        y = y[0]
        y = y[:, :, 0, 0]
        return y

# use cpu rather than cuda to get comparable results
device = "cpu"

# pytorch_fid model
dims = 2048
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]).to(device)

# wrapper model to pytorch_fid model
wrapper_model = WrapperInceptionV3(model)
wrapper_model.eval();

# comparable metric
pytorch_fid_metric = FID(num_features=dims, feature_extractor=wrapper_model)

In [None]:
def eval_step(engine, batch):
    return batch
default_evaluator = Engine(eval_step)

In [None]:
pytorch_fid_metric.attach(default_evaluator, "fid")
y_true = torch.ones(64, 3, 128, 128)
y_pred = torch.ones(64, 3, 128, 128)
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics["fid"])

-2.2289691212538154e-13
