In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from skimage import data, color, img_as_float, feature, filters

In [None]:
def inpainting(working_image, working_mask, patch_size):
    confidence = 1-working_mask  
    center = patch_size//2
    keep_going = True
    while(keep_going):
        front = feature.canny(working_mask) 

        confidence = compute_C(confidence, front, center)
        data = compute_D(working_image, working_mask)
        priorities = confidence * data * front
        
        target_pixel = np.unravel_index(np.argmax(priorities), priorities.shape)
        
        similar_patch = find_similar_source_patch(working_image, target_pixel, center, working_mask)

        working_image = update_img(working_image, working_mask, target_pixel, similar_patch, center)
        confidence = update_C(confidence, target_pixel, center, working_mask)
        working_mask = update_mask(working_mask, target_pixel, center)
  
        keep_going = not isEmpty(working_mask, working_image)

    return working_image


def update_mask(source_mask, target_pixel, center):
    source_mask[
           start(target_pixel[0],center): end(target_pixel[0],center,source_mask.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,source_mask.shape[1])
        ] = 0

    return working_mask

def update_C(confidence, target_pixel, center, source_mask):
    mask_patch = source_mask[
           start(target_pixel[0],center): end(target_pixel[0],center,source_mask.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,source_mask.shape[1])
        ]
    conf_value = confidence[target_pixel[0],target_pixel[1]]
    pixels = np.argwhere(mask_patch == 1) + [start(target_pixel[0],center), start(target_pixel[1],center)]
    for pixel in pixels:
        confidence[pixel[0],pixel[1]] = conf_value       
    return confidence;

def update_img(source_image, source_mask, target_pixel, pixels, center):
    x = np.copy(source_image)
    y = np.copy(source_image)
    target_patch = x[
           start(target_pixel[0],center): end(target_pixel[0],center,source_image.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,source_image.shape[1])
        ]
    
    mask_patch = source_mask[
           start(target_pixel[0],center): end(target_pixel[0],center,source_mask.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,source_mask.shape[1])
        ]

    similar_patch = y[
            pixels[0]: pixels[0] + target_patch.shape[0], 
            pixels[1]: pixels[1] + target_patch.shape[1]
        ]
    
    similar_patch[mask_patch==0]=0
    target_patch[mask_patch==1]=0

    new_data = similar_patch + target_patch
  
    source_image[
           start(target_pixel[0],center): end(target_pixel[0],center,source_image.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,source_image.shape[1])
        ] = new_data
    
    return source_image


def find_similar_source_patch(source_image, target_pixel, center, input_mask):
    lab_image = color.rgb2lab(source_image)
    lab = np.copy(lab_image)
    
    ##### highest priority patch
    highest_patch = lab[
           start(target_pixel[0],center): end(target_pixel[0],center,lab_image.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,lab_image.shape[1]),
        ] 
    
    patch_height, patch_width = highest_patch.shape[0], highest_patch.shape[1]
    
    ##### I need this to put a 0 the target region of the 2 patches
    mask_patch = input_mask[
           start(target_pixel[0],center): end(target_pixel[0],center,input_mask.shape[0]), 
           start(target_pixel[1],center): end(target_pixel[1],center,input_mask.shape[1]),
        ]
    
    highest_patch[mask_patch == 1] = 0

    best_match = None
    best_match_difference = 0
    
    for y in range(0, source_image.shape[0] - patch_height +1):
        for x in range(0, source_image.shape[1] - patch_width +1):
            source_mask = input_mask[
                    y: y + patch_height, 
                    x: x + patch_width 
            ]

            if(np.any(source_mask == 1)):
                continue
   
            source_patch = lab_image[
                    y: y + patch_height, 
                    x: x + patch_width
            ]
            
            sp = np.copy(source_patch)  
            sp[mask_patch == 1] = 0
            
            difference = np.sum((highest_patch - sp)**2)
            if best_match is None or difference < best_match_difference:
                best_match = [y, x]
                best_match_difference = difference

    return best_match

    
def isEmpty(working_mask, working_image):
    height, width = working_image.shape[:2]
    remaining = working_mask.sum()
    total = height * width
    print('{} of {} completed'.format(total-remaining, total))
    return remaining == 0

def compute_C(confidence, front, center):
    new_confidence = np.copy(confidence)
    front_positions = np.argwhere(front == 1) ##### position of the point on the front
    for point in front_positions:
        patch_confidence = confidence[        ##### patch over a point in the front
           start(point[0],center): end(point[0],center,confidence.shape[0]), 
           start(point[1],center): end(point[1],center,confidence.shape[1])
        ] 
        
        new_confidence[point[0], point[1]] = np.sum(patch_confidence)/(patch_confidence.shape[0]*patch_confidence.shape[1])        
    return new_confidence
    
def compute_D(source_image, source_mask):
    gray_image = img_as_float(color.rgb2grey(source_image))

    gray_image[source_mask == 1] = None ##### do not take into account the target region
    vertical = np.nan_to_num(filters.sobel_v(gray_image))
    horizontal = np.nan_to_num(filters.sobel_h(gray_image))
    temp = horizontal; horizontal = -vertical; vertical = temp;  ##### rotate the gradient by 90 degrees

    normal = filters.sobel(source_mask)
    n = normal/np.linalg.norm(normal)

    data = (np.sqrt((horizontal*n)**2 + (vertical*n)**2)/255) + 0.001

    return data
        
def start(target_pixel, center):
    return max(0, target_pixel - center)

def end(target_pixel, center, dimension):
    return min(target_pixel + center + 1 , dimension)
 

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi']=400

working_image = data.imread('images/image8.png')
working_mask =  data.imread('masks/mask8.png', as_gray=True).round()

plt.subplot(121)
plt.imshow(working_image)
plt.title('Image')
plt.subplot(122)
plt.imshow(working_mask)
plt.title('Mask')

In [None]:
patch_size = 9
output_image = inpainting(working_image, working_mask, patch_size)


plt.imshow(output_image)