In [27]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
import math
from tqdm import tqdm
import numpy as np
import scipy

In [52]:
batch_size = 30

real_images = image_dataset_from_directory(
    directory="/weka/proj-fmri/shared/coco/sampled_imgs",
    label_mode=None,  # since you're not doing classification
    shuffle=False,
    batch_size=batch_size)

generated_images = image_dataset_from_directory(
    directory="/weka/proj-fmri/shared/coco/sampled_imgs",
    label_mode=None,
    shuffle=False,
    batch_size=batch_size)

Found 30000 files belonging to 1 classes.
Found 30000 files belonging to 1 classes.


In [51]:
tf.keras.backend.set_floatx('float64')
inception_model = tf.keras.applications.InceptionV3(include_top=False, 
                              weights="imagenet", 
                              pooling='avg')

In [55]:
def compute_embeddings(dataloader, count):
    image_embeddings = []
    for _ in range(count):
        images = next(iter(dataloader))
        embeddings = inception_model.predict(images)
        image_embeddings.extend(embeddings)
    return np.array(image_embeddings)
    
count = math.ceil(10000/batch_size)



real_image_embeddings = compute_embeddings(real_images, count)


# compute embeddings for generated images
generated_image_embeddings = compute_embeddings(generated_images, count)
real_image_embeddings = tf.cast(real_image_embeddings, tf.float64)
generated_image_embeddings = tf.cast(generated_image_embeddings, tf.float64)

real_image_embeddings.shape, generated_image_embeddings.shape




(TensorShape([10020, 2048]), TensorShape([10020, 2048]))

In [65]:
def calculate_fid(real_embeddings, generated_embeddings):

    # Calculate mean and covariance statistics using TensorFlow operations
    mu1, mu2 = tf.reduce_mean(real_embeddings, axis=0), tf.reduce_mean(generated_embeddings, axis=0)

    # Convert TensorFlow tensors to NumPy arrays for covariance calculation
    real_embeddings_np = real_embeddings.numpy()
    generated_embeddings_np = generated_embeddings.numpy()
    sigma1, sigma2 = np.cov(real_embeddings_np, rowvar=False), np.cov(generated_embeddings_np, rowvar=False)

    # Calculate sum of squared differences between means
    ssdiff = np.sum((mu1 - mu2)**2.0)

    # Calculate sqrt of product between covariances
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    # Check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Ensure positive semi-definite property for numerical stability
    covmean = (covmean + covmean.T) / 2

    # Calculate FID
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid


In [66]:
fid = calculate_fid(real_image_embeddings, generated_image_embeddings)

In [67]:
fid

-0.09471308621463312