In [1]:
import cv2
import numpy as np
import time
from scipy.sparse import linalg 
from scipy.sparse import lil_matrix 
from glob import glob
import matplotlib.pyplot as plt

In [2]:
#確定index在區域之內或之外
def in_region(index, mask):
    return mask[index] == 1

#上下左右index
def cruciform(index):
    x, y = index
    return [(x+1, y), (x-1, y), (x, y+1), (x, y-1)]

#確定index有無在邊界上
def boundary(index, mask):
    if in_region(index, mask) == False:
        return False
    for p in cruciform(index):
        if in_region(p, mask) == False:
            return True
    return False    


In [3]:
def pointpos(index, mask):
    if in_region(index, mask) == False:
        return 2
    if boundary(index, mask) == True:
        return 1
    return 0

def laplacian_(src, index):
    x, y = index
    value = (4 * src[x, y]) - (1 * src[x+1, y]) - (1 * src[x-1, y]) - (1 * src[x, y+1]) - (1 * src[x, y-1])
    return value

def region_mask(mask):
    nonzero = mask.nonzero()
    return list(zip(nonzero[0], nonzero[1]))

##============create sparse matrix===============##
def sparse_mat(p):
    num = len(p)  #mask中點的數量
    MatA = lil_matrix((num, num))
    for i,index in enumerate(p):
        MatA[i, i] = 4
        for m in cruciform(index):
            if m not in p:
                continue
            j = p.index(m)
            MatA[i, j] = -1
    return MatA

In [4]:
def solver(src, dst, mask):
    pointer = region_mask(mask)
    num = len(pointer)
    
##=============create A matrix=================##    
    MatA = sparse_mat(pointer)
##=============create b matrix=================##
    Matb = np.zeros(num)
    
    for i, index in enumerate(pointer):
        Matb[i] = laplacian_(src, index) 
        
        if pointpos(index, mask) == 1:
            for p in cruciform(index):
                if in_region(p, mask) == False:
                    Matb[i] += dst[p]
##=============solve x===================##                    
    Matx = linalg.cg(MatA, Matb)
    copydst = np.copy(dst).astype(int)
    
    for i, index in enumerate(pointer):
        copydst[index] = Matx[0][i]
    return copydst    

In [5]:
start = time.time()

##=============load image =================##

imgsrc = cv2.imread("img/wsr.png", cv2.IMREAD_COLOR)
imgdst = cv2.imread("img/wta.png", cv2.IMREAD_COLOR)
imgmask = cv2.imread("img/wma.png", cv2.IMREAD_GRAYSCALE)

#normalize
mask = np.atleast_3d(imgmask).astype(np.float) / 255.
mask[mask != 1] = 0

#one channel 
mask = mask[:, :, 0]
channel = imgsrc.shape[-1]

#每個channel都解
result_stack = [solver(imgsrc[:, :, i], imgdst[:, :, i], mask) for i in range(channel)]

result = cv2.merge(result_stack)


end = time.time()

print("Execution time: {0}[secs]".format(round(end - start, 2)))
cv2.imwrite("output/result4.png", result)


Execution time: 1408.78[secs]


True

In [19]:
#======================= One ==========================#
onestart = time.time()

onesrc = cv2.imread("img/einsteinSample.bmp")
onemask = cv2.imread("img/einsteinMask.bmp",0)

onedst = cv2.inpaint(onesrc, onemask, 3, cv2.INPAINT_TELEA)

oneend = time.time()

print("Execution time: {0}[secs]".format(round(oneend - onestart, 2)))

cv2.imwrite("output/oneresult.png", onedst)

Execution time: 0.53[secs]


True