In [None]:
# Author: Jonas Rashidi

import numpy as np
import matplotlib.image as mpimg
from scipy import spatial
from tqdm import tqdm
import cv2  
from pyinpaint.utils import *

class Inpainting:
    def __init__(self, org_img, mask, ps):
        self.org_img = org_img
        self.mask = mask
        self.ps = ps

    def __call__(self, k_boundary=4, k_search=1000, k_patch=5):
        inpainted_img = self.forward(k_boundary, k_search, k_patch)
        return inpainted_img

    def preprocess(self):
        img = mpimg.imread(self.org_img)
        mask = cv2.imread(self.mask, cv2.IMREAD_GRAYSCALE)  

        img = img.astype("float32") / 255.0
        self._shape = img.shape

        img[mask == 0] = 0
        
        position = pmat(self._shape)
        texture = fmat(img)

        self._position = (position - np.min(position)) / (np.max(position) - np.min(position))
        self._texture = (texture - np.min(texture)) / (np.max(texture) - np.min(texture))
        
        self._patches = create_patches(img, (self.ps, self.ps))

    def postprocess(self, fmat):
        img_out = to_img(fmat, self._shape)
        img_out = np.clip(img_out, 0, 1)
        img_out = (img_out * 255).astype(np.uint8)
        return img_out

    def forward(self, k_boundary, k_search, k_patch):
        self.preprocess()
        kdt = spatial.cKDTree(self._position)
        dA = np.where(self._texture.any(axis=1))[0]
        A = np.where(~self._texture.any(axis=1))[0]

        pbar = tqdm(desc=f"# of pixels to be inpainted are {A.size}", total=A.size, bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt}')
        while A.size >= 1:
            dmA = np.array([]).astype("int")
            for i in A:
                _, indices = kdt.query(self._position[i], k_boundary)
                if (~np.isin(indices, A)).any():
                    dmA = np.append(dmA, i)
                    mask = (~(self._patches[i].flatten() == 0)).astype("int")
                    _, indices = kdt.query(self._position[i], k_search)
                    part_of_dA = indices[~np.isin(indices, A)]
                    new_patches = mask.flatten() * self._patches[part_of_dA]
                    kdt_ = spatial.cKDTree(new_patches)
                    _, indices = kdt_.query(self._patches[i].flatten(), k_patch)
                    ids = part_of_dA[indices]
                    self._texture[i] = self._texture[ids].mean(axis=0)
            self._patches = create_patches(to_img(self._texture, (self._shape)), (self.ps, self.ps))
            dA = np.concatenate((dA, dmA), axis=0)
            A = A[~np.isin(A, dmA)]
            pbar.update(dmA.size)
        pbar.close()
        return self.postprocess(self._texture)


def create_strict_red_line_mask(img):
    hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    lower_red1 = np.array([0, 30, 30])
    upper_red1 = np.array([10, 255, 255])
    lower_red2 = np.array([170, 30, 30])
    upper_red2 = np.array([180, 255, 255])
    
    mask1 = cv2.inRange(hsv_img, lower_red1, upper_red1)
    mask2 = cv2.inRange(hsv_img, lower_red2, upper_red2)
    red_mask = cv2.bitwise_or(mask1, mask2)
    
    binary_mask = cv2.bitwise_not(red_mask)
    binary_mask[binary_mask > 0] = 1
    return binary_mask


def smooth_inpainted_image(img):
    smoothed_img = cv2.GaussianBlur(img, (5, 5), 0)
    return smoothed_img


def sharpen_image(img, sigma=1.0, alpha=1.5):
    img_float = img.astype(np.float32) / 255.0
    blurred = cv2.GaussianBlur(img_float, (0, 0), sigma)
    sharpened = cv2.addWeighted(img_float, 1 + alpha, blurred, -alpha, 0)
    sharpened = np.clip(sharpened, 0, 1)
    return (sharpened * 255).astype(np.uint8)


def check_args(cls):
    def correct_args(org_img, mask, ps=11):
        if isinstance(org_img, str) and isinstance(mask, str) and isinstance(ps, int):
            return cls(org_img, mask, ps)
        else:
            raise Exception(f"arg[0]:str, arg[1]:str, arg[2]:int")
    return correct_args


Inpaint = check_args(Inpainting)

image_path = r"C:\Users\User\Downloads\Input_image.jpg"
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

red_line_mask = create_strict_red_line_mask(img_rgb)
cv2.imwrite("strict_red_line_mask.jpg", red_line_mask)

inpainting = Inpainting(image_path, "strict_red_line_mask.jpg", ps=11)
inpainted_img = inpainting()

smoothed_inpainted_img = smooth_inpainted_image(inpainted_img)
sharpened_inpainted_img = sharpen_image(smoothed_inpainted_img)

cv2.imwrite("sharpened_inpainted_image 1.jpg", sharpened_inpainted_img)
