In [16]:
import torch
import torchvision.transforms as transforms
from torchvision.models import inception_v3
import numpy as np
from PIL import Image
import os
from scipy.linalg import sqrtm

In [17]:
def get_features(directory, model, device, transform):
    features = []
    
    for image_name in os.listdir(directory):
        image_path = os.path.join(directory, image_name)
        if os.path.isfile(image_path): 
            try:
                image = Image.open(image_path).convert("RGB")
                image = transform(image).unsqueeze(0).to(device)
                with torch.no_grad():
                    pred = model(image)
                    # Make sure to extract the feature vector before any final classification layers
                    if pred.shape[1] != 2048:  # Check if the model outputs the expected 2048-d feature vector
                        pred = pred[:, :2048]
                    features.append(pred.cpu().numpy().flatten())  # Flatten to avoid shape issues
            except Exception as e:
                print(f"Failed to process {image_name}: {e}")
    return np.array(features)

def calculate_fid(real_features, fake_features):
    if real_features.size == 0 or fake_features.size == 0:
        raise ValueError("One of the feature arrays is empty. Check your data loading and feature extraction steps.")
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    # covmean = np.linalg.sqrtm(sigma1.dot(sigma2))
    
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [None]:
def main():
    # Paths to your real and fake images
    real_images_path = '/home/idu675/projects/Thesis/Dreambooth/temp_instance'
    fake_images_path = '/home/idu675/projects/Thesis/Dreambooth/outputs_filtered'
    
    # Image transformations
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load Inception model
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()  # Modify the model to return features directly
    model.to(device)
    model.eval()

    # Get features
    real_features = get_features(real_images_path, model, device, transform)
    fake_features = get_features(fake_images_path, model, device, transform)
    
    # Calculate FID
    # Calculate FID
    try:
        fid_value = calculate_fid(real_features, fake_features)
        print(f'FID score: {fid_value}')
    except ValueError as e:
        print(e)

if __name__ == "__main__":
    main()

Using device: cuda
FID score: 312.1872743845984


In [None]:
# outputs2_filtered: FID score: 309.0241913849678
# outputs_filtered: 312.1872743845984
