In [4]:
import os
import glob
import shutil
import csv
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from diffusers import ControlNetModel, StableDiffusionInpaintPipeline
from skimage.metrics import peak_signal_noise_ratio as compute_psnr
from skimage.metrics import structural_similarity as compute_ssim
import lpips  # pip install lpips

# ---------------------------
# Inpainting functions
# ---------------------------
def run_lama_inpaint(image_path, mask_path, output_path):
    from simple_lama_inpainting import SimpleLama  # import here in case it's not global
    simple_lama = SimpleLama()
    image = Image.open(image_path)
    mask = Image.open(mask_path).convert('L')
    result = simple_lama(image, mask)
    result.save(output_path)
    print(f"SimpleLama inpaint saved: {output_path}")

# ---------------------------
# ControlNet helper functions
# ---------------------------
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

def clean_huggingface_cache(model_path):
    """Remove unnecessary Hugging Face cache directories and .lock files."""
    for root, dirs, files in os.walk(model_path, topdown=False):
        for name in files:
            if name.endswith(".lock"):
                os.remove(os.path.join(root, name))
        for name in dirs:
            if name.startswith("models--") or name == "temp":
                shutil.rmtree(os.path.join(root, name), ignore_errors=True)

def get_latest_snapshot(model_path):
    """Find and move the correct snapshot folder for a downloaded model."""
    if os.path.exists(model_path):
        for subdir in os.listdir(model_path):
            snapshot_path = os.path.join(model_path, subdir, "snapshots")
            if os.path.exists(snapshot_path):
                snapshots = sorted(os.listdir(snapshot_path), reverse=True)
                if snapshots:
                    latest_snapshot = os.path.join(snapshot_path, snapshots[0])
                    for file_name in os.listdir(latest_snapshot):
                        src = os.path.join(latest_snapshot, file_name)
                        dest = os.path.join(model_path, file_name)
                        if not os.path.exists(dest):
                            shutil.move(src, dest)
                    shutil.rmtree(os.path.dirname(latest_snapshot), ignore_errors=True)
                    return model_path
    return model_path

def check_and_download_model(model_name, model_path, is_controlnet=False):
    """Check if the model exists; if not, download and move it to the correct directory."""
    if is_controlnet:
        model_path = os.path.join(model_path, "controlnet")
    else:
        model_path = os.path.join(model_path, "stable-diffusion")

    if os.path.exists(model_path) and os.listdir(model_path):
        print(f"{model_name} already exists. Skipping download.")
        return

    print(f"{model_name} not found. Downloading...")
    temp_dir = os.path.join("models", "temp")

    if is_controlnet:
        ControlNetModel.from_pretrained(model_name, cache_dir=temp_dir)
    else:
        StableDiffusionInpaintPipeline.from_pretrained(model_name, cache_dir=temp_dir)

    correct_model_path = get_latest_snapshot(temp_dir)
    os.makedirs(model_path, exist_ok=True)
    for file_name in os.listdir(correct_model_path):
        src = os.path.join(correct_model_path, file_name)
        dest = os.path.join(model_path, file_name)
        if not os.path.exists(dest):
            shutil.move(src, dest)
    shutil.rmtree(temp_dir, ignore_errors=True)
    print(f"{model_name} downloaded and saved in {model_path}")

def load_controlnet():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    models_dir = "models"
    os.makedirs(models_dir, exist_ok=True)
    controlnet_dir = os.path.join(models_dir, "controlnet")
    stable_diffusion_dir = os.path.join(models_dir, "stable-diffusion")
    os.makedirs(controlnet_dir, exist_ok=True)
    os.makedirs(stable_diffusion_dir, exist_ok=True)

    check_and_download_model("stabilityai/stable-diffusion-2-inpainting", models_dir, is_controlnet=False)
    check_and_download_model("lllyasviel/control_v11p_sd15_inpaint", models_dir, is_controlnet=True)

    clean_huggingface_cache(models_dir)

    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        stable_diffusion_dir, torch_dtype=torch_dtype, local_files_only=True
    ).to(device, dtype=torch_dtype)
    # If needed, integrate the controlnet model into the pipeline.
    return pipe

def make_divisible_by_8(size):
    """Ensure both width and height are divisible by 8."""
    width, height = size
    width = (width // 8) * 8
    height = (height // 8) * 8
    return width, height

def run_controlnet_inpaint(image_path, mask_path, pipe, reference_images, prompt, output_path):
    # Open image and mask
    image = Image.open(image_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")
    original_size = image.size
    adjusted_size = make_divisible_by_8(original_size)

    conditioning = None
    if reference_images:
        conditioning = [
            img.resize(adjusted_size, Image.Resampling.LANCZOS)
            for img in reference_images
        ]

    result = pipe(
        prompt=prompt,
        image=image.resize(adjusted_size, Image.Resampling.LANCZOS),
        mask_image=mask.resize(adjusted_size, Image.Resampling.LANCZOS),
        conditioning_image=conditioning,
        height=adjusted_size[1],
        width=adjusted_size[0]
    ).images[0]
    result = result.resize(original_size, Image.Resampling.LANCZOS)
    result.save(output_path)
    print(f"ControlNet inpaint saved: {output_path}")

# ---------------------------
# LPIPS model loading
# ---------------------------
def load_lpips_model(model_dir="models/lpips"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, "lpips_alex.pth")
    model = lpips.LPIPS(net='alex')
    if os.path.exists(model_path):
        # Load the saved state dictionary
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        print(f"Loaded LPIPS model from {model_path}")
    else:
        # Save the model state dictionary to the file for future use
        torch.save(model.state_dict(), model_path)
        print(f"Saved LPIPS model to {model_path}")
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    return model

# Load LPIPS model (cached in models folder)
lpips_model = load_lpips_model()

# ---------------------------
# Evaluation functions
# ---------------------------
def prepare_for_lpips(pil_image):
    # Convert image to tensor in range [0,1] then normalize to [-1,1]
    tensor = transforms.ToTensor()(pil_image).unsqueeze(0)
    tensor = tensor * 2 - 1
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    return tensor

def evaluate_metrics(gt_img, inpaint_img):
    # Convert images to numpy arrays (normalized to [0,1])
    gt_np = np.array(gt_img).astype(np.float32) / 255.0
    inpaint_np = np.array(inpaint_img).astype(np.float32) / 255.0

    # Resize if needed
    if gt_np.shape != inpaint_np.shape:
        inpaint_img = inpaint_img.resize(gt_img.size, Image.Resampling.LANCZOS)
        inpaint_np = np.array(inpaint_img).astype(np.float32) / 255.0

    psnr = compute_psnr(gt_np, inpaint_np, data_range=1.0)
    ssim = compute_ssim(gt_np, inpaint_np, multichannel=True, data_range=1.0)

    gt_tensor = prepare_for_lpips(gt_img)
    inpaint_tensor = prepare_for_lpips(inpaint_img)
    with torch.no_grad():
        lpips_distance = lpips_model(gt_tensor, inpaint_tensor).item()

    return psnr, ssim, lpips_distance

# ---------------------------
# Main combined evaluation
# ---------------------------
if __name__ == "__main__":
    # Directories for input and output
    image_dir = "DUT-OMRON-image"   # Ground truth images (JPEG)
    mask_dir = "DUT-OMRON-mask"     # Masks (PNG)
    results_dir = "results"
    lama_dir = os.path.join(results_dir, "lama")
    controlnet_dir = os.path.join(results_dir, "controlnet")
    os.makedirs(lama_dir, exist_ok=True)
    os.makedirs(controlnet_dir, exist_ok=True)

    # Load ControlNet pipeline (and optionally reference images)
    pipe = load_controlnet()
    reference_images_dir = "images/reference_images"
    reference_images = []
    if os.path.exists(reference_images_dir):
        reference_images = [
            Image.open(img).convert("RGB")
            for img in glob.glob(os.path.join(reference_images_dir, "*.*"))
        ]

    # Define a prompt for ControlNet
    prompt = (
        "Restore missing areas by seamlessly extending the surroundings. "
        "Maintain consistency in color, texture, landmarks, and lighting."
    )

    # List to store evaluation results
    evaluation_results = []
    
    # Process each ground truth image (JPEG) in DUT-OMRON-image
    for filename in os.listdir(image_dir):
        if filename.lower().endswith((".jpg", ".jpeg")):
            gt_path = os.path.join(image_dir, filename)
            gt_image = Image.open(gt_path).convert("RGB")
            
            # Get corresponding mask using the same base name with .png extension
            base_name = os.path.splitext(filename)[0]
            mask_filename = base_name + ".png"
            mask_path = os.path.join(mask_dir, mask_filename)
            if not os.path.exists(mask_path):
                print(f"Mask for {filename} not found as {mask_filename}. Skipping.")
                continue

            # Define output paths for both inpainting methods
            out_lama = os.path.join(lama_dir, filename)
            out_controlnet = os.path.join(controlnet_dir, filename)

            # Run SimpleLama inpainting
            try:
                run_lama_inpaint(gt_path, mask_path, out_lama)
            except Exception as e:
                print(f"Error in SimpleLama for {filename}: {e}")
                continue

            # Run ControlNet inpainting
            try:
                run_controlnet_inpaint(gt_path, mask_path, pipe, reference_images, prompt, out_controlnet)
            except Exception as e:
                print(f"Error in ControlNet for {filename}: {e}")
                continue

            # Load the inpainted results for evaluation
            lama_result = Image.open(out_lama).convert("RGB")
            controlnet_result = Image.open(out_controlnet).convert("RGB")
            
            # Evaluate metrics against the ground truth
            lama_psnr, lama_ssim, lama_lpips = evaluate_metrics(gt_image, lama_result)
            controlnet_psnr, controlnet_ssim, controlnet_lpips = evaluate_metrics(gt_image, controlnet_result)
            
            print(f"Evaluated {filename}:")
            print(f"  SimpleLama -> PSNR: {lama_psnr:.2f}, SSIM: {lama_ssim:.4f}, LPIPS: {lama_lpips:.4f}")
            print(f"  ControlNet -> PSNR: {controlnet_psnr:.2f}, SSIM: {controlnet_ssim:.4f}, LPIPS: {controlnet_lpips:.4f}")
            
            # Store results for CSV
            evaluation_results.append({
                'filename': filename,
                'lama_PSNR': lama_psnr,
                'lama_SSIM': lama_ssim,
                'lama_LPIPS': lama_lpips,
                'controlnet_PSNR': controlnet_psnr,
                'controlnet_SSIM': controlnet_ssim,
                'controlnet_LPIPS': controlnet_lpips
            })

    # Write the evaluation metrics to a CSV file
    csv_file_path = "evaluation_results.csv"
    csv_fields = ['filename', 
                  'lama_PSNR', 'lama_SSIM', 'lama_LPIPS', 
                  'controlnet_PSNR', 'controlnet_SSIM', 'controlnet_LPIPS']
    with open(csv_file_path, mode='w', newline='') as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=csv_fields)
        writer.writeheader()
        for row in evaluation_results:
            writer.writerow(row)

    print(f"\nEvaluation complete. Results saved to {csv_file_path}")


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: C:\Users\ng_mi\AppData\Roaming\Python\Python312\site-packages\lpips\weights\v0.1\alex.pth
Loaded LPIPS model from models/lpips\lpips_alex.pth
stabilityai/stable-diffusion-2-inpainting already exists. Skipping download.
lllyasviel/control_v11p_sd15_inpaint already exists. Skipping download.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

SimpleLama inpaint saved: results\lama\im005.jpg


  0%|          | 0/50 [00:00<?, ?it/s]

ControlNet inpaint saved: results\controlnet\im005.jpg


ValueError: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.