In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
from scipy import linalg
from tqdm import tqdm
from pytorch_fid import fid_score
from torchvision import transforms
from lpips import LPIPS
import os
from PIL import Image

class FaceGenerationBenchmark:
    def __init__(self, real_images_path, gan_generated_path, sdxl_generated_path, batch_size=32):
        """
        Initialize the benchmarking framework
        
        Args:
            real_images_path: Path to directory containing real face images
            gan_generated_path: Path to directory containing GAN-generated faces
            sdxl_generated_path: Path to directory containing SDXL-generated faces
            batch_size: Batch size for processing images
        """
        self.real_images_path = real_images_path
        self.gan_generated_path = gan_generated_path
        self.sdxl_generated_path = sdxl_generated_path
        self.batch_size = batch_size
        
        # Initialize LPIPS model
        self.lpips_model = LPIPS(net='alex').cuda()
        
        # Define image transforms
        self.transform = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
    
    def load_images(self, path):
        """Load and preprocess images from directory"""
        images = []
        for img_name in os.listdir(path):
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(path, img_name)
                img = Image.open(img_path).convert('RGB')
                img = self.transform(img)
                images.append(img)
        return torch.stack(images)
    
    def calculate_fid(self, real_features, generated_features):
        """Calculate FID score between real and generated image features"""
        mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)
        
        ssdiff = np.sum((mu1 - mu2) ** 2)
        covmean = linalg.sqrtm(sigma1.dot(sigma2))
        
        if np.iscomplexobj(covmean):
            covmean = covmean.real
            
        fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
        return fid
    
    def calculate_lpips(self, images1, images2):
        """Calculate average LPIPS distance between two sets of images"""
        total_distance = 0
        num_pairs = min(len(images1), len(images2))
        
        with torch.no_grad():
            for i in range(0, num_pairs, self.batch_size):
                batch1 = images1[i:i+self.batch_size].cuda()
                batch2 = images2[i:i+self.batch_size].cuda()
                distance = self.lpips_model(batch1, batch2)
                total_distance += distance.sum().item()
        
        return total_distance / num_pairs
    
    def run_benchmark(self):
        """Run complete benchmarking suite"""
        print("Loading images...")
        real_images = self.load_images(self.real_images_path)
        gan_images = self.load_images(self.gan_generated_path)
        sdxl_images = self.load_images(self.sdxl_generated_path)
        
        # Calculate FID scores
        print("Calculating FID scores...")
        inception = torchvision.models.inception_v3(pretrained=True, transform_input=False).cuda()
        inception.eval()
        
        def get_features(images):
            features = []
            with torch.no_grad():
                for i in range(0, len(images), self.batch_size):
                    batch = images[i:i+self.batch_size].cuda()
                    feat = inception(batch)[0].squeeze()
                    features.append(feat.cpu().numpy())
            return np.concatenate(features)
        
        real_features = get_features(real_images)
        gan_features = get_features(gan_images)
        sdxl_features = get_features(sdxl_images)
        
        gan_fid = self.calculate_fid(real_features, gan_features)
        sdxl_fid = self.calculate_fid(real_features, sdxl_features)
        
        # Calculate LPIPS scores
        print("Calculating LPIPS scores...")
        gan_lpips = self.calculate_lpips(real_images, gan_images)
        sdxl_lpips = self.calculate_lpips(real_images, sdxl_images)
        
        results = {
            'GAN': {
                'FID': gan_fid,
                'LPIPS': gan_lpips
            },
            'SDXL': {
                'FID': sdxl_fid,
                'LPIPS': sdxl_lpips
            }
        }
        
        return results

def generate_benchmark_report(results):
    """Generate a formatted report of benchmark results"""
    report = "Face Generation Model Benchmark Results\n"
    report += "====================================\n\n"
    
    for model, metrics in results.items():
        report += f"{model} Model:\n"
        report += f"  FID Score: {metrics['FID']:.4f}\n"
        report += f"  LPIPS Score: {metrics['LPIPS']:.4f}\n\n"
    
    # Add interpretation
    report += "Interpretation:\n"
    report += "- Lower scores are better for both metrics\n"
    report += "- FID measures overall distribution similarity\n"
    report += "- LPIPS measures perceptual similarity\n"
    
    return report