In [1]:
from utils.fid import FrechetInceptionDistance
from models.conv_generator import ConvGenerator
from data_loaders.mnist import MnistDataLoaderFactory
from dotmap import DotMap
from tqdm.notebook import tqdm
import os
import torch
from torchvision.models.inception import inception_v3

In [2]:
config = DotMap()
config.data = DotMap(root="/tmp/data", batch_size=128)
config.device = "cuda"

In [3]:
fid = FrechetInceptionDistance(config.device)
dl = MnistDataLoaderFactory.get_data_loader(config)

In [4]:
model_files = [f"gen{i}.p" for i in range(1, 4)] + [f"gen_large{i}.p" for i in [1, 2, 4]]
models = []
for fname in model_files:
    model = ConvGenerator(config)
    model.load(os.path.join("trained_models", fname))
    model.eval()
    model.cuda()
    models.append(model)

Get 30 batches (about 3800 samples)

In [6]:
di = iter(dl)
batches = [next(di) for _ in range(30)]

In [7]:
scores = []
for model, path in zip(models, model_files):
    for X_true in tqdm(batches):
        X_fake = model.generate_batch(X_true.shape[0], config.device)
        fid.add_batch(X_true.clone(), X_fake.clone())
    score = fid.calculate()
    print(f"{path}: {score:.4f}")

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen1.p: 339.3789


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen2.p: 331.7495


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen3.p: 354.6193


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen_large1.p: 351.5028


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen_large2.p: 329.3862


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


gen_large4.p: 337.5021


Sanity test

In [9]:
for b in batches:
    fid.add_batch(b, torch.randn_like(b))
fid.calculate()

552.9860000841278