In [2]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
from multiprocessing import Pool
import time
import dill

In [2]:
def show_img(img):
    cv2.imshow("img", img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [3]:
def show_plt_img(img_in):
    img = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB)
    plt.imshow(img)

In [4]:
alpha = 0.1
theta = 0.01
sigma = 0.01
B1 = 3
B2 = 0.1 

Input: specular highlight image I(x)

In [5]:
img = cv2.imread('fish.ppm')

In [6]:
# show_img(img)

Compute H(x); S(x); V(x)g by HSV transformation

In [6]:
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) / 255
H, S, V = hsv_img[:, :, 0], hsv_img[:, :, 1], hsv_img[:, :, 2]

In [7]:
I_b, I_g, I_r = cv2.split(img / 255)

In [9]:
# show_img(hsv_img)

Compute highlight detection sets by (8) and (9)  

X_SV(x) = {x| S(x) < α; V(x) > 1 − α} (8) 

In [8]:
np.min(S), np.mean(S), np.max(S)

(0.0, 0.4376002221200981, 0.9647058823529412)

In [9]:
X_SV = np.array(S < alpha) & np.array(V > 1 - alpha)
# X_SV = 

X_G(x) = {x| G(x) ≥ τ} (9)

In [10]:
#gradient magnitude of dark-channel image
I_gmd = np.minimum(I_b, I_g, I_r)
I_gmd.shape

(480, 640)

In [11]:
_I_gmd = np.pad(I_gmd, ((1,1),(1,1)), 'constant', constant_values=((0,0),(0,0)))

In [12]:
G = cv2.Laplacian(I_gmd, cv2.CV_64F)
X_G = np.array(G >= theta)
# show_img(G)

Compute Hue estimate H* (11)


In [13]:
X_SVG = X_SV | X_G

In [14]:
def w_svg(x):
    global X_SVG
    if X_SVG[x]:
        return alpha
    return 1

In [18]:
W_svg_matrix = np.copy(X_SVG).astype('float64')
W_svg_matrix[W_svg_matrix == True] = alpha
W_svg_matrix[W_svg_matrix == False] = 1
W_svg_matrix

array([[1. , 1. , 1. , ..., 1. , 1. , 1. ],
       [1. , 1. , 1. , ..., 1. , 0.1, 1. ],
       [0.1, 1. , 1. , ..., 1. , 1. , 0.1],
       ...,
       [1. , 0.1, 1. , ..., 0.1, 1. , 0.1],
       [1. , 1. , 1. , ..., 1. , 1. , 1. ],
       [1. , 1. , 1. , ..., 1. , 0.1, 1. ]])

In [16]:
def w_hs(H_x, H_u, S_u):
    fst = np.exp(-np.power(H_x - H_u, 2) / np.power(sigma, 2))
#     mb mistake in paper
#     snd = np.exp(-(1 - np.power(S(u), 2)))
    snd = np.exp(-np.power(1 - S_u, 2))
    return fst * snd

def filter_w_hs(H_x, H_part, S_part):
    result = 0
    for H_u, S_u in zip(np.nditer(H_part), np.nditer(S_part)):
        result += w_hs(H_x, H_u, S_u)
    return result
        
def filter_w_hs_S(H_x, H_part, S_part):
    result = 0
    for H_u, S_u in zip(np.nditer(H_part), np.nditer(S_part)):
        result += w_hs(H_x, H_u, S_u) * S_u
    return result

In [21]:
def calc_size(point, shape, window_size):
    result = [0, 0, 0, 0]
    y, x = point
    edge = (window_size - 1) // 2
    result[0] = max(y - edge, 0) #y - edge if y - edge > 0 else 0 
    result[1] = min(shape[0], y + edge) #y + edge if y + edge < shape[0] else shape[0]
    
    result[2] = max(x - edge, 0) #x - edge if x - edge > 0 else 0 
    result[3] = min(x + edge, shape[1]) #x + edge if x + edge < shape[1] else shape[1]
    return result 

def get_whs(H_x, H_part, S_part):
    #     mb mistake in paper
    return np.exp(-(H_x - H_part) ** 2) * np.exp(-(1 - S_part) ** 2)
#     return np.exp(-(H_x - H_part) ** 2) * np.exp(-(1 - S_part ** 2) )


def get_part(size, Matrix):
    return Matrix[size[0] : size[1], size[2] : size[3]]

In [33]:
%%time
window_size = 7
print(X_SVG.shape)
H_11 = np.zeros(X_SVG.shape)
S_12 = np.zeros(X_SVG.shape)


for y in range(X_SVG.shape[0]):
    for x in range(X_SVG.shape[1]):
        if X_SVG[y, x]:
            size = calc_size((y, x), X_SVG.shape, window_size)
            H_part = get_part(size, H)
            S_part = get_part(size, S)        
           
            wsvg = get_part(size, W_svg_matrix)
            H_svg = H_part * wsvg
            H_11[y, x] = (1 / wsvg.sum()) * H_svg.sum()
            
            whs = get_whs(H[y, x], H_part, S_part)
            S_whs = whs * S_part
            S_12[y, x] = (1 / whs.sum()) * S_whs.sum()

(480, 640)
Wall time: 2.08 s


In [40]:
# show_img(S_12)
# show_img(H_11)

In [77]:
# hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) / 255
# H, S, V = hsv_img[:, :, 0], hsv_img[:, :, 1], hsv_img[:, :, 2]
ones = np.ones(X_SVG.shape)
corrected_hsv = np.ones(hsv_img.shape)
corrected_hsv[:, :, 0] = H_11
corrected_hsv[:, :, 1] = S_12
# corrected_hsv[:, :, 2] /= 255
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
corrected_rgb = cv2.cvtColor(corrected_hsv.astype('uint8'), cv2.COLOR_HSV2RGB)

In [78]:
show_img(corrected_img)

Compute m∗d(x) by updating algorithm (25)
k - free color value 


In [79]:
def P(z, u):
    return np.argmin(np.linalg.norm(u - z))