In [19]:
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [37]:
def show_img(img: Image, size = 3, title=None) -> None:
    plt.figure(figsize=(size, size))
    plt.imshow(img)
    plt.axis('off')
    if title != None:
        plt.title(title)
    plt.show()

def show_tensor(ts: torch.Tensor, size = 3, title=None) -> None:
    """ts: (B, C, H, W) or (C, H, W)"""
    if len(ts.shape) == 4:
        img = make_grid(ts, nrow=1, padding=0).permute(1, 2, 0)
    elif len(ts.shape) == 3:
        img = ts.permute(1, 2, 0)
    else:
        raise ValueError(f"ts should be (B, C, H, W) or (C, H, W), but got {ts.shape}")
    show_img(img, size, title)

def tensor_to_image(ts: torch.Tensor) -> Image:
    img_np = ts.squeeze().permute(1, 2, 0).cpu().numpy()
    img_np = (img_np * 255).astype(np.uint8)  # Rescale back to 0-255 range
    return Image.fromarray(img_np)

In [6]:
def load_images(image_folder, batch_size, transforms):
    dataset = ImageFolder(root=image_folder, transform=transforms)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader

def calculate_fid(real_images_loader, generated_images_loader, fid):
    for real_batch, _ in real_images_loader:
        fid.update(real_batch, real=True)

    for gen_batch, _ in generated_images_loader:
        fid.update(gen_batch, real=False)

    fid_score = fid.compute()
    return fid_score

In [116]:
image_size = 299  # InceptionV3 input size
batch_size = 32
transforms = Compose([Resize((image_size, image_size)), ToTensor()])
                    #   Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

real_path = "../output/images/real"
generated_path = "../output/images/generated"
real_images_loader = load_images(real_path, batch_size, transforms)
generated_images_loader = load_images(generated_path, batch_size, transforms)

In [117]:
torch.manual_seed(0)
fid = FrechetInceptionDistance(feature=64, normalize=True).set_dtype(torch.float64)
fid_score = calculate_fid(real_images_loader, generated_images_loader, fid)
print(f"FID Score: {fid_score}")

FID Score: 4.100490124266012
