In [7]:
import torch
import numpy as np
import math
# from diffusers import StableDiffusionInpaintingPipeline
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, "Image_Inpainting")
from data import InpaintingDataset, find_images_recursively
from evaluate_models import evaluate_model, calculate_psnr
from evaluate_diffusion import save_combined_images, save_grid
from utils import make_mask
from diffusers import StableDiffusionInpaintPipeline, DiffusionPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16
).to(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
def losses(image1, image2):
    T = transforms.ToTensor()
    l1_loss_fn = nn.L1Loss()
    l2_loss_fn = nn.MSELoss()
    return l1_loss_fn(image2, image1), l2_loss_fn(image2, image1)

In [8]:
# img_paths = find_images_recursively("data/cityscapes/val/img")
img_paths = find_images_recursively("data/celeba_filtered")

def run_inpainting(paths, dataset, save_dir="celeba_filtered_outputs", prompt="", stop_at=3):
    os.makedirs(save_dir, exist_ok=True)
    T = transforms.ToTensor()
    
    for i, path in enumerate(paths):
        img = Image.open(path).resize((128, 128))
        mask = make_mask(128, 64, mask_type='square')
        result = pipe(prompt=prompt, image=img, mask_image=mask).images[0]
        
        result.save(os.path.join(save_dir, f"inpainted_{i:04d}.png"))
        result = result.resize(img.size)
        
        # print(l2_loss(img, result))
        new_img = T(img)

        masked_image = new_img * (1 - mask)
        pred_mask = T(result) * mask
        og_mask = new_img * mask
        
        print("PSNR:", calculate_psnr(og_mask, pred_mask))
        l1, l2 = losses(og_mask, pred_mask)
        print("L1 loss:", l1)
        print("L2 loss:", l2)
        
        save_image(masked_image, os.path.join(save_dir, f"masked_{i:04d}.png"))
        img.save(os.path.join(save_dir, f"original_{i:04d}.png"))
        save_image(mask, os.path.join(save_dir, f"mask_{i:04d}.png"))

        print(f"[{i}] Inpainted image saved.")

        if i == stop_at - 1:
            break

    save_combined_images(save_dir, dataset, stop_at)
    grid_path = save_grid(dataset, stop_at)

    return grid_path

grd_path = run_inpainting(img_paths, 'celeba', stop_at=3)
print(grd_path)

  0%|          | 0/50 [00:00<?, ?it/s]

PSNR: 18.33785629272461
L1 loss: tensor(0.0403)
L2 loss: tensor(0.0147)
[0] Inpainted image saved.


  0%|          | 0/50 [00:00<?, ?it/s]

PSNR: 26.537643432617188
L1 loss: tensor(0.0152)
L2 loss: tensor(0.0022)
[1] Inpainted image saved.


  0%|          | 0/50 [00:00<?, ?it/s]

PSNR: 15.030433654785156
L1 loss: tensor(0.0580)
L2 loss: tensor(0.0314)
[2] Inpainted image saved.
gen_outputs/diffusion_celeba_grid.png


In [5]:
# def run_inpainting_old(val_loader, save_dir="celeba_filtered_outputs", prompt="", stop_at=5):
#     os.makedirs(save_dir, exist_ok=True)
#     cnt = 0
    
#     for i, (masked_image, img, mask) in enumerate(val_loader):
#         cnt += 1
#         # print(masked_image.shape, img.shape, mask.shape)
#         # result = pipe(prompt=prompt, image=masked_image, mask_image=mask).images[0]
#         result = pipe(prompt=prompt, image=img, mask_image=mask).images[0]

#         result.save(os.path.join(save_dir, f"inpainted_{i+5:04d}.png"))
#         save_image(masked_image[0], os.path.join(save_dir, f"masked_{i+5:04d}.png"))
#         save_image(img[0], os.path.join(save_dir, f"original_{i+5:04d}.png"))
#         save_image(mask[0], os.path.join(save_dir, f"mask_{i+5:04d}.png"))
        
#         # img.save(os.path.join(save_dir, f"original_{i:04d}.png"))

#         if i % 10 == 0:
#             print(f"[{i}] Inpainted image saved.")

#         if cnt == 5:
#             break