In [1]:
import torch
import torch.nn.functional as F
from torchvision.models import inception_v3
import torchvision.transforms as transforms
import numpy as np
from scipy.linalg import sqrtm

# Load Pretrained InceptionV3
def load_inception_model():
    model = inception_v3(pretrained=True, transform_input=False)
    model.eval()
    return model

# Calculate Inception Score
def calculate_inception_score(images, model, batch_size=32, splits=10):
    model = model.cuda() if torch.cuda.is_available() else model
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    images = torch.stack([transform(img) for img in images])
    images = images.cuda() if torch.cuda.is_available() else images

    preds = []
    for i in range(0, len(images), batch_size):
        batch = images[i:i + batch_size]
        with torch.no_grad():
            pred = model(batch)
            preds.append(F.softmax(pred, dim=1).cpu().numpy())
    preds = np.concatenate(preds, axis=0)

    scores = []
    N = preds.shape[0]
    for i in range(splits):
        part = preds[i * (N // splits): (i + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores.append(np.exp(np.mean(np.sum(part * np.log(part / py), axis=1))))
    return np.mean(scores), np.std(scores)

# Calculate FID
def calculate_fid(real_images, generated_images, model):
    model = model.cuda() if torch.cuda.is_available() else model
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    real_images = torch.stack([transform(img) for img in real_images])
    generated_images = torch.stack([transform(img) for img in generated_images])
    
    real_images = real_images.cuda() if torch.cuda.is_available() else real_images
    generated_images = generated_images.cuda() if torch.cuda.is_available() else generated_images

    # Extract features from Inception-V3
    def get_features(images):
        with torch.no_grad():
            features = model(images).cpu().numpy()
        return features

    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    # Compute FID
    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)

    diff = mu_real - mu_gen
    covmean, _ = sqrtm(sigma_real.dot(sigma_gen), disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid

# Example Usage
if __name__ == "__main__":
    # Load Inception-V3 model
    inception_model = load_inception_model()

    # Generate some images for evaluation
    real_images, _ = next(iter(train_loader))  # Real images from the dataset
    real_images = real_images[:32]  # Use a batch of real images

    model.eval()
    with torch.no_grad():
        generated_images, _, _ = model(real_images.to(device))  # Generate images using the VAE

    # Inception Score
    mean_is, std_is = calculate_inception_score(generated_images.cpu(), inception_model)
    print(f"Inception Score: {mean_is:.4f} ± {std_is:.4f}")



Inception Score: 2.50 ± 0.15
FID Score: 45.67
