In [81]:
from __future__ import division

import numpy as np
import scipy as scp
import pylab as pyl
import time
import matplotlib.pyplot as plt
from numpy import random
import copy
from PIL import Image

from nt_toolbox.general import *
from nt_toolbox.signal import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [82]:
class NNF():


#     Nearest-Neighbor Field 1 pixel = { x_target, y_target, distance_scaled } 
#     constructor
    def __init__(self, image, output, patchsize):
        self.image = image
        self.output = output
        self.S = patchsize
        self.inputH = image.H
        self.inputW = image.W
        self.outputH = output.H
        self.outputW = output.W

#     initialize field with random values
    def randomize(self):
#         field
        self.field = np.zeros([self.inputW,self.inputH,3])

        for y in range(self.inputH):
            for x in range(self.inputW):
                self.field[x][y][0] = random.randint(self.outputW)
                self.field[x][y][1] = random.randint(self.outputH)
                self.field[x][y][2] = MaskedImage.DSCALE
        self.initialize()

        
#     initialize field from an existing (possibily smaller) NNF
    def initializeFrom(self, nnf):
#         field
        self.field = np.zeros([self.inputW,self.inputH,3])

        fx = self.inputW/nnf.inputW
        fy = self.inputH/nnf.inputH
#         print("nnf upscale by "+fx+"x"+fy+" : "+nnf.input.W+","+nnf.input.H+" -> "+input.W+","+input.H);
        for y in range(self.inputH):
            for x in range(self.inputW):
                xlow = min(x/fx, nnf.inputW-1)
                ylow = min(y/fy, nnf.inputH-1)
                self.field[x][y][0] = nnf.field[xlow][ylow][0]*fx  
                self.field[x][y][1] = nnf.field[xlow][ylow][1]*fy
                self.field[x][y][2] = MaskedImage.DSCALE
        self.initialize()

        
#     compute initial value of the distance term
    def initialize(self):
        for y in range(self.inputH):
            for x in range(self.inputW):
                self.field[x][y][2] = self.distance(x,y, self.field[x][y][0],self.field[x][y][1])

#                 if the distance is INFINITY (all pixels masked ?), try to find a better link
                iteration = 0
                maxretry = 20
                while self.field[x][y][2] == MaskedImage.DSCALE and iteration < maxretry:
                    self.field[x][y][0] = random.randint(self.outputW)
                    self.field[x][y][1] = random.randint(self.outputH)
                    self.field[x][y][2] = self.distance(x,y, self.field[x][y][0],self.field[x][y][1])
                    iteration += 1
                    

#     multi-pass NN-field minimization (see "PatchMatch" - page 4)
    def minimize(self, passes):

        min_x = 0
        min_y = 0
        max_x = self.inputW-1
        max_y = self.inputH-1

#         multi-pass minimization
        for i in range(passes):
            print(".")

#             scanline order
            for y in range(min_y,max_y):
                for x in range(min_x,max_x):
                    if self.field[x][y][2] > 0:
                        self.minimizeLink(x,y,+1)

#             reverse scanline order
            for y in range(max_y,min_y-1,-1):
                for x in range(max_x,min_x-1,-1):
                    if self.field[x][y][2] > 0:
                        self.minimizeLink(x,y,-1)

                            
#     minimize a single link (see "PatchMatch" - page 4)
    def minimizeLink(self, x, y, direction):

#         Propagation Left/Right
        if x-direction > 0 and x-direction < self.inputW:
            xp = self.field[x-direction][y][0] + direction
            yp = self.field[x-direction][y][1]
            dp = self.distance(x,y, xp, yp)
            if dp < self.field[x][y][2]:
                self.field[x][y][0] = xp
                self.field[x][y][1] = yp
                self.field[x][y][2] = dp

#         Propagation Up/Down
        if y-direction > 0 and y-direction < self.inputH:
            xp = self.field[x][y-direction][0]
            yp = self.field[x][y-direction][1] + direction
            dp = self.distance(x,y, xp,yp)
            if dp < self.field[x][y][2]:
                self.field[x][y][0] = xp
                self.field[x][y][1] = yp
                self.field[x][y][2] = dp

#         Random search
        wi = self.outputW 
        xpi = self.field[x][y][0]
        ypi = self.field[x][y][1]
        
        while wi > 0:
            xp = xpi + random.randint(2*wi)-wi
            yp = ypi + random.randint(2*wi)-wi
            xp = max(0, min(self.outputW-1, xp))
            yp = max(0, min(self.outputH-1, yp))

            dp = self.distance(x,y, xp,yp)
            if dp < self.field[x][y][2]:
                self.field[x][y][0] = xp
                self.field[x][y][1] = yp
                self.field[x][y][2] = dp
            
            wi//=2

#     compute distance between two patches
    def distance(self, x, y, xp, yp):
        return MaskedImage.distance(self.image,x,y,self.output,xp,yp,self.S)

    def getField(self):
        return self.field


In [83]:
class MaskedImage():

#     the maximum value returned by MaskedImage.distance() 
    DSCALE = 65535
        
#     construct from existing image and mask
    def __init__(self, image, mask):
        self.image = image
        self.W = image.size[0]
        self.H = image.size[1]
        self.mask = mask
            
#     construct empty image
    @classmethod
    def initEmpty(cls, width, height):
        image = Image.new('RGB', (int(width),int(height)))
        mask = np.zeros([int(width),int(height)])
        
        return cls(image, mask)

    @staticmethod
    def similarity(i, sigma2 = 0.05):
        t = i/(DSCALE + 1)
        return np.exp(-t/(2*sigma2))
            
        
    def getImage(self):
        return self.image
    

    def getSample(self, x, y, band):
        return self.image.getpixel((int(x),int(y)))[band]
    

    def setSample(self, x, y, band, value):
        
        newpixel = np.array(self.image.getpixel((int(x),int(y))))
        newpixel[band] = value
        ans = map(int,newpixel)
        self.image.putpixel((int(x),int(y)), tuple(ans))
    

    def isMasked(self, x, y):
        return self.mask[x][y]
    

    def setMask(self, x, y, value):
        self.mask[x][y] = value
    

    def countMasked(self):
    
        count = 0
        
        for y in range(self.H):
            for x in range(self.W):
                if (self.mask[x][y]):
                    count += 1
        
        return count
    

#     return true if the patch contains one (or more) masked pixel
    def containsMasked(self, x, y, S):
        for dy in range(-S,S+1):
            for dx in range(-S,S+1):
                xs = x+dx
                ys = y+dy
                if (xs<0 or xs>=W):
                    continue
                elif (ys<0 or ys>=H):
                    continue
                elif mask[xs][ys]:
                    return True
        
        return False
        

#     distance between two patches in two images
    @staticmethod
    def distance(source, xs, ys, target, xt, yt, S):
        distance = 0 
        wsum = 0
        ssdmax = 10*255*255 #Corresponds to a difference of 255 for each color channel

#         for each pixel in the source patch
        for dy in range(-S,S+1):
            for dx in range(-S,S+1):
        
                wsum += ssdmax

                xks = xs+dx
                yks = ys+dy
        
                if xks < 0 or xks >= source.W:
                    distance += ssdmax 
                    continue
        
                if yks < 0 or yks >= source.H:
                    distance += ssdmax 
                    continue
        
#                 cannot use masked pixels as a valid source of information
                if source.isMasked(xks, yks):
                    distance += ssdmax 
                    continue
                    
#                 corresponding pixel in the target patch
                xkt = xt+dx
                ykt = yt+dy
        
                if xkt < 0 or xkt >= target.W:
                    distance += ssdmax 
                    continue
        
                if ykt < 0 or ykt >= target.H:
                    distance += ssdmax 
                    continue
        
#                 cannot use masked pixels as a valid source of information
                if target.isMasked(xkt, ykt):
                    distance += ssdmax 
                    continue

#                 SSD distance between pixels (each value is in [0,255^2])
                ssd = 0

#                 value distance (weight for R/G/B components = 3/6/1)
                for band in range(3):
                    weight = [3,6,1][band]
                    diff2 = (source.getSample(xks, yks, band) - target.getSample(xkt, ykt, band))**2# Value 
                    ssd += weight*diff2
                

#                 add pixel distance to global patch distance
                distance += ssd

        return int(MaskedImage.DSCALE*distance/wsum)

    
#     Helper for image resize
    def resize(self, image, newwidth, newheight):
        out = image.resize((newwidth, newheight))
#         out.show()
        return out

#     return a copy of the image
    def copy(self):
        
        newimage = self.image.copy()
        newmask = copy.deepcopy(self.mask)
        
        return MaskedImage(newimage,newmask)
    

#     return a downsampled image (factor 1/2)
    def downsample(self):
        
        newW = self.W/2 
        newH = self.H/2

#         Binomial coefficient kernels
        kernelEven = [1,5,10,10,5,1]
        kernelOdd = [1,4,6,4,1]

        if self.W%2 == 0:
            kernelx = kernelEven
        else:
            kernelx = kernelOdd
            
        if self.H%2 == 0:
            kernely = kernelEven
        else:
            kernely = kernelOdd

        newimage = MaskedImage.initEmpty(newW, newH)
        
        ny = -1
        
        for y in range(0,self.H-1,2):
            ny += 1
            nx = -1
            for x in range(0,self.W-1,2):
                
                nx += 1
                
                r = 0
                g = 0
                b = 0
                ksum = 0
                masked = 0
                total = 0

                for dy in range(len(kernely)):
                    yk = y+dy-2
                    if yk < 0 or yk >= self.H:
                        continue
                    
                    for dx in range(len(kernelx)):
                        
                        xk = x+dx-2
                        
                        if xk<0 or xk >= self.W:
                            continue

                        total += 1
                        
                        if self.mask[xk][yk]:
                            masked += 1
                            continue

                        k = kernelx[dx]*kernely[dy]
                        r += k*self.getSample(xk, yk, 0)
                        g += k*self.getSample(xk, yk, 1)
                        b += k*self.getSample(xk, yk, 2)
                        ksum += k
                
                if ksum > 0:
                    newimage.setSample(nx, ny, 0, int(r/ksum+0.5))
                    newimage.setSample(nx, ny, 1, int(g/ksum+0.5))
                    newimage.setSample(nx, ny, 2, int(b/ksum+0.5))
                    newimage.setMask(nx, ny, False)
                else:
                    newimage.setMask(nx, ny, True)


                if masked > 0.75*total:
                    newimage.setMask(nx, ny, True)
                else:
                    newimage.setMask(nx, ny, False)
            
        return newimage


#     return an upscaled image
    def upscale(self, newW, newH):
        newimage = MaskedImage.initEmpty(newW, newH)
        newimage.image = self.resize(self.image, newW, newH)
        return newimage

In [84]:
class Inpaint():

    def inpaint(self, image, mask, radius):
#         initial image
        self.initial = MaskedImage(image, mask)

#         patch radius
        self.radius = radius

#         working copies
        source = copy.deepcopy(self.initial)

        print("build pyramid of images...")

#         build pyramid of downscaled images
        self.pyramid = []
        self.pyramid.append(source)
        while source.W > self.radius and source.H > self.radius:
            if source.countMasked() == 0:
                break
            source = source.downsample()
            self.pyramid.append(source)

        
        maxlevel = len(self.pyramid)

#         The initial target is the same as the smallest source.
#         We consider that this target contains no masked pixels
        target = source.copy()
        for y in range(target.H):
            for x in range(target.W):
                target.setMask(x,y,False)
                

#         for each level of the pyramid 
        for level in range(maxlevel-1,0,-1):
            print("\n*** Processing -  Zoom 1:"+str(1<<level)+" ***")

#             create Nearest-Neighbor Fields (direct and reverse)
            source = self.pyramid[level]

            print("initialize NNF...")
            if level == maxlevel-1:
#                 at first,  use random data as initial guess
                self.nnf_TargetToSource = NNF(target, source, self.radius)
                self.nnf_TargetToSource.randomize()
    
            else:
#                 then, we use the rebuilt (upscaled) target 
#                 and reuse the previous NNF as initial guess
                new_nnf = NNF(target, source, self.radius)
                new_nnf.initializeFrom(self.nnf_TargetToSource)
                self.nnf_TargetToSource = new_nnf

#             Build an upscaled target by EM-like algorithm (see "PatchMatch" - page 6)
            target = self.ExpectationMaximization(level)
            target.getImage().show()
        
        return target.getImage()
    

#     EM-Like algorithm (see "PatchMatch" - page 6)
#     Returns a double sized target image
    def ExpectationMaximization(self,level):

        iterEM = min(2*level,4)
        iterNNF = min(5,level)

        source = self.nnf_TargetToSource.output
        target = self.nnf_TargetToSource.image
        newtarget = None

        print("EM loop (em =",iterEM,",nnf =",iterNNF,") :")

#         EM Loop
        for emloop in range(1,1+iterEM):

            print(" "+str(1+iterEM-emloop))

#             set the new target as current target
            if newtarget is not None:
                self.nnf_TargetToSource.image = copy.deepcopy(newtarget)
                target = copy.deepcopy(newtarget)
                newtarget = None
            

#             -- minimize the NNF
            self.nnf_TargetToSource.minimize(iterNNF)

#             -- Now we rebuild the target using best patches from source

            upscaled = False

#             Instead of upsizing the final target, we build the last target from the next level source image 
#             So the final target is less blurry (see "Space-Time Video Completion" - page 5)
            if level >= 1 and emloop == iterEM:
                newsource = self.pyramid[level-1]
                newtarget = target.upscale(newsource.W,newsource.H)
                upscaled = True
            else:
                newsource = self.pyramid[level]
                newtarget = target.copy()
                upscaled = False
            

#             --- EXPECTATION/MAXIMIZATION step ---
            self.EM_Step(newsource, newtarget, self.nnf_TargetToSource, upscaled)

#             debug : display intermediary result
#             result = newtarget.getImage().resize((self.initial.W, self.initial.H))
#             Demo.display(result)
        
        return newtarget


#      Expectation-Maximization step : vote for best estimations of each pixel and compute maximum likelihood
    def EM_Step(self, source, target, nnf, upscaled):
    
        field = nnf.getField()
        R = nnf.S
    
        if upscaled:
            R*=2

#         for each pixel in the target image
        for y in range(target.H):
    
            for x in range(target.W):

#                 clear histograms
                histo = np.zeros([3,256])
                wsum = 0

#                 **** ESTIMATION STEP ****

#                 for all target patches containing the pixel
                for dy in range(-R,1+R):
                    for dx in range(-R,1+R):

#                         (xpt,ypt) = center pixel of the target patch
                        xpt = x+dx
                        ypt = y+dy

#                         get best corresponding source patch from the NNF
                        if not upscaled:
                            if xpt < 0 or xpt >= nnf.image.W:
                                continue
                            if ypt < 0 or ypt >= nnf.image.H:
                                continue
                            xst = field[xpt][ypt][0]
                            yst = field[xpt][ypt][1]
                            w = MaskedImage.similarity(field[xpt][ypt][2])
#                             , np.percentile(field[:][:][2], 75)
                        else:
                            if xpt < 0 or xpt >= 2*nnf.image.W:
                                continue
                            if ypt < 0 or ypt >= 2*nnf.image.H:
                                continue
                            xst = 2*field[xpt//2][ypt//2][0]+(xpt%2)
                            yst = 2*field[xpt//2][ypt//2][1]+(ypt%2)
                            w = MaskedImage.similarity(field[xpt//2][ypt//2][2])
#                             , np.percentile(field[:][:][2], 75)

#                          get pixel corresponding to (x,y) in the source patch
                        xs = xst-dx
                        ys = yst-dy
                        if xs < 0 or xs >= source.W:
                            continue
                        if ys < 0 or ys >= source.H:
                            continue

#                         add contribution of the source pixel
                        if source.isMasked(xs, ys):
                            continue
            
                        red = source.getSample(xs, ys, 0)
                        green = source.getSample(xs, ys, 1)
                        blue  = source.getSample(xs, ys, 2)
                        histo[0][red] += w
                        histo[1][green] += w
                        histo[2][blue] += w
                        wsum += w
                
#                  no significant contribution : conserve the values from previous target
                if wsum < 1:
                    continue

#                  **** MAXIMIZATION STEP ****

#                  average the contributions of significant pixels (near the median) 
                lowth = 0.40*wsum  # low threshold in the CDF
                highth = 0.60*wsum # high threshold in the CDF
                for band in range(3):
                    
                    cdf = 0
                    contrib = 0
                    wcontrib = 0
        
                    for i in range(256):
        
                        cdf += histo[band][i]
            
                        if cdf < lowth:
                            continue
            
                        contrib += i*histo[band][i] 
                        wcontrib += histo[band][i]
             
                        if cdf > highth:
                            break
                    
                    value = int(contrib/wcontrib)
                    target.setSample(x, y, band, value)

In [91]:
class Demo():

#     display widget
    @staticmethod
    def display(bimg):
        bimg.show()
    

    def loadImage(self,filename):
        return Image.open(filename).convert('RGB')
    

    def main(self):

        image = self.loadImage("lakeandballoon.jpg")
        maskimage = self.loadImage("randomMask.jpg")

#         generate mask array from mask image
        W = maskimage.size[0]
        H = maskimage.size[1]
        mask = np.zeros([W,H])
        for y in range(H):
            for x in range(W):
                mask[x][y] = (sum(maskimage.getpixel((x,y))) < 10)

#         overwrite image, to see the mask in RED
        W = image.size[0]
        H = image.size[1]
        for y in range(H):
            for x in range(W):
                if mask[x][y]:
                    image.putpixel((x,y), (255,0,0))

        image.show()
        output = Inpaint().inpaint(image, mask, 2)
        output.show()

        print("\nDONE.")

In [None]:
start = time.time()
Demo().main()
end = time.time()
print(end - start)

build pyramid of images...

*** Processing -  Zoom 1:8 ***
initialize NNF...




EM loop (em = 4 ,nnf = 3 ) :
 4
.
.
.
 3
.
.
.
 2
.
.
