# Assesing realism of generated images

In [1]:
# pip install torch torchvision numpy scipy torch-fidelity 

import torch
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import pandas as pd

class ImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

def get_inception_features(dataloader, device):
    """Extract deep features from Inception-v3. This is the classic model used for FID score calculation."""
    model = inception_v3(pretrained=True, transform_input=False).to(device)
    model.fc = torch.nn.Identity()  # remove last layer
    model.eval()

    features = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            preds = model(batch)
            features.append(preds.cpu().numpy())

    return np.concatenate(features, axis=0)

def calculate_fid(real_features, generated_features):
    """Compute FID Score"""
    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)

    diff = mu_real - mu_gen
    cov_sqrt, _ = sqrtm(sigma_real @ sigma_gen, disp=False)

    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real

    fid = diff @ diff + np.trace(sigma_real + sigma_gen - 2 * cov_sqrt)
    return fid

def compute_fid_for_categories(real_root, generated_root, categories, batch_size=32):
    """Compute FID for each category"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    # this is what the inception model expects: pretrained Inception-v3 model expects 299×299 pixel input images.
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    fid_scores = {}

    for category in categories:
        real_path = os.path.join(real_root, category)
        generated_path = os.path.join(generated_root, category)

        if not os.path.exists(real_path) or not os.path.exists(generated_path):
            print(f"Skipping category {category}: Missing directory.")
            continue

        real_dataset = ImageDataset(real_path, transform=transform)
        generated_dataset = ImageDataset(generated_path, transform=transform)

        real_loader = DataLoader(real_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        generated_loader = DataLoader(generated_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

        real_features = get_inception_features(real_loader, device)
        generated_features = get_inception_features(generated_loader, device)

        fid_score = calculate_fid(real_features, generated_features)
        fid_scores[category] = fid_score

    return fid_scores


In [None]:
# Per augmentation category:

# change this accordingly: SOS
real_root = "dataset/train"  # your path to ONLY real images -> might need to do some seperation with how our current code works
generated_root = "dataset/train"  #your path to ONLY generated augmented images

#subdirs
categories = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]

augmentation_pipeline = "data_manipulation"  # Example, change this for each method

# Run FID calculation for each category
fid_results = compute_fid_for_categories(real_root, generated_root, categories)

# Convert results to DataFrame and display
fid_df = pd.DataFrame(fid_results.items(), columns=["Category", "FID Score"])

# Corrected file name formatting
fid_df.to_csv(f"fid_scores_{augmentation_pipeline}.csv", index=False)
print(fid_df)