<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Calculating_the_Fr%C3%A9chet_Inception_Distance_(FID)_score_for_a_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
from scipy.linalg import sqrtm
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define FID calculation function
def calculate_fid(real_features, fake_features):
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    # Compute sum of squared differences between means
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    # Compute sqrt of product of covariance matrices
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Final FID computation
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

# Use InceptionV3 to extract features
class InceptionV3FeatureExtractor(nn.Module):
    def __init__(self):
        super(InceptionV3FeatureExtractor, self).__init__()
        self.inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
        self.inception.fc = nn.Identity()  # Remove the final classification layer

    def forward(self, x):
        # Extract features from the final pool layer before the classification head
        x = self.inception(x)
        return x

# Initialize feature extractor model
feature_extractor = InceptionV3FeatureExtractor().to(device).eval()

# Transform for the images
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.Grayscale(num_output_channels=3),  # Convert MNIST to 3 channels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load real MNIST images
real_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False)

# Fake image generation and transformation
fake_images = torch.randn(1000, 1, 28, 28)  # Replace with real generated images
fake_images = fake_images.repeat(1, 3, 1, 1)  # 3-channel grayscale
fake_images = torch.nn.functional.interpolate(fake_images, size=(299, 299))  # Resize fake images
fake_dataset = TensorDataset(fake_images)
fake_loader = DataLoader(fake_dataset, batch_size=32, shuffle=False)

# Debug print statements to verify shapes
print("Fake images shape:", fake_images.shape)
print("Real loader batch shape:", next(iter(real_loader))[0].shape)
print("Fake loader batch shape:", next(iter(fake_loader))[0].shape)

# Extract features from DataLoader
def extract_features(loader, model):
    features = []
    for batch in tqdm(loader, desc="Extracting features"):
        images = batch[0].to(device)

        # Debug print to check image shape
        print("Batch image shape:", images.shape)

        with torch.no_grad():
            preds = model(images).cpu().numpy()
        features.append(preds)
    features = np.concatenate(features, axis=0)
    return features

# Extract features for real and fake datasets
real_features = extract_features(real_loader, feature_extractor)
fake_features = extract_features(fake_loader, feature_extractor)

# Calculate FID score
fid_score = calculate_fid(real_features, fake_features)
print(f"FID score: {fid_score}")