In [1]:
from PIL import Image
import numpy as np
from torchvision import models, transforms
import torch
from scipy.linalg import sqrtm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Paths to folders containing images
real_images_folder = 'C:/Users/Arjit/OneDrive/Desktop/MTP/CODE/Color Diffusion/Color-diffusion/test_images/100_gt'
fake_images_folder = 'C:/Users/Arjit/OneDrive/Desktop/MTP/CODE/Color Diffusion/Color-diffusion/results/100_output'

In [3]:
# Load images and preprocess them
def load_images_from_folder(folder_path, model):
    preprocess = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    features_list = []
    for filename in os.listdir(folder_path):
        image_path = os.path.join(folder_path, filename)
        image = Image.open(image_path).convert('RGB')
        image_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
        
        with torch.no_grad():
            # Ensure the model is in evaluation mode and extract features
            model.eval()
            features = model(image_tensor)
            features_list.append(features.squeeze().numpy())  # Flatten the feature vector
            
    return np.array(features_list)

In [4]:
# Calculate FID
def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    return ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

In [5]:
# Load model
inception_model = models.inception_v3(pretrained=True, transform_input=False)
inception_model.fc = torch.nn.Identity()  # Remove the last classification layer
inception_model.AuxLogits.fc = torch.nn.Identity()  # Remove auxiliary classification head

In [8]:
# Extract features from all images in both folders
real_features = load_images_from_folder(real_images_folder, inception_model)
fake_features = load_images_from_folder(fake_images_folder, inception_model)

In [9]:
# Calculate FID score
fid_score = calculate_fid(real_features, fake_features)
print("FID Score:", fid_score)

FID Score: 152.4346608544729
