In [2]:
import os
from PIL import Image
import torch
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance

def load_and_transform_image(image_path, transform):
    img = Image.open(image_path).convert('RGB')
    img = transform(img)
    return img

def load_first_n_images_from_folder(folder, transform, num_images=1000):
    all_filenames = [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
    all_filenames.sort()
    selected_filenames = all_filenames[:num_images]
    images = [load_and_transform_image(filename, transform) for filename in selected_filenames]
    return torch.stack(images)

transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x * 255).byte())
])

In [8]:
# Load real and fake images from folders
# real_images_folder = r'C:\Users\user\Desktop\baza\healthy'
# fake_images_folder = r'C:\Users\user\Desktop\studia\magisterka\ProGAN\latent_walking'

real_images_folder = r'real_images'
fake_images_folder = r'fake_images'

real_images = load_first_n_images_from_folder(real_images_folder, transform, num_images=800)
fake_images = load_first_n_images_from_folder(fake_images_folder, transform, num_images=800)

In [9]:
fid = FrechetInceptionDistance(feature=64)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
fid_score = fid.compute()
print(f'FID score: {fid_score}')

FID score: 0.050199978053569794


In [10]:
fid = FrechetInceptionDistance(feature=192)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
fid_score = fid.compute()
print(f'FID score: {fid_score}')

FID score: 0.2659568190574646


In [11]:
fid = FrechetInceptionDistance(feature=768)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
fid_score = fid.compute()
print(f'FID score: {fid_score}')

FID score: 0.15625116229057312


In [12]:
fid = FrechetInceptionDistance(feature=2048)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
fid_score = fid.compute()
print(f'FID score: {fid_score}')

FID score: 66.962646484375
