In [0]:
%%capture
!pip install --upgrade tensorflow

In [22]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [0]:
import os
os.chdir('/content/drive/My Drive/Colab Notebooks')

In [0]:
# example of calculating the frechet inception distance in Keras for cifar10
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import shuffle
from scipy.linalg import sqrtm
from numpy.random import randn
from matplotlib import pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from tensorflow.keras.datasets.mnist import load_data
from skimage.transform import resize
from tensorflow.keras.backend import clear_session

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# scale an array of images to a new size
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
        # store
        images_list.append(new_image)
    return asarray(images_list)

# calculate frechet inception distance
def calculate_fid(model, images1, images2):
    # calculate activations
    act1 = model.predict(images1)
    act2 = model.predict(images2)
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = numpy.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

# create and save a plot of generated images (reversed grayscale)
def show_plot(examples, n):
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
    pyplot.show()

# clear all previously created graphs
clear_session()
# number of samples on which the distance will be calculated
n_samples = 1000
# prepare the inception v3 model
model = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))
# load mnist images
images1= load_data()[0][0][:n_samples]
# convert integer to floating point values
images1 = images1.astype('float32')
# resize images
images1 = scale_images(images1, (299,299,3))
# pre-process images
images1 = preprocess_input(images1)
# keep track of fid history
fid_hist = []
# iterate through the models we trained
for i in range(1, 101):
    # generate mnist numbers
    generator = load_model('models/MNIST/GAN_MNIST_model_%03d.h5' % i, compile=False)
    # generate images
    latent_points = generate_latent_points(100, n_samples)
    # generate images
    images2 = generator.predict(latent_points)
    # convert integer to floating point values
    images2 = images2.astype('float32')
    # resize images
    images2 = scale_images(images2, (299,299,3))
    # pre-process images
    images2 = preprocess_input(images2)
    # calculate fid
    fid = calculate_fid(model, images1, images2)
    # print('FID for model_%03d : %.3f' % (i, fid))
    fid_hist.append(fid)
    plt.plot(fid_hist)
    # plt.axis([0, 110, 0, 400])
    plt.ylabel('FID')
    plt.xlabel('Epochs')
    plt.savefig('plots/FID_%d_samples.png' % n_samples)
    plt.close()

In [32]:
print(fid_hist)

[389.45694582912495, 384.8925490520161, 385.2617214262209, 388.1819369414252, 388.1839745369258, 388.18454831763074, 388.18346718873966, 388.1423260139047, 340.06124635453216, 343.6196230649147, 352.931364721209, 358.516517257843, 360.2973640415116, 354.51914360667604, 353.3590365891872, 355.34431419178276, 349.53365770176543, 356.5207139594594, 347.10956018573904, 351.8141131395299, 353.01980428143054, 353.6966207749706, 349.89548367550896, 352.70888759194884, 350.3579765124342, 351.73816375542674, 349.3326634157017, 349.00561619479333, 350.0432122204528, 351.14297143799786, 349.21804996029607, 351.19728844377045, 353.6578360870109, 351.39712841480355, 349.2091314067069, 352.12134050921213, 348.5835871359563, 351.7883172402925, 349.7278358303629, 352.0752339585203, 353.13288451458664, 351.69666291264366, 352.07140798048704, 351.3678258340575, 349.21550157046397, 352.85670003303335, 349.586482359961, 352.8462614155872, 351.33360386989654, 350.9421817362412, 353.3704974327634, 353.04155