In [1]:
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import shuffle
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets.mnist import load_data
from skimage.transform import resize
from keras.datasets import cifar10

In [2]:
# scale an array of images to a new size
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
    # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
    # store
        images_list.append(new_image)
    return asarray(images_list)

In [3]:
# calculate frechet inception distance
def calculate_fid(model, images1, images2):
 # calculate activations
    act1 = model.predict(images1)
    act2 = model.predict(images2)
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = numpy.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [4]:
# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

In [5]:
# load cifar10 images
(images1, _), (images2, _) = cifar10.load_data()
# shuffle(images1)
images1 = images1[:100]
images2 = images2[:100]
print('Loaded', images1.shape, images2.shape)


Loaded (100, 32, 32, 3) (100, 32, 32, 3)


In [6]:
# convert integer to floating point values
images1 = images1.astype('float32')
images2 = images2.astype('float32')
# resize images
images1 = scale_images(images1, (299,299,3))
images2 = scale_images(images2, (299,299,3))
print('Scaled', images1.shape, images2.shape)
# pre-process images
images1 = preprocess_input(images1)
images2 = preprocess_input(images2)
# calculate fid
fid = calculate_fid(model, images1, images2)
print('FID: %.3f' % fid)

Scaled (100, 299, 299, 3) (100, 299, 299, 3)
FID: -50195695344141386936283010094984474693163556275381340093755477783822163555845052925539609043009536.000


In [7]:
# load cifar10 images
(images3, _), (images4, _) = cifar10.load_data()
shuffle(images3)
images3 = images1[:100]
images4 = images2[:100]
print('Loaded', images3.shape, images4.shape)

Loaded (100, 299, 299, 3) (100, 299, 299, 3)


In [8]:
images3 = images3.astype('float32')
images4 = images4.astype('float32')
# resize images
images3 = scale_images(images3, (299,299,3))
images3 = scale_images(images4, (299,299,3))
print('Scaled', images3.shape, images3.shape)
# pre-process images
images3 = preprocess_input(images3)
images4 = preprocess_input(images4)
# calculate fid
fid2 = calculate_fid(model, images3, images4)
print('FID: %.3f' % fid2)

Scaled (100, 299, 299, 3) (100, 299, 299, 3)
FID: -3911109074562213285635415721746006069218819315838597601221389662760412095545606475872122961920.000


In [9]:
print(fid2-fid)

5.0191784235066825e+97
