# Evaluation Metrics - DreamBooth vs Baseline

Tính FID, LPIPS, SSIM cho DreamBooth models và so sánh với baseline.

**Models đánh giá:**
- Baseline: SD v1.5 gốc
- DreamBooth Contemporary_Realism
- DreamBooth New_Realism

**Metrics:**
- FID: Generated images vs Style images (WikiArt)
- LPIPS: Generated images vs Style images (perceptual similarity)
- SSIM: Generated images vs Content images (structure preservation)


## Setup


In [None]:
import os
import torch

if not torch.cuda.is_available():
    print("WARNING: No GPU detected!")
else:
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
import json
from pathlib import Path
from typing import Dict, List
import pandas as pd
import numpy as np
from tqdm import tqdm

import lpips
from PIL import Image
from diffusers import StableDiffusionPipeline
from torchvision import transforms as T
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.functional.image.ssim import structural_similarity_index_measure

class StableDiffusionTransform:
    def __init__(self, size=256, center_crop=True):
        self.transform = T.Compose([
            T.Resize(size),
            T.CenterCrop(size) if center_crop else T.RandomCrop(size),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])
    def __call__(self, img):
        return self.transform(img)

class FIDEvaluator:
    def __init__(self, device="cpu"):
        self.metric = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
        self.device = device
    def update_real(self, images):
        images = (images * 0.5 + 0.5).clamp(0, 1)
        self.metric.update(images.to(self.device), real=True)
    def update_fake(self, images):
        images = (images * 0.5 + 0.5).clamp(0, 1)
        self.metric.update(images.to(self.device), real=False)
    def compute(self):
        return float(self.metric.compute())

def compute_lpips(img1, img2, model):
    img1 = (img1 * 0.5 + 0.5).clamp(0, 1)
    img2 = (img2 * 0.5 + 0.5).clamp(0, 1)
    return model(img1, img2).mean()

def compute_ssim(img1, img2):
    img1 = (img1 * 0.5 + 0.5).clamp(0, 1)
    img2 = (img2 * 0.5 + 0.5).clamp(0, 1)
    return structural_similarity_index_measure(img1, img2, data_range=1.0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

RESULTS_DIR = Path("/kaggle/working/results/metrics")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

STYLE_CHECKPOINT_MAP = {
    "New_Realism": "/kaggle/input/real1k/dreambooth_checkpoints",
    "Contemporary_Realism": "/kaggle/input/priorimages/dreambooth_checkpoints",
}

STYLE_IMAGES_DIRS = [
    "/kaggle/input/real1k/dreambooth/New_Realism/instance_images",
    "/kaggle/input/priorimages/dreambooth/Contemporary_Realism/instance_images",
]

CONTENT_IMAGES_DIRS = [
    "/kaggle/input/coco-2017-dataset/coco2017/val2017",
    "/kaggle/input/coco2017/val2017",
]

STYLES = ["Contemporary_Realism", "New_Realism"]
UNIQUE_TOKEN = "sks"
RESOLUTION = 256
NUM_SAMPLES = 20


ModuleNotFoundError: No module named 'src'

## Load Models


In [None]:
print("Loading baseline model...")
baseline_pipeline = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    safety_checker=None,
    requires_safety_checker=False,
)
if torch.cuda.is_available():
    baseline_pipeline = baseline_pipeline.to(device)
print("Baseline loaded")

dreambooth_pipelines = {}
for style_name in STYLES:
    if style_name not in STYLE_CHECKPOINT_MAP:
        print(f"No checkpoint path mapped for {style_name}")
        continue
    
    checkpoint_base = Path(STYLE_CHECKPOINT_MAP[style_name])
    checkpoint_path = checkpoint_base / style_name
    
    if not checkpoint_base.exists():
        print(f"Dataset not found: {checkpoint_base}")
        continue
    
    if not checkpoint_path.exists():
        print(f"Checkpoint not found: {checkpoint_path}")
        continue
    
    print(f"Loading {style_name} from {checkpoint_base}...")
    try:
        pipeline = StableDiffusionPipeline.from_pretrained(
            str(checkpoint_path),
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            safety_checker=None,
            requires_safety_checker=False,
        )
        if torch.cuda.is_available():
            pipeline = pipeline.to(device)
        dreambooth_pipelines[style_name] = pipeline
        print(f"{style_name} loaded")
    except Exception as e:
        print(f"Error loading {style_name}: {e}")

models = {
    "baseline": baseline_pipeline,
    **{f"dreambooth_{s.lower().replace('_', '')}": dreambooth_pipelines[s] 
       for s in STYLES if s in dreambooth_pipelines}
}
print(f"\nModels loaded: {list(models.keys())}")


## Load Reference Images


In [None]:
style_images = {}
content_images = []

for style_name in STYLES:
    for style_dir_path in STYLE_IMAGES_DIRS:
        style_dir = Path(style_dir_path)
        if style_dir.exists() and style_name in str(style_dir):
            style_files = list(style_dir.glob("*.jpg"))[:NUM_SAMPLES]
            if not style_files:
                style_files = list(style_dir.glob("*.png"))[:NUM_SAMPLES]
            if style_files:
                style_images[style_name] = [Image.open(f).convert("RGB") for f in style_files]
                print(f"Loaded {len(style_images[style_name])} style images for {style_name} from {style_dir}")
                break

for base_dir in CONTENT_IMAGES_DIRS:
    content_dir = Path(base_dir)
    if content_dir.exists() and content_dir.is_dir():
        content_files = list(content_dir.glob("*.jpg"))[:NUM_SAMPLES]
        if content_files:
            content_images = [Image.open(f).convert("RGB") for f in content_files]
            print(f"Loaded {len(content_images)} content images from {content_dir}")
            break

transform = StableDiffusionTransform(size=RESOLUTION, center_crop=True)

print(f"\nStyle images: {sum(len(v) for v in style_images.values())} total")
print(f"Content images: {len(content_images)}")


## Generate Samples


In [None]:
test_prompts = [
    f"a {UNIQUE_TOKEN} style painting of a cat",
    f"a {UNIQUE_TOKEN} style painting of a landscape",
    f"a {UNIQUE_TOKEN} style painting of a portrait",
    f"a {UNIQUE_TOKEN} style painting of a cityscape",
    f"a {UNIQUE_TOKEN} style painting of flowers",
]

baseline_prompts = [
    "a painting of a cat",
    "a painting of a landscape",
    "a painting of a portrait",
    "a painting of a cityscape",
    "a painting of flowers",
]

generated_samples = {}

print("Generating samples from baseline...")
baseline_samples = []
for prompt in baseline_prompts:
    image = baseline_pipeline(
        prompt,
        num_inference_steps=50,
        guidance_scale=7.5,
        height=RESOLUTION,
        width=RESOLUTION,
    ).images[0]
    baseline_samples.append(image)
generated_samples["baseline"] = baseline_samples
print(f"Baseline: {len(baseline_samples)} samples")

for style_name in STYLES:
    if style_name in dreambooth_pipelines:
        model_key = f"dreambooth_{style_name.lower().replace('_', '')}"
        pipeline = dreambooth_pipelines[style_name]
        print(f"\nGenerating samples from {style_name}...")
        style_samples = []
        for prompt in test_prompts:
            image = pipeline(
                prompt,
                num_inference_steps=50,
                guidance_scale=7.5,
                height=RESOLUTION,
                width=RESOLUTION,
            ).images[0]
            style_samples.append(image)
        generated_samples[model_key] = style_samples
        print(f"{style_name}: {len(style_samples)} samples")

print(f"\nTotal generated samples: {sum(len(v) for v in generated_samples.values())}")


## Compute Metrics


In [None]:
lpips_model = lpips.LPIPS(net="vgg").to(device)

def compute_metrics_for_model(
    model_name: str,
    generated_images: List[Image.Image],
    style_images_list: List[Image.Image],
    content_images_list: List[Image.Image],
) -> Dict[str, float]:
    transform = StableDiffusionTransform(size=RESOLUTION, center_crop=True)
    
    generated_tensors = torch.stack([
        transform(img).to(device) for img in generated_images
    ])
    
    if len(style_images_list) > 0:
        style_tensors = torch.stack([
            transform(img).to(device) for img in style_images_list[:len(generated_images)]
        ])
    else:
        style_tensors = generated_tensors
    
    if len(content_images_list) > 0:
        content_tensors = torch.stack([
            transform(img).to(device) for img in content_images_list[:len(generated_images)]
        ])
    else:
        content_tensors = generated_tensors
    
    fid_metric = FIDEvaluator(device=device)
    fid_metric.update_real(style_tensors)
    fid_metric.update_fake(generated_tensors)
    fid_score = fid_metric.compute()
    
    lpips_scores = []
    for gen, style in zip(generated_tensors, style_tensors):
        lpips_score = compute_lpips(
            gen.unsqueeze(0),
            style.unsqueeze(0),
            model=lpips_model,
        )
        lpips_scores.append(lpips_score.item())
    lpips_mean = np.mean(lpips_scores)
    
    ssim_scores = []
    for gen, content in zip(generated_tensors, content_tensors):
        ssim_score = compute_ssim(
            gen.unsqueeze(0),
            content.unsqueeze(0),
        )
        ssim_scores.append(ssim_score.item())
    ssim_mean = np.mean(ssim_scores)
    
    return {
        "fid": fid_score,
        "lpips": lpips_mean,
        "ssim": ssim_mean,
        "lpips_std": np.std(lpips_scores),
        "ssim_std": np.std(ssim_scores),
    }

all_metrics = {}

for model_name, samples in generated_samples.items():
    print(f"\nComputing metrics for {model_name}...")
    
    style_ref = style_images.get(STYLES[0], []) if "dreambooth" in model_name else []
    
    metrics = compute_metrics_for_model(
        model_name,
        samples,
        style_ref,
        content_images,
    )
    all_metrics[model_name] = metrics
    print(f"  FID: {metrics['fid']:.4f}")
    print(f"  LPIPS: {metrics['lpips']:.4f} ± {metrics['lpips_std']:.4f}")
    print(f"  SSIM: {metrics['ssim']:.4f} ± {metrics['ssim_std']:.4f}")


## Results and Comparison


In [None]:
results_df = pd.DataFrame(all_metrics).T
results_df.index.name = "model"
results_df = results_df.reset_index()

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(results_df.to_string(index=False))
print("\n" + "="*60)

results_path = RESULTS_DIR / "dreambooth_metrics.csv"
results_df.to_csv(results_path, index=False)
print(f"\nResults saved to: {results_path}")

json_path = RESULTS_DIR / "dreambooth_metrics.json"
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(all_metrics, f, indent=2)
print(f"Results saved to: {json_path}")

print("\nComparison with Baseline:")
baseline_fid = all_metrics.get("baseline", {}).get("fid", 0)
baseline_lpips = all_metrics.get("baseline", {}).get("lpips", 0)
baseline_ssim = all_metrics.get("baseline", {}).get("ssim", 0)

for model_name, metrics in all_metrics.items():
    if model_name != "baseline":
        fid_improvement = baseline_fid - metrics["fid"]
        lpips_improvement = baseline_lpips - metrics["lpips"]
        ssim_change = metrics["ssim"] - baseline_ssim
        
        print(f"\n{model_name}:")
        print(f"  FID: {metrics['fid']:.4f} (vs baseline: {fid_improvement:+.4f})")
        print(f"  LPIPS: {metrics['lpips']:.4f} (vs baseline: {lpips_improvement:+.4f})")
        print(f"  SSIM: {metrics['ssim']:.4f} (vs baseline: {ssim_change:+.4f})")
