# Environment

In [None]:
import os

import cv2 as cv
import matplotlib.pylab as plt
import numpy as np
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline

In [None]:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    # revision="fp16",
    # torch_dtype depends on CPU (float32) vs GPU (float16)
    # https://stackoverflow.com/questions/75641074/i-run-stable-diffusion-its-wrong-runtimeerror-layernormkernelimpl-not-implem
    # torch_dtype=torch.float16,
).to("cuda")

In [None]:
def show_result(
    patch: np.array, 
    background: np.array, 
    output: np.array,
    zoom: tuple[tuple[int, int]],
    method_name: str
):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    axes[0].imshow(patch)
    axes[0].set_title("Patch")
    axes[1].imshow(background)
    axes[1].set_title("Background")
    axes[2].imshow(output)
    axes[2].set_title("Merged images")
    axes[3].imshow(output[
        zoom[1][0]:zoom[1][1],
        zoom[0][0]:zoom[0][1],
        :
    ])
    axes[3].set_title("Zoomed")
    # axes[3].set_xlim(zoom[0])
    # axes[3].set_ylim(zoom[1])
    plt.suptitle(method_name)

# Analysis

In [None]:
DATA_DIR = "../../data/postprocessed"

BORDER = 20 # in pixel

In [None]:
patch_path = os.path.join(
    DATA_DIR,
    "train",
    "patches",
    "20220816_TaenikonWiese_S_xx_F_xx_O_sama_ID2_DJI_20220816121514_0132.0_2_rumex.png"
)
easy_background_path = os.path.join(
    DATA_DIR,
    "train",
    "images",
    "20230615_SchildDotnachtZaelgli_S_20_F_60_H_12_O_krma_ID1_DJI_20230615145252_0193.1_3.png"
)
difficult_background_path = os.path.join(
    DATA_DIR,
    "train",
    "images",
    "20230609_HerrenpuentSuedost_S_20_F_60_H_12_O_krma_ID1_DJI_20230609151113_0028.1_3.png"
)

In [None]:
background_path = difficult_background_path

In [None]:
patch = cv.imread(patch_path)
patch = cv.cvtColor(patch, cv.COLOR_RGB2BGR)
background = cv.imread(background_path)
background = cv.cvtColor(background, cv.COLOR_RGB2BGR)
# Background must be a square for some of the methods we are trying
background = background[:512, :512, :]

In [None]:
# Simulate a mask
p_x, p_y, _ = patch.shape
x, y = 50, 50
x_min, x_max = x, x+p_x
y_min, y_max = y, y+p_y
bg_patch = background[x_min:x_max, y_min:y_max, :]
bounding_box_mask = np.zeros(background.shape[:-1])
bounding_box_mask[x_min:x_max, y_min:y_max] = 1

In [None]:
def get_inpainting_mask(background, patch, bounding_box_mask, border: int = 10):
    
    p_x, p_y = patch.shape[:-1]
    bg_x, bg_y = background.shape[:-1]
    
    x_pos, y_pos = np.where(bounding_box_mask == 1)
    # Top left corner of bounding box
    x = x_pos[0]
    y = y_pos[0]
    
    # Mask is the border around the patch
    mask = np.zeros((p_x + 2 * border, p_y + 2 * border))
    mask[:2*border, :] = 1
    mask[-2*border:, :] = 1
    mask[:, -2*border:] = 1
    mask[:, :2*border] = 1

    tmp = np.zeros((bg_y, bg_x))
    tmp[x-border:x+p_x+border, y-border:y+p_y+border] = mask
    mask = tmp
    
    mask_3d = np.dstack((mask, mask, mask))
    
    inpainting_mask = (mask_3d * 255).astype(np.uint8)
    
    return inpainting_mask


In [None]:
border = 10
inpainting_mask = get_inpainting_mask(background, patch, bounding_box_mask, border=border)

In [None]:
# The y axis of the inpainting_mask should extent by the amount of border
(
    np.where(inpainting_mask[:, :, 0] == 255)[0].max() - 
    np.where(bounding_box_mask[:, :] == 1)[0].max()
) == border


In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(bounding_box_mask, cmap="Greys_r")
ax[1].imshow(inpainting_mask);

In [None]:
def perform_overlay_inpainting(
    background, 
    patch, 
    bounding_box_mask, 
    output_filepath: str,
    border: int = 1
):
    
    # Mask is RGB (actually B&W) [0, 255] with content to be inpainted denoted by the white part
    inpainting_mask_image = get_inpainting_mask(background, patch, bounding_box_mask, border)
    
    # Image and mask_image should be PIL images.
    # The mask structure is white for in-painting and black for keeping as is
    image = pipe(
        # Default performs content aware filling
        prompt="", 
        image=Image.fromarray(background),
        mask_image=Image.fromarray(inpainting_mask_image)
    ).images[0]
    
    image.save(output_filepath)
    
    return image

    

In [None]:
image = perform_overlay_inpainting(background, patch, bounding_box_mask, border=border, output_filepath="test.jpg")

In [None]:
image