In [1]:
pip install pytorch-fid

You should consider upgrading via the '/sciclone/data10/jrhee01/genVision-celebA/gen-env/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import subprocess

In [3]:
# Base paths
real_source = "../celebA/celeba/img_align_celeba"
real_10k = "../celebA/celeba/real_10000"

generators = {
    "VAE": "../vae_outputs/generated",
    "GAN": "../gan_outputs/generated",
    "WGAN":"../wgan_outputs/generated"
}


In [4]:
import os
import zipfile

# Base paths
zip_path = "../celebA/celeba/img_align_celeba.zip"
real_10k = "../celebA/celeba/real_10000"

print("Extracting first 10k images from ZIP...")
os.makedirs(real_10k, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    all_files = sorted([f for f in zip_ref.namelist() if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    first_10k = all_files[:10000]
    for file in first_10k:
        zip_ref.extract(member=file, path=real_10k)

print(f"Extracted {len(first_10k)} images to {real_10k}\n")



Extracting first 10k images from ZIP...
Extracted 10000 images to ../celebA/celeba/real_10000



In [5]:
# Compute FID scores
fid_scores = {}
real_10k = "/sciclone/data10/jrhee01/genVision-celebA/celebA/celeba/real_10000/img_align_celeba"
for name, path in generators.items():
    print(path)
    print(f"Computing FID for {name}...")
    result = subprocess.run(
        ["python", "-m", "pytorch_fid",real_10k, path],
        capture_output=True,
        text=True
    )
    # Extract score from output
    for line in result.stdout.splitlines():
        if "FID:" in line:
            score = float(line.split("FID:")[-1].strip())
            fid_scores[name] = score
            print(f"{name} FID: {score:.2f}")
            break
    else:
        print(f"Could not extract FID for {name}. Raw output:\n{result.stdout}\n")

../vae_outputs/generated
Computing FID for VAE...
VAE FID: 136.62
../gan_outputs/generated
Computing FID for GAN...
GAN FID: 95.14
../wgan_outputs/generated
Computing FID for WGAN...
WGAN FID: 96.21


In [6]:
# Print final comparison
print("\n--- FID Comparison (Lower is Better) ---")
print(fid_scores)
for name, score in sorted(fid_scores.items(), key=lambda x: x[1]):
    print(f"{name}: {score:.2f}")



--- FID Comparison (Lower is Better) ---
{'VAE': 136.6219206039531, 'GAN': 95.14026159529655, 'WGAN': 96.21181184858494}
GAN: 95.14
WGAN: 96.21
VAE: 136.62
