# Evaluation Metrics - Customized Image Generation

Tính các metrics: FID, LPIPS, SSIM cho style transfer models.


## Import và cấu hình


In [2]:
from pathlib import Path

import lpips
import torch
from PIL import Image

from src.utils import eval_utils
from src.utils.data_utils import StableDiffusionTransform


ModuleNotFoundError: No module named 'src'

## Chuẩn hoá và tạo dữ liệu mẫu


In [None]:
project_root = Path("..").resolve()
image_dir = project_root / "results" / "eda"

content_path = image_dir / "coco_samples.png"
style_path = image_dir / "wikiart_samples.png"

transform = StableDiffusionTransform(size=512, center_crop=True)

content_tensor = transform(Image.open(content_path).convert("RGB")).unsqueeze(0)
style_tensor = transform(Image.open(style_path).convert("RGB")).unsqueeze(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

content_tensor = content_tensor.to(device)
style_tensor = style_tensor.to(device)

generated_tensor = torch.clamp(0.7 * content_tensor + 0.3 * style_tensor, -1.0, 1.0)


## Tính toán các chỉ số


In [None]:
layers = (
    "relu4_2",
    "relu1_1",
    "relu2_1",
    "relu3_1",
    "relu4_1",
    "relu5_1",
)
extractor = eval_utils.VGGFeatureExtractor(layers=layers, device=device)
content_loss = eval_utils.compute_content_loss(
    generated_tensor,
    content_tensor,
    extractor=extractor,
    layers=("relu4_2",),
)
style_loss = eval_utils.compute_style_loss(
    generated_tensor,
    style_tensor,
    extractor=extractor,
)
lpips_model = lpips.LPIPS(net="vgg").to(device)
lpips_score = eval_utils.compute_lpips(
    generated_tensor,
    style_tensor,
    model=lpips_model,
)
ssim_score = eval_utils.compute_ssim(
    generated_tensor,
    content_tensor,
).item()
fid_metric = eval_utils.FIDEvaluator(device=device)
fid_metric.update_real(content_tensor)
fid_metric.update_fake(generated_tensor)
fid_score = fid_metric.compute()


## Kết quả


In [None]:
print(f"Content loss: {content_loss.item():.4f}")
print(f"Style loss: {style_loss.item():.4f}")
print(f"LPIPS: {lpips_score.item():.4f}")
print(f"SSIM: {ssim_score:.4f}")
print(f"FID: {fid_score:.4f}")
