# Model evaluation using Frechet Inception Distance and pixelwise similarity

This notebook will evaluate the results of our different GANs using Frechet Inception Distance (FID score). The FID score compares the activations of real images in a trained network versus the activations of generated images. A lower score represents higher quality of generated data.

This notebook consults code from `https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/`

## Import libraries

In [1]:
import tensorflow as tf
import glob
import numpy as np

# Define directory paths for model results

In [32]:
def load(image_file, real=False):
    '''
    load in a jpeg file containing both the input image and the target imaage
    '''
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    image = tf.cast(image, tf.float32)

    # our test targets are stored as input/target pairs
    if real:
        w = image.shape[1] // 2
        image = image[:,:w,:]

    # normalize the pixel values
    image = (image / 127.5) - 1 

    # resize to 299,299,3 for inception net
    image = tf.image.resize(image, [299, 299])

    return image

In [3]:
def load_dataset(path, real=False):
    out = []
    for img in glob.glob(path):
        out.append(load(img, real))
    return np.asarray(out)

In [23]:
real_images = load_dataset('../data/sketch_data/test/*.jpg',real=True)

In [38]:
results = {}
results['conditional_gan_256'] = load_dataset('./predictions/cgan*.jpg')
results['cyclegan_256'] = load_dataset('./predictions/cyclegan*.jpg')
results['resnet_scaled'] = load_dataset('./predictions/resnet128*.jpg')

In [25]:
real_images.shape

(500, 299, 299, 3)

In [39]:
len(results['resnet_scaled'])

366

## From the paper "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium" the equation for FID score is:

d^2 = ||mu_1 – mu_2||^2 + Tr(C_1 + C_2 – 2*sqrt(C_1*C_2))

In [49]:
# load pre trained inception v3 network to be used for FID score
from keras.applications.inception_v3 import InceptionV3
inception = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

from scipy.linalg import sqrtm

def frechet_inception_distance(dist1, dist2):
    '''
    calculates the fid of two distributions
    '''
    # get final layer activations in inception net
    activations1 = inception.predict(dist1)
    activations2 = inception.predict(dist2)
    
    # get sum squared difference of average activationss
    activation_diff = (activations1.mean(axis = 0) - activations2.mean(axis = 0)) ** 2
    activation_diff = np.sum(activation_diff)
    
    # get covariances of activation layers
    covariance1 = np.cov(activations1,rowvar=False)
    covariance2 = np.cov(activations2,rowvar=False)
    
    covariance_term = covariance1 + covariance2 - 2*sqrtm(covariance1.dot(covariance2))
    # get sum along diagonals
    covariance_term = np.trace(covariance_term)
    
    return activation_diff + covariance_term

In [None]:
def pixel_similarity(dist1,dist2):
    '''
    computes  average difference between pixel values of two distributions
    '''
    return abs(dist1.mean(axis=0)-dist2.mean(axis=0)).mean()

## We will calculate the FID score between all of our generated distributions and the 500 original test images

In [44]:
metrics = {}

In [50]:
for name, generated_images in results.items():
    scores = {}
    scores['pixel_similarity'] = pixel_similarity(real_images,generated_images)
    scores['FID'] = frechet_inception_distance(real_images, generated_images)
    metrics[name] = scores
    print("Pixel similarity for",name,':',scores['pixel_similarity'])
    print("FID score for",name,":",scores['FID'])
    print("************")

Pixel similarity for conditional_gan_256 : 0.90040755
FID score for conditional_gan_256 : (142.43949031085748-1.7850260437925806e-06j)
************
Pixel similarity for cyclegan_256 : 0.9593999
FID score for cyclegan_256 : (143.1787594240364-1.1949177051162446e-06j)
************
Pixel similarity for resnet_scaled : 0.90431577
FID score for resnet_scaled : (153.3962406367699-1.4728668388100988e-06j)
************
