In [1]:
import numpy as np
import cv2
import math
import tensorflow as tf
import os
import time

from tensorflow.keras import backend as K
from modelutils import MaxoutConv2D,r2

In [2]:
%load_ext autoreload
%autoreload 2

#jupyter 中开启该选项，否则不执行
%matplotlib inline

# 估计大气光值A 

In [3]:
# get the dark channel of im,and estimate AtmLight of im

# get the dark channel
def DarkChannel(im,sz):
    
    b,g,r = cv2.split(im)
    
    dc = cv2.min(cv2.min(r,g),b)
    
    # the next step's func just like kernel stride of dc,and get the min
    # value in the szXsz patch 
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(sz,sz))
    
    dark = cv2.erode(dc,kernel)    
    
    return dark

# estimate the atmospheric light
def AtmLight(im,dark):
    
    [h,w] = im.shape[:2]
    
    imsz = h*w
    
    numpx = int(max(math.floor(imsz/1000),1))
    
    darkvec = dark.reshape(imsz,1)
    
    imvec = im.reshape(imsz,3)
    
    indices = darkvec.argsort()
    
    indices = indices[imsz-numpx::]
    
    atmsum = np.zeros([1,3])
    
    for idx in range(1,numpx):
        
        atmsum = atmsum + imvec[indices[idx]]
        
    A = atmsum / numpx
    
    return A

# 获得图片transmission map
输出粗糙的transmission map 复原的图像会有 块效应


In [4]:
def TransmissionEstimate(im,net):

    
    patch = []
    
    width = im.shape[1]
    
    height = im.shape[0]
    
    transmission = np.empty((height,width),np.float32)
    
    num_w = int(width / PATCH_SIZE)
    
    num_h = int(height / PATCH_SIZE)
     
    for i in range(num_h):
    
        for j in range(num_w):#    h,                                        w,                              c
            
                hazy_patch = im[0+i*PATCH_SIZE:PATCH_SIZE+i*PATCH_SIZE, 0+j*PATCH_SIZE:PATCH_SIZE+j*PATCH_SIZE, :] 
                
                patch.append(hazy_patch)
                
    patch = np.array(patch)
    
    patch = patch.astype(np.float32)
    
    patch = patch/255.0
    
    trans = net(patch)
    
    temp = trans
    
    temp = temp.numpy().ravel()
    
    for i in range(num_h):
        
            for j in range(num_w):
                
                #f_value = temp[i*PATCH_SIZE+j]
                f_value = temp[i*num_w+j]
                
                transmission[0+i*PATCH_SIZE:PATCH_SIZE+i*PATCH_SIZE, 0+j*PATCH_SIZE:PATCH_SIZE+j*PATCH_SIZE] = f_value
    
    return transmission

# 导向滤波，优化transmission map

In [5]:
def Guidedfilter(im,p,r,eps):
    
    mean_I = cv2.boxFilter(im,cv2.CV_64F,(r,r))
    
    mean_p = cv2.boxFilter(p, cv2.CV_64F,(r,r))
    
    mean_Ip = cv2.boxFilter(im*p,cv2.CV_64F,(r,r))
    
    cov_Ip = mean_Ip - mean_I*mean_p
    
    mean_II = cv2.boxFilter(im*im,cv2.CV_64F,(r,r))
    
    var_I   = mean_II - mean_I*mean_I
    
    a = cov_Ip/(var_I + eps)
    
    b = mean_p - a*mean_I
    
    mean_a = cv2.boxFilter(a,cv2.CV_64F,(r,r))
    
    mean_b = cv2.boxFilter(b,cv2.CV_64F,(r,r))
    
    q = mean_a*im + mean_b
    
    return q


def TransmissionRefine(im,et):
    
    gray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    
    gray = np.float64(gray)/255
    
    r = 60
    
    eps = 0.0001
    
    t = Guidedfilter(gray,et,r,eps)
    
    return t

# 恢复得到无雾图

In [6]:
def Recover(im,t,A,tx = 0.1,tm = 1.0):
    
    res = np.empty(im.shape,im.dtype)
    
    t = cv2.max(t,tx)
    
    t = cv2.min(t,tm)
    
    for idx in range(0,3):
        
        res[:,:,idx] = (im[:,:,idx]-A[0,idx])/t + A[0,idx]
        
    return res

# 单幅图片测试

In [7]:
PATCH_SIZE = 16
model_dir = './model'
im_path = './demo/1.png'

In [8]:
if __name__ == "__main__":
    K.clear_session()
    t1 = time.time()
    src = cv2.imread(im_path)    #src.shape == [h,w,c]
    (h,w) = src.shape[0:2]
    net = tf.keras.models.load_model(os.path.join(model_dir,'model.hdf5'), 
                                    custom_objects={'MaxoutConv2D':MaxoutConv2D,                                                   
                                                   "r2":r2})
    
    if w%PATCH_SIZE != 0 or h%PATCH_SIZE != 0:
        if w%PATCH_SIZE != 0:
            w = w - (w%PATCH_SIZE)
        if h%PATCH_SIZE != 0:
            h = h - (h%PATCH_SIZE)
        src = src[0:h, 0:w]
        cv2.imwrite(im_path,src)   
        src = cv2.imread(im_path)  
    t_src = time.time()
    print("image reading cost: ", t_src-t1)
    height = src.shape[0]
    width =src.shape[1]
    I = src/255.0
    dark = DarkChannel(I,16)
    A = AtmLight(I,dark)
    print("A:",A)
    t_A = time.time()
    print("A estimate cost:", t_A-t_src)
    te = TransmissionEstimate(src,net)
    t = TransmissionRefine(src,te)
    t_t = time.time()
    print("tmap estimate cost:", t_t-t_A)
    J = Recover(I,t,A,0.1)
    suffix = im_path.rpartition('.')[-1]
    prefix = im_path.rpartition('.')[0]
    save_path = prefix + '_Dehaze_proposed.' + suffix
    cv2.imwrite(save_path,J*255)    
    t2 = time.time()
    print("time cost:",t2-t1)
    print('Work Done! Dehazed image im path:',save_path)

image reading cost:  10.770394325256348
A: [[0.89048302 0.83189861 0.82018173]]
A estimate cost: 0.05343365669250488
tmap estimate cost: 4.438108682632446
time cost: 15.352979898452759
Work Done! Dehazed image im path: ./demo/1_Dehaze_proposed.png
