In [12]:
import torch
from torchvision import datasets, transforms
import numpy as np
from scipy.linalg import sqrtm

# Helper function to calculate FID score
def calculate_fid_score(images_real, images_fake, batch_size=64):
    mu_real = np.mean(images_real, axis=0)
    sigma_real = np.cov(images_real, rowvar=False)

    mu_fake = np.mean(images_fake, axis=0)
    sigma_fake = np.cov(images_fake, rowvar=False)

    # Calculate the square root of the product of covariances
    cov_sqrt_real_fake = sqrtm(sigma_real @ sigma_fake)

    # Calculate the FID score
    fid = np.linalg.norm(mu_real - mu_fake) + np.trace(sigma_real + sigma_fake - 2 * cov_sqrt_real_fake)
    return fid

# Download and load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# Generate a reference dataset (e.g., using a GAN or any other method)
# Here, we'll just use random noise as an example.
fake_images = np.random.rand(len(mnist_dataset), 28 * 28)

# Flatten the real images before computing FID
real_images_flattened = mnist_dataset.data.numpy().reshape(len(mnist_dataset), -1)

# Calculate FID score
fid_score = calculate_fid_score(real_images_flattened, real_images_flattened)
print(f'FID Score: {fid_score:.2f}')


FID Score: -0.00-0.00j
