In [2]:
%matplotlib inline
import numpy as np

import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

from fjd.fjd_metric import FJDMetric
from fjd.embeddings import OneHotEmbedding, InceptionEmbedding

In [4]:
def get_dataloaders():

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                              std=(0.5, 0.5, 0.5))])

    train_set = CIFAR10(root='./datasets/cifar10',
                        train=True,
                        download=True,
                        transform=transform)

    test_set = CIFAR10(root='./datasets/cifar10',
                       train=False,
                       download=True,
                       transform=transform)

    train_loader = DataLoader(train_set,
                              batch_size=128,
                              shuffle=True,
                              drop_last=False)

    test_loader = DataLoader(test_set,
                             batch_size=128,
                             shuffle=False,
                             drop_last=False)

    return train_loader, test_loader

In [7]:
class SuspiciouslyGoodGAN(torch.nn.Module):
    def __init__(self):
        super(SuspiciouslyGoodGAN, self).__init__()

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5, 0.5))])

        test_set = CIFAR10(root='./datasets/cifar10',
                           train=False,
                           download=True,
                           transform=transform)

        test_loader = DataLoader(test_set,
                                 batch_size=128,
                                 shuffle=True,
                                 drop_last=False,
                                 num_workers=2)

        self.test_loader = test_loader
        self.data_iter = iter(test_loader)

    def forward(self, z, y):
        # Normally a GAN would actually do something with z and y, but for this fake GAN we ignore them
        try:
            samples, _ = next(self.data_iter)
        except StopIteration:
            # Reset dataloader if it runs out of samples
            self.data_iter = iter(self.test_loader)
            samples, _ = next(self.data_iter)
        samples = samples.cuda()
        return samples
    
class GANWrapper:
    def __init__(self, model, model_checkpoint=None):
        self.model = model
        
        if model_checkpoint is not None:
            self.model_checkpoint = model_checkpoint
            self.load_model()

    def load_model(self):
        # self.model.eval()  # uncomment to put in eval mode if desired
        self.model = self.model.cuda()

        state_dict = torch.load(self.model_checkpoint)
        self.model.load_state_dict(state_dict)

    def get_noise(self, batch_size):
        # change the noise dimension as required
        z = torch.cuda.FloatTensor(batch_size, 128).normal_()
        return z

    def __call__(self, y):
        batch_size = y.size(0)
        z = self.get_noise(batch_size)
        samples = self.model(z, y)
        return samples

In [8]:
num_exp = 10
fid_holder = []
fjd_holder = []

for exp in range(num_exp):
    train_loader, test_loader = get_dataloaders()
    inception_embedding = InceptionEmbedding(parallel=False)
    onehot_embedding = OneHotEmbedding(num_classes=10)
    gan = SuspiciouslyGoodGAN()
    gan = GANWrapper(gan)

    fjd_metric = FJDMetric(gan=gan,
                           reference_loader=train_loader,
                           condition_loader=test_loader,
                           image_embedding=inception_embedding,
                           condition_embedding=onehot_embedding,
                           reference_stats_path='datasets/cifar_train_stats.npz',
                           save_reference_stats=True,
                           samples_per_condition=1,
                           cuda=True)

    fid = fjd_metric.get_fid()
    fjd = fjd_metric.get_fjd()
    print('FID: ', fid)
    print('FJD: ', fjd)
    
    fid_holder.append(fid)
    fjd_holder.append(fjd)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.98it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.01it/s]


FID:  3.0738064945389283
FJD:  33.965788149474065
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.23it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.00it/s]


FID:  3.073805515196341
FJD:  33.50166298711497
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.02it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.05it/s]


FID:  3.0738055765946797
FJD:  33.340849819642244
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.13it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:27<00:00,  2.86it/s]


FID:  3.073808072888255
FJD:  33.536421016947315
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.25it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  2.95it/s]


FID:  3.0738064424569984
FJD:  33.88407957157597
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:27<00:00,  2.83it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.01it/s]


FID:  3.0738070825161685
FJD:  33.55603697452034
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.18it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.06it/s]


FID:  3.0738062021222845
FJD:  33.88526554185387
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.12it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:25<00:00,  3.06it/s]


FID:  3.0738067278910535
FJD:  34.31001823168526
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.23it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:26<00:00,  3.01it/s]


FID:  3.0738061789299422
FJD:  33.85215219951624
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Computing generated distribution:   0%|                                                         | 0/79 [00:00<?, ?it/s]

Loading reference statistics from datasets/cifar_train_stats.npz


Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:24<00:00,  3.23it/s]
Computing generated distribution: 100%|████████████████████████████████████████████████| 79/79 [00:29<00:00,  2.66it/s]


FID:  3.073806576938182
FJD:  33.47814025618459


In [9]:
import statistics
fid_mean = sum(fid_holder) / len(fid_holder)
fid_var = sum((x-fid_mean)**2 for x in fid_holder) / len(fid_holder)
fid_stdev = fid_var**0.5
fid_se = fid_stdev / len(fid_holder)
print('fid mean and se', fid_mean, '+-', fid_se)

import statistics
fid_mean = sum(fjd_holder) / len(fjd_holder)
fid_var = sum((x-fid_mean)**2 for x in fjd_holder) / len(fjd_holder)
fid_stdev = fid_var**0.5
fid_se = fid_stdev / len(fjd_holder)
print('fjd mean and se', fid_mean, '+-', fid_se)

fid mean and se 3.073806487007283 +- 6.990809700766761e-08
fjd mean and se 33.73104147485149 +- 0.028098404907149994
