# Evaluation Metrics - LoRA vs Baseline (FINAL)

Đánh giá LoRA (img2img) bằng bộ metrics CLIP cuối cùng để đo content preservation và style quality.

**Metrics (FINAL):**
1. **CLIP-content (chính)**: `1 - cos_sim(clip(output), clip(content))` – đo semantic content preservation.
2. **Style Strength Score (phụ)**: `CLIP-content / baseline_CLIP-content` – đo mức độ thay đổi style so với baseline.
3. **CLIP-style (chính)**: `1 - cos_sim(clip(output), clip(style_reference))` – đo mức độ giống style reference.

**Models đánh giá:**
- Baseline: SD v1.5 gốc (img2img)
- LoRA: Action_painting, Analytical_Cubism, Contemporary_Realism, New_Realism, Synthetic_Cubism (5 styles)

Notebook này chỉ chạy LoRA để tránh OOM. Chạy `04b_Evaluation_Metrics_DreamBooth_TI_FINAL.ipynb` (sẽ tạo sau) để đánh giá DreamBooth và TI.


## Setup


In [None]:
%pip install -q lpips torchmetrics[image] torch-fidelity protobuf==3.20.3 ftfy regex git+https://github.com/openai/CLIP.git


In [None]:
import os
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*cuFFT.*")
warnings.filterwarnings("ignore", message=".*cuDNN.*")
warnings.filterwarnings("ignore", message=".*cuBLAS.*")

import torch

if not torch.cuda.is_available():
    print("WARNING: No GPU detected!")
else:
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
import json
import random
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
from tqdm import tqdm

from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from torchvision import transforms as T
import clip

LORA_STYLES = ["Action_painting", "Analytical_Cubism", "Contemporary_Realism", "New_Realism", "Synthetic_Cubism"]
RESOLUTION = 256
NUM_SAMPLES = 20
IMG2IMG_STRENGTH = 0.5
IMG2IMG_GUIDANCE = 7.5
MAX_CONTENT_SAMPLES = 8
NUM_STYLE_IMAGES = 10
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

class StableDiffusionTransform:
    def __init__(self, size=256, center_crop=True):
        self.transform = T.Compose([
            T.Resize(size),
            T.CenterCrop(size) if center_crop else T.RandomCrop(size),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])
    def __call__(self, img):
        return self.transform(img)

def load_style_images(style_name: str, num_images: int = NUM_STYLE_IMAGES) -> List[Image.Image]:
    style_dir = WIKIART_DIR / style_name
    if not style_dir.exists():
        print(f"Warning: Style directory not found: {style_dir}")
        return []
    
    image_files = sorted(list(style_dir.glob("*.jpg")) + list(style_dir.glob("*.png")))
    if not image_files:
        print(f"Warning: No images found in {style_dir}")
        return []
    
    selected_files = image_files[:min(num_images, len(image_files))]
    style_images = [Image.open(f).convert("RGB").resize((RESOLUTION, RESOLUTION), Image.BILINEAR) 
                    for f in selected_files]
    
    return style_images

def compute_clip_distance(img_a, img_b):
    with torch.no_grad():
        tensor_a = clip_preprocess(img_a).unsqueeze(0).to(device)
        tensor_b = clip_preprocess(img_b).unsqueeze(0).to(device)
        
        feat_a = clip_model.encode_image(tensor_a)
        feat_b = clip_model.encode_image(tensor_b)
        
        feat_a = feat_a / feat_a.norm(dim=-1, keepdim=True)
        feat_b = feat_b / feat_b.norm(dim=-1, keepdim=True)
        
        cos_sim = (feat_a * feat_b).sum(dim=-1)
        return (1.0 - cos_sim.item())


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

def setup_pipeline(pipeline):
    if torch.cuda.is_available():
        pipeline = pipeline.to(device)
    pipeline.enable_attention_slicing()
    return pipeline

RESULTS_DIR = Path("/kaggle/working/results/metrics")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

DATASET_CACHE_DIR = Path("/kaggle/input/styleandcontent")
DATASET_CONTENT_PATH = DATASET_CACHE_DIR / "content_paths.json"
DATASET_STYLE_PATH = DATASET_CACHE_DIR / "style_paths.json"

LOCAL_CONTENT_CACHE = RESULTS_DIR / "content_paths.json"
LOCAL_STYLE_CACHE = RESULTS_DIR / "style_paths.json"

CONTENT_PATHS_FILE = DATASET_CONTENT_PATH if DATASET_CONTENT_PATH.exists() else LOCAL_CONTENT_CACHE
STYLE_PATHS_FILE = DATASET_STYLE_PATH if DATASET_STYLE_PATH.exists() else LOCAL_STYLE_CACHE

CONTENT_PATHS_WRITE_TARGET = LOCAL_CONTENT_CACHE
STYLE_PATHS_WRITE_TARGET = LOCAL_STYLE_CACHE

LORA_DATASET_MAP = {
    "Action_painting": "/kaggle/input/dts-lora-actionpainting",
    "Analytical_Cubism": "/kaggle/input/dts-lora-analyticalcubism",
    "Contemporary_Realism": "/kaggle/input/dts-lora-contemporaryrealism",
    "New_Realism": "/kaggle/input/dts-lora-newrealism",
    "Synthetic_Cubism": "/kaggle/input/dts-lora-syntheticcubism",
}

CONTENT_IMAGES_DIRS = [
    "/kaggle/input/coco-2017-dataset/coco2017/val2017",
    "/kaggle/input/coco2017/val2017",
]

WIKIART_DIR = Path("/kaggle/input/wikiart")
if not WIKIART_DIR.exists():
    WIKIART_DIR = Path("/kaggle/input/steubk/wikiart")


## Load Models


In [None]:
print("Loading CLIP model...")
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_model.eval()
print("CLIP loaded")

print("Loading baseline img2img model...")
baseline_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=TORCH_DTYPE,
    safety_checker=None,
    requires_safety_checker=False,
)
baseline_pipeline = setup_pipeline(baseline_pipeline)
print("Baseline loaded")

lora_pipelines = {}

for style_name in LORA_STYLES:
    if style_name not in LORA_DATASET_MAP:
        print(f"No LoRA dataset path mapped for {style_name}")
        continue
    
    lora_base = Path(LORA_DATASET_MAP[style_name])
    lora_weights_path = lora_base / "lora_models" / style_name / "pytorch_lora_weights.safetensors"
    
    if not lora_base.exists():
        print(f"LoRA dataset not found: {lora_base}")
        continue
    
    if not lora_weights_path.exists():
        print(f"LoRA weights not found: {lora_weights_path}")
        continue
    
    print(f"Loading LoRA {style_name} from {lora_base}...")
    try:
        pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=TORCH_DTYPE,
            safety_checker=None,
            requires_safety_checker=False,
        )
        pipeline.load_lora_weights(str(lora_weights_path.parent))
        pipeline = setup_pipeline(pipeline)
        lora_pipelines[style_name] = pipeline
        print(f"LoRA {style_name} loaded")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception as e:
        print(f"Error loading LoRA {style_name}: {e}")

models = {
    "baseline": baseline_pipeline,
    **{f"lora_{s.lower().replace('_', '')}": lora_pipelines[s] 
       for s in LORA_STYLES if s in lora_pipelines}
}
print(f"\nModels loaded: {list(models.keys())}")


In [None]:
def _load_or_create_content_paths(max_items: int) -> List[str]:
    paths: List[str] = []
    if CONTENT_PATHS_FILE.exists():
        try:
            with open(CONTENT_PATHS_FILE, "r", encoding="utf-8") as f:
                stored = json.load(f)
            candidate = stored.get("content_paths", [])
            paths = [p for p in candidate if Path(p).exists()]
            if len(paths) < max_items:
                print("Stored content paths are incomplete. Regenerating.")
                paths = []
        except Exception as exc:
            print(f"Failed to read {CONTENT_PATHS_FILE}: {exc}. Regenerating content paths.")
            paths = []
    if not paths:
        print("Generating deterministic content paths...")
        for base_dir in CONTENT_IMAGES_DIRS:
            content_dir = Path(base_dir)
            if not (content_dir.exists() and content_dir.is_dir()):
                continue
            files = sorted(list(content_dir.glob("*.jpg")) + list(content_dir.glob("*.png")))
            if len(files) >= max_items:
                paths = [str(p.resolve()) for p in files[:max_items]]
                with open(CONTENT_PATHS_WRITE_TARGET, "w", encoding="utf-8") as f:
                    json.dump({"content_paths": paths}, f, indent=2)
                print(f"Saved {len(paths)} content paths to {CONTENT_PATHS_WRITE_TARGET}")
                break
        if len(paths) < max_items:
            raise RuntimeError("Unable to prepare deterministic content paths. Check CONTENT_IMAGES_DIRS.")
    return paths[:max_items]

content_paths = _load_or_create_content_paths(MAX_CONTENT_SAMPLES)
content_images = [Image.open(Path(p)).convert("RGB") for p in content_paths]
transform = StableDiffusionTransform(size=RESOLUTION, center_crop=True)
content_subset = [img.resize((RESOLUTION, RESOLUTION), Image.BILINEAR) for img in content_images]

print(f"Using {len(content_subset)} content images from deterministic list.")

style_paths_data: Dict[str, List[str]] = {}
if STYLE_PATHS_FILE.exists():
    try:
        with open(STYLE_PATHS_FILE, "r", encoding="utf-8") as f:
            stored_styles = json.load(f)
        style_paths_data = stored_styles.get("styles", {})
    except Exception as exc:
        print(f"Failed to read {STYLE_PATHS_FILE}: {exc}. Regenerating style paths.")
        style_paths_data = {}

style_paths_updated = False
style_images_map: Dict[str, List[Image.Image]] = {}
print("\nLoading style images from WikiArt (deterministic)...")
for style_name in LORA_STYLES:
    stored_paths = [p for p in style_paths_data.get(style_name, []) if Path(p).exists()]
    if len(stored_paths) >= NUM_STYLE_IMAGES:
        selected_paths = stored_paths[:NUM_STYLE_IMAGES]
    else:
        style_dir = WIKIART_DIR / style_name
        if not style_dir.exists():
            print(f"  {style_name}: directory not found {style_dir}")
            continue
        files = sorted(list(style_dir.glob("*.jpg")) + list(style_dir.glob("*.png")))
        if not files:
            print(f"  {style_name}: no images found in {style_dir}")
            continue
        selected_paths = [str(p.resolve()) for p in files[:NUM_STYLE_IMAGES]]
        style_paths_data[style_name] = selected_paths
        style_paths_updated = True
    style_images = [Image.open(Path(p)).convert("RGB").resize((RESOLUTION, RESOLUTION), Image.BILINEAR)
                    for p in selected_paths]
    style_images_map[style_name] = style_images
    print(f"  {style_name}: {len(style_images)} images")

if style_paths_updated:
    with open(STYLE_PATHS_WRITE_TARGET, "w", encoding="utf-8") as f:
        json.dump({"styles": style_paths_data}, f, indent=2)

if not style_images_map:
    print("Warning: No style images loaded. CLIP-style metrics will be skipped.")


## Generate Samples


In [None]:
def run_img2img(pipeline, prompt, inputs):
    outputs = []
    for idx, img in enumerate(inputs):
        generator = torch.Generator(device=device).manual_seed(SEED + idx)
        result = pipeline(
            prompt=prompt,
            image=img,
            strength=IMG2IMG_STRENGTH,
            guidance_scale=IMG2IMG_GUIDANCE,
            num_inference_steps=50,
            generator=generator,
        ).images[0]
        outputs.append(result)
    return outputs

style_prompts = {
    "baseline": "a realistic depiction of the same scene",
}

for style_name in LORA_STYLES:
    style_prompts[f"lora_{style_name.lower().replace('_', '')}"] = (
        f"a {style_name.replace('_', ' ').lower()} painting of the scene"
    )

generated_samples = {}

print("Generating baseline img2img samples...")
baseline_samples = run_img2img(
    baseline_pipeline,
    style_prompts["baseline"],
    content_subset,
)
generated_samples["baseline"] = baseline_samples
print(f"Baseline: {len(baseline_samples)} samples")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

for style_name in LORA_STYLES:
    model_key = f"lora_{style_name.lower().replace('_', '')}"
    if style_name in lora_pipelines:
        pipeline = lora_pipelines[style_name]
        print(f"\nGenerating samples from LoRA {style_name}...")
        style_samples = run_img2img(
            pipeline,
            style_prompts[model_key],
            content_subset,
        )
        generated_samples[model_key] = style_samples
        print(f"LoRA {style_name}: {len(style_samples)} samples")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print(f"\nTotal generated samples: {sum(len(v) for v in generated_samples.values())}")


## Compute Metrics


In [None]:
def compute_metrics_for_model(
    model_name: str,
    generated_images: List[Image.Image],
    content_images_list: List[Image.Image],
    style_images_list: Optional[List[Image.Image]] = None,
) -> Dict[str, float]:
    clip_content_scores = []
    for gen_img, content_img in zip(generated_images, content_images_list[:len(generated_images)]):
        clip_score = compute_clip_distance(gen_img, content_img)
        clip_content_scores.append(clip_score)
    clip_content_mean = np.mean(clip_content_scores)
    clip_content_std = np.std(clip_content_scores)
    
    metrics = {
        "clip_content": clip_content_mean,
        "clip_content_std": clip_content_std,
    }
    
    if style_images_list and len(style_images_list) > 0:
        clip_style_scores = []
        for gen_img, style_img in zip(generated_images, style_images_list[:len(generated_images)]):
            clip_style = compute_clip_distance(gen_img, style_img)
            clip_style_scores.append(clip_style)
        if clip_style_scores:
            metrics["clip_style"] = np.mean(clip_style_scores)
            metrics["clip_style_std"] = np.std(clip_style_scores)
    
    return metrics

all_metrics = {}

for model_name, samples in generated_samples.items():
    print(f"\nComputing metrics for {model_name}...")
    
    style_images_for_model = None
    if model_name != "baseline":
        style_name = None
        for s in LORA_STYLES:
            if s.lower().replace('_', '') in model_name.lower():
                style_name = s
                break
        if style_name and style_name in style_images_map:
            style_images_for_model = style_images_map[style_name]
    
    metrics = compute_metrics_for_model(
        model_name,
        samples,
        content_subset,
        style_images_for_model,
    )
    all_metrics[model_name] = metrics
    print(f"  CLIP-content: {metrics['clip_content']:.4f} ± {metrics['clip_content_std']:.4f}")
    if 'clip_style' in metrics:
        print(f"  CLIP-style: {metrics['clip_style']:.4f} ± {metrics['clip_style_std']:.4f}")

In [None]:
baseline_clip_value = all_metrics.get("baseline", {}).get("clip_content", None)
if baseline_clip_value is not None and baseline_clip_value > 0:
    for model_name, metrics in all_metrics.items():
        metrics["style_strength"] = metrics["clip_content"] / baseline_clip_value
else:
    for metrics in all_metrics.values():
        metrics["style_strength"] = np.nan

if "baseline" in all_metrics:
    all_metrics["baseline"]["style_strength"] = 1.0


In [None]:
results_df = pd.DataFrame(all_metrics).T
results_df.index.name = "model"
results_df = results_df.reset_index()

print("\n" + "="*60)
print("EVALUATION RESULTS - LoRA (FINAL)")
print("="*60)
print(results_df.to_string(index=False))
print("\n" + "="*60)

results_path = RESULTS_DIR / "lora_metrics_final.csv"
results_df.to_csv(results_path, index=False)
print(f"\nResults saved to: {results_path}")

json_path = RESULTS_DIR / "lora_metrics_final.json"
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(all_metrics, f, indent=2)
print(f"Results saved to: {json_path}")

print("\nComparison with Baseline:")
baseline_clip = all_metrics.get("baseline", {}).get("clip_content", 0)
for model_name, metrics in all_metrics.items():
    if model_name != "baseline":
        clip_delta = metrics["clip_content"] - baseline_clip
        print(f"\n{model_name}:")
        print(f"  CLIP-content: {metrics['clip_content']:.4f} (Δ vs baseline: {clip_delta:+.4f})")
        if 'clip_style' in metrics:
            print(f"  CLIP-style: {metrics['clip_style']:.4f}")
        print(f"  Style Strength: {metrics['style_strength']:.4f}")
