In [2]:
import torch
import numpy as np
from scipy.linalg import sqrtm
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import os

In [3]:
def list_all_img_files(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if "png" in str(file).lower() or "jpg" in str(file).lower():
                file_paths.append(os.path.join(root, file))
    return file_paths

In [4]:
def list_all_img_files(folder_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png"):
    real_image_paths = []
    generated_image_paths = []
    
    for file in os.listdir(folder_path):
        if file.endswith(real_postfix):
            real_image_paths.append(os.path.join(folder_path, file))
        elif file.endswith(fake_postfix):
            generated_image_paths.append(os.path.join(folder_path, file))
    
    return sorted(real_image_paths), sorted(generated_image_paths)

In [5]:
def calculate_fid(real_images, generated_images, batch_size=50, device='cpu'):
    # Load pre-trained Inception v3 model
    inception = models.inception_v3(pretrained=True, transform_input=False).to(device)
    inception.eval()

    # Resize and normalize images to match Inception v3 requirements
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    def get_features(images):
        features = []
        with torch.no_grad():
            for img in images:
                img = transform(img).unsqueeze(0).to(device)
                feat = inception(img)[0].view(img.size(0), -1).cpu().numpy()
                features.append(feat)
        return np.vstack(features)
    
    # Extract features for real and generated images
    real_features = get_features(real_images)
    generated_features = get_features(generated_images)
    
    # Calculate mean and covariance for real and generated features
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_generated, sigma_generated = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    
    # Calculate FID using the formula:
    # FID = ||mu_real - mu_generated||^2 + Tr(sigma_real + sigma_generated - 2 * sqrt(sigma_real * sigma_generated))
    diff = mu_real - mu_generated
    covmean, _ = sqrtm(sigma_real.dot(sigma_generated), disp=False)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_generated - 2 * covmean)
    return fid

## Map2sat

In [26]:
result_path = "./results/map2sat_pretrained_pix2pix/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 1296.8016


In [27]:
result_path = "./results/map2sat_pretrained_CycleGAN/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_A_real.png", fake_postfix="_A_fake.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 1008.9865


In [28]:
result_path = "./results/map2sat_pureUNET/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 1547.2033


## facades_label2photo

In [29]:
result_path = "./results/facades_label2photo_pretrained_pix2pix/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 532.8256


In [30]:
result_path = "./results/facades_label2photo_pretrained_CycleGAN/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real.png", fake_postfix="_fake.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 1043.1319


In [31]:
result_path = "./results/facades_label2photo_pureUNET/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 799.9782


## Windtunnel

In [37]:
result_path = "./results/windtunnel_pix2pix/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 612.1227


In [38]:
result_path = "./results/windtunnel_CycleGAN/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_A.png", fake_postfix="_fake_A.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 655.1222


In [40]:
result_path = "./results/windtunnel_pureUNET/test_latest/images"

real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_A.png", fake_postfix="_fake_B.png")

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')

FID score: 963.1893


## Experimental

In [6]:
# result_path = "./results/windtunnel_pureUNET/test_latest/images"
# real_image_paths, generated_image_paths = list_all_img_files(result_path, real_postfix="_real_B.png", fake_postfix="_fake_B.png")

real_image_paths = ["experimental/windtunnel_20240415_125028_real_B.png"]
generated_image_paths = ["experimental/windtunnel_20240415_125028_real_B_gray.png"]

generated_image_paths = real_image_paths[::2]

real_images = [Image.open(image_path) for image_path in real_image_paths]
generated_images = [Image.open(image_path) for image_path in generated_image_paths]

# Calculate FID score
fid_score = calculate_fid(real_images, generated_images)
print(f'FID score: {fid_score:.4f}')



ValueError: Non-matrix input to matrix function.