In [1]:
import numpy as np
from skimage import img_as_ubyte
from scipy.ndimage import gaussian_filter
from scipy.ndimage import rotate
from skimage.transform import rescale
from skimage import color
from skimage.segmentation import active_contour
from skimage.feature import canny
from skimage.transform import resize
from skimage.util import img_as_float
from skimage import exposure
from skimage.exposure import match_histograms
import imageio 
from scipy.spatial.distance import cdist
import cv2
import matplotlib.pyplot as plt

In [2]:
style = cv2.imread("Starry.jpg")
style = cv2.resize(style, (400, 400))

content = cv2.imread("amsterdam.jpg")
content = cv2.resize(content, (400, 400))


In [None]:
def segment(C, scale):
    #use highboost filter to enhance edges
    kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
    C = cv2.filter2D(C, -1, kernel)
    #convert to grayscale
    C = cv2.cvtColor(C, cv2.COLOR_BGR2GRAY)
    #use canny edge detection
    C = canny(C, sigma=3)
    #use active contour to segment
    s = np.linspace(0, 2*np.pi, 400)
    r = 200 + 100*np.sin(s)
    c = 200 + 100*np.cos(s)
    init = np.array([r, c]).T
    snake = active_contour(C, init, alpha=0.015, beta=10, gamma=0.001, w_line=0, w_edge=1, max_iterations=2500)
    #scale the snake
    snake = snake * scale
    return snake

In [1]:
def nearest_n(R, X, Q_size, S, h, w, c, Pp,Vp,Pstride,mp,L, gap):
    S = S.reshape((h, w, c))
    RX = X(np.bool8(R))
    min_dist = np.inf

    Rxp = Vp.T @ (RX - mp)
    dif = np.tile(Rxp, (1, Pp.shape[1])) - Pp
    sqr = np.sum(dif**2, axis=0)
    sqr = sqr + 0.1*np.min(sqr)*np.random.rand(sqr.shape[0])
    min_idx = np.argmin(sqr)

    rows = (np.floor(((w - Q_size + 1) - 1) / Pstride) + 1).astype(int)
    cols = (np.floor(((w - Q_size + 1) - 1) / Pstride) + 1).astype(int)
    ls, ks = np.unravel_index(np.ceil(min_idx / 4) - 1, (rows, cols))
    ks = (ks - 1) * Pstride + 1
    ls = (ls - 1) * Pstride + 1
    ang = np.mod(min_idx+3, 4)*90

    z = S[ls:ls+Q_size, ks:ks+Q_size, :]
    z = z.ravel()

    return ks, ls, z, ang


In [None]:
def irls(R, X, z):
    tnc, nij = R.shape
    iter = 5
    Xk = X
    r=0.8
    unsampled = np.double(np.logical_not(np.sum(R, axis=1)))

    for i in range(iter):
        A = unsampled
        B = np.dot(Xk, unsampled)
        for j in range(nij):
            w = np.dot(np.sum(Xk(np.bool8(R[:, j]))- z[j]**2 + 1e-6),(r-2)/2)
            A = A + w * R[:, j]
            temp = R[:, j]
            temp[temp > 0] = z[:, i]
            B = B + w * temp

        Xk = B / (A + 1e-6)

    Xtilde = Xk
    return Xtilde

In [1]:
def style_transfer(content, style, hall, mask0, h_coeff, w_coeff, patch_sizes, scales, imsize=400):
    C0 = content.deepcopy().ravel()
    S0 = style.deepcopy().ravel()
    sigma_s=5
    sigma_r=0.2
    h0=imsize
    w0=imsize
    c=3

    C0 = match_histograms(C0.reshape((h0, w0, c)), S0.reshape((h0, w0, c)))
    C0 = C0.ravel()

    gap_sizes = [28, 18, 9, 6]

    X = C0
    X = X + np.max(X)*np.random.rand(X.shape[0])
    X = X.ravel()

    for L in scales:
        # scale everything
        C_scale = cv2.resize(C0.reshape((h0, w0, c)), (int(h0/L), int(w0/L)))
        S_scale = cv2.resize(S0.reshape((h0, w0, c)), (int(h0/L), int(w0/L)))
        mask_scale = cv2.resize(mask0.reshape((h0, w0)), (int(h0/L), int(w0/L)))
        h = C_scale.shape[0]
        w = C_scale.shape[1]
        C = C_scale.ravel()
        S = S_scale.ravel()
        X = cv2.resize(X.reshape((h0, w0, c)), (int(h0/L), int(w0/L))).ravel()
        halls = cv2.resize(hall, (int(h0/L), int(w0/L)))

        for patch in patch_sizes:
            if L>1 and patch == 13:
                continue

            Q_size = patch
            pstride = 4
            S = S.reshape((h, w, c))
            P = np.zeros((c*Q_size**2, (np.floor((h-Q_size)/pstride)+1)*(np.floor((w-Q_size)/pstride)+1)*4))
            for i in range(0, h-Q_size, pstride):
                for j in range(0, w-Q_size, pstride):
                    patch = S[i:i+Q_size, j:j+Q_size, :]
                    P[:, (np.ceil(i/pstride)-1)*(np.floor((w - Q_size)/pstride)+1)*4 + (np.ceil(j/pstride)-1)*4] = patch.ravel()
                    temp1 = cv2.rotate(patch, cv2.ROTATE_90_CLOCKWISE).ravel()
                    P[:, (np.ceil(i/pstride)-1)*(np.floor((w - Q_size)/pstride)+1)*4 + (np.ceil(j/pstride)-1)*4 + 1] = temp1
                    temp2 = cv2.rotate(patch, cv2.ROTATE_180).ravel()
                    P[:, (np.ceil(i/pstride)-1)*(np.floor((w - Q_size)/pstride)+1)*4 + (np.ceil(j/pstride)-1)*4 + 2] = temp2
                    temp3 = cv2.rotate(patch, cv2.ROTATE_90_COUNTERCLOCKWISE).ravel()
                    P[:, (np.ceil(i/pstride)-1)*(np.floor((w - Q_size)/pstride)+1)*4 + (np.ceil(j/pstride)-1)*4 + 3] = temp3
            S = S.ravel()

            # mean of P
            P_mean = np.mean(P, axis=1)
            # remove mean from P
            P = P - np.tile(P_mean, (P.shape[1], 1)).T

            # PCA of P
            U, S, V = np.linalg.svd(P, full_matrices=False)
            # keep 0.95 energy
            energy = np.cumsum(S**2)/np.sum(S**2)
            idx = np.where(energy>0.95)[0][0]
            U = U[:, :idx]
            S = S[:idx]
            V = V[:idx, :]

            # compute P_hat
            P_hat = np.dot(U.T, P)

            #iterate 3 times
            for iter in range(3):
                # 1. style fusion
                X = h_coeff*halls.ravel() + (1-h_coeff)*X

                # 2. patch matching
                gap = gap_sizes[patch_sizes.index(patch)]
                Rall = np.zeros((h*w*c, (np.floor( ((h-Q_size+1)-1)/gap ) + 1 )*(np.floor( ((w-Q_size+1)-1)/gap ) + 1)))
                z = np.zeros((c*patch*patch, (np.floor( ((h-Q_size+1)-1)/gap ) + 1 )*(np.floor( ((w-Q_size+1)-1)/gap ) + 1)))
                for i in range(0, h-Q_size, gap):
                    for j in range(0, w-Q_size, gap):
                        R = np.zeros((h,w,c))
                        R[i:i+Q_size, j:j+Q_size, :] = 1
                        R = R.ravel()
                        Rall[:, (np.ceil(i/gap)-1)*(np.floor( ((w-Q_size+1)-1)/gap ) + 1 ) + np.ceil(j/gap)-1] = R
                        ks, ls, zij, ang = nearest_n(R, X, Q_size, S, h, w, c, P_hat, V, pstride, P_mean, L, gap)
                        temp = rotate(zij.reshape((patch, patch, c)), ang)
                        z[:, (np.ceil(i/gap)-1)*(np.floor( ((w-Q_size+1)-1)/gap ) + 1 ) + np.ceil(j/gap)-1] = temp.ravel()

                # 3. style synthesis
                Xtilde = irls(Rall, X, z)

                #4. content fusion
                W = np.tile(w_coeff*mask_scale.ravel()/np.max(mask_scale.ravel()), (c, 1))
                ones_matrix = np.ones_like(W)
                denominator = W + ones_matrix
                Xhat = (1. / denominator) * (Xtilde + W * C) 

                # 5. color transfer
                X = match_histograms(Xhat.reshape((h, w, c)), S.reshape((h, w, c)))

                # 6. denoise
                X = cv2.bilateralFilter(X.reshape((h, w, c)), 5, sigmaColor=sigma_r, sigmaSpace=sigma_s).ravel()
        
        if L>1:
            X = cv2.resize(X.reshape((h, w, c)), (h0, w0)).ravel()
    
    X.reshape((h0, w0, c))
    return X

## Main

In [None]:
# blurry content image
content_blur = gaussian_filter(content, sigma=100)

hallucination = style_transfer(content_blur, 
                               style,
                               np.ones((400, 400, 3)),
                               0,
                               0,
                               np.ones((400, 400, 3)),
                               [36, 22],
                               [4, 2, 1]
                               )

cv2.imwrite('hallucination.jpg', hallucination)
plt.imshow(hallucination)
plt.show()

estimation = style_transfer(content,
                            style,
                            hallucination,
                            0.25,
                            1.5,
                            cv2.cvtColor(content, cv2.COLOR_BGR2GRAY),
                            [36, 22],
                            [4, 2, 1]
                            )

cv2.imwrite('estimation.jpg', estimation)
plt.imshow(estimation)
plt.show()

### Main

In [None]:
imsize = 400

# image to double
content = np.double(content)
style = np.double(style)
content_blur = gaussian_filter(content, sigma=100)

hall = style_transfer(content_blur,
                        style,
                        np.ones((imsize, imsize, 3)),
                        np.ones((imsize, imsize)),
                        0,
                        0,
                        [36, 22],
                        [4, 2, 1]
                        )

est_img = style_transfer(content,
                            style,
                            hall,
                            segment(cv2.cvtColor(content, cv2.COLOR_BGR2GRAY), 1),
                            0.25,
                            1.5,
                            [36, 22, 13],
                            [4, 2, 1]
                            )

plt.imsave('hallucination.jpg', hall)
plt.imsave('estimation.jpg', est_img)

