In [1]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch_fidelity
from torchmetrics.image.fid import FrechetInceptionDistance

In [2]:
def save_cifar10_test_images(
    root_dir="cifar10_test_images",
    resize_to=None
):
    """
    Downloads CIFAR-10 test set, saves each image to a separate PNG in `root_dir`.
    Optionally resizes images before saving.
    """
    os.makedirs(root_dir, exist_ok=True)

    base_transforms = []
    if resize_to is not None:
        base_transforms.append(transforms.Resize((resize_to, resize_to)))
    base_transforms.append(transforms.ToTensor())
    transform = transforms.Compose(base_transforms)

    cifar10_test = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )
    print("CIFAR-10 test set downloaded.")

    # Save each image to disk
    for idx, (img_tensor, label) in enumerate(cifar10_test):
        # Convert from tensor back to PIL for saving
        img_pil = F.to_pil_image(img_tensor)

        # Build a filename like "img_00000_label_1.png"
        # or just "img_00000.png" if you prefer
        filename = f"img_{idx:05d}_label_{label}.png"
        filepath = os.path.join(root_dir, filename)
        img_pil.save(filepath)

        if idx % 1000 == 0:
            print(f"Saved {idx} images...")

    print("All CIFAR-10 test images saved!")


In [3]:
# save_cifar10_test_images(root_dir="cifar10_test_images", resize_to=32)

In [4]:
class GeneratedDataset(Dataset):
    """
    A simple Dataset to load generated images from a folder.
    """
    def __init__(self, folder, transform=None):
        super().__init__()
        self.folder = folder
        self.img_files = sorted(
            [f for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        )
        self.transform = transform

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.folder, self.img_files[idx])
        img = Image.open(img_path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img

In [5]:
def compute_fid_for_cifar10_test(gen_folder, batch_size):

    transform = transforms.Compose([
        transforms.Resize((32, 32)), 
        transforms.ToTensor(),
    ])

    real_dataset = torchvision.datasets.CIFAR10(
        root="./cifar10_test_images",
        train=False,
        download=True,
        transform=transform
    )
    real_loader = DataLoader(real_dataset, batch_size=batch_size, shuffle=False)

    gen_dataset = GeneratedDataset(folder=gen_folder, transform=transform)
    gen_loader = DataLoader(gen_dataset, batch_size=batch_size, shuffle=False)

    if torch.cuda.is_available():
        print("running on cuda")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    fid = FrechetInceptionDistance(feature=2048).to(device)
    
    for real_imgs, _ in real_loader:
        real_imgs = (real_imgs * 255).to(torch.uint8)
        real_imgs = real_imgs.to(device)
        fid.update(real_imgs, real=True)
    
    for fake_imgs in gen_loader:
        fake_imgs = (fake_imgs * 255).to(torch.uint8)
        fake_imgs = fake_imgs.to(device)
        fid.update(fake_imgs, real=False)
    
    score = fid.compute()
    print(f"FID score: {score.item():.4f}")



In [6]:
gen_folder_path = "cifar10_test_images"
compute_fid_for_cifar10_test(
    gen_folder=gen_folder_path, 
    batch_size=2048
)

# FID score = 0 means the output is exactly the same as testing dataset

Files already downloaded and verified
running on cuda
FID score: -0.0000


In [7]:
gen_folder_path = "noise"

compute_fid_for_cifar10_test(
    gen_folder=gen_folder_path,
    batch_size=2048
)

# 458 for image gen from epoch 0-10, almost pure noise

Files already downloaded and verified
running on cuda
FID score: 458.4973


In [8]:
gen_folder_path = "cifar10_samples"

compute_fid_for_cifar10_test(
    gen_folder=gen_folder_path,
    batch_size=2048
)

# 61 is an acceptable value - not good, not bad.

Files already downloaded and verified
running on cuda
FID score: 61.7527
