In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import numpy as np
from scipy.linalg import sqrtm

In [None]:
# Define a custom dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_tensors):
        self.image_tensors = image_tensors

    def __len__(self):
        return len(self.image_tensors)

    def __getitem__(self, idx):
        return self.image_tensors[idx], 0  # Dummy label

# Function to extract features using InceptionV3
def get_features(dataloader, model, device):
    model.eval()
    features = []
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            pred = model(images)[0]  
            features.append(pred.cpu().numpy())
    features = np.concatenate(features, axis=0)
    return features

# Function to calculate FID
def calculate_fid(real_features, generated_features):
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
        
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid


In [None]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load real images and generated images (replace with your actual images)
real_images = ...  # List or tensor of real images
generated_images = ...  # List or tensor of generated images

# Transform and create datasets
real_images = torch.stack([transform(img) for img in real_images])
generated_images = torch.stack([transform(img) for img in generated_images])

real_dataset = CustomImageDataset(real_images)
generated_dataset = CustomImageDataset(generated_images)

real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False)
generated_loader = DataLoader(generated_dataset, batch_size=32, shuffle=False)

# Load pre-trained InceptionV3 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inception_model = models.inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.fc = torch.nn.Identity()  # Remove the final classification layer

# Extract features
real_features = get_features(real_loader, inception_model, device)
generated_features = get_features(generated_loader, inception_model, device)

# Calculate FID
fid_score = calculate_fid(real_features, generated_features)
print(f"FID Score: {fid_score}")