In [6]:
import sys
import numpy as np
import cv2
import math
import tensorflow as tf
import os
import matplotlib
# matplotlib.use("Agg") #可以使matplotlib保存.png图到磁盘
import matplotlib.pyplot as plt
from PIL import Image
import time

from tensorflow.keras import backend as K
from modelutils import MaxoutConv2D,r2

In [7]:
%load_ext autoreload
%autoreload 2

#jupyter 中开启该选项，否则不执行
%matplotlib inline

# 估计大气光值A 

In [8]:
# get the dark channel of im,and estimate AtmLight of im

# get the dark channel
def DarkChannel(im,sz):
    
    b,g,r = cv2.split(im)
    
    dc = cv2.min(cv2.min(r,g),b)
    
    # the next step's func just like kernel stride of dc,and get the min
    # value in the szXsz patch 
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(sz,sz))
    
    dark = cv2.erode(dc,kernel)    
    
    return dark

# estimate the atmospheric light
def AtmLight(im,dark):
    
    [h,w] = im.shape[:2]
    
    imsz = h*w
    
    numpx = int(max(math.floor(imsz/1000),1))
    
    darkvec = dark.reshape(imsz,1)
    
    imvec = im.reshape(imsz,3)
    
    indices = darkvec.argsort()
    
    indices = indices[imsz-numpx::]
    
    atmsum = np.zeros([1,3])
    
    for idx in range(1,numpx):
        
        atmsum = atmsum + imvec[indices[idx]]
        
    A = atmsum / numpx
    
    return A

# 获得图片transmission map
输出粗糙的transmission map 复原的图像会有 块效应


In [86]:
def TransmissionEstimate(im,net):

    
    patch = []
    
    width = im.shape[1]
    
    height = im.shape[0]
    
    transmission = np.empty((height,width),np.float32)
    
    num_w = int(width / PATCH_SIZE)
    
    num_h = int(height / PATCH_SIZE)
     
    for i in range(num_h):
    
        for j in range(num_w):#    h,                                        w,                              c
            
                hazy_patch = im[0+i*PATCH_SIZE:PATCH_SIZE+i*PATCH_SIZE, 0+j*PATCH_SIZE:PATCH_SIZE+j*PATCH_SIZE, :] 
                
                patch.append(hazy_patch)
                
    patch = np.array(patch)
    
    patch = patch.astype(np.float32)
    
    patch = patch/255.0
    
    trans = net(patch)
    
    temp = trans
    
    temp = temp.numpy().ravel()
    
    for i in range(num_h):
        
            for j in range(num_w):
                
                #f_value = temp[i*PATCH_SIZE+j]
                f_value = temp[i*num_w+j]
                
                transmission[0+i*PATCH_SIZE:PATCH_SIZE+i*PATCH_SIZE, 0+j*PATCH_SIZE:PATCH_SIZE+j*PATCH_SIZE] = f_value
    
    return transmission

# 导向滤波，优化transmission map

In [10]:
def Guidedfilter(im,p,r,eps):
    
    mean_I = cv2.boxFilter(im,cv2.CV_64F,(r,r))
    
    mean_p = cv2.boxFilter(p, cv2.CV_64F,(r,r))
    
    mean_Ip = cv2.boxFilter(im*p,cv2.CV_64F,(r,r))
    
    cov_Ip = mean_Ip - mean_I*mean_p
    
    mean_II = cv2.boxFilter(im*im,cv2.CV_64F,(r,r))
    
    var_I   = mean_II - mean_I*mean_I
    
    a = cov_Ip/(var_I + eps)
    
    b = mean_p - a*mean_I
    
    mean_a = cv2.boxFilter(a,cv2.CV_64F,(r,r))
    
    mean_b = cv2.boxFilter(b,cv2.CV_64F,(r,r))
    
    q = mean_a*im + mean_b
    
    return q


def TransmissionRefine(im,et):
    
    gray = cv2.cvtColor(im,cv2.COLOR_BGR2GRAY)
    
    gray = np.float64(gray)/255
    
    r = 60
    
    eps = 0.0001
    
    t = Guidedfilter(gray,et,r,eps)
    
    return t

# 恢复得到无雾图

In [11]:
def Recover(im,t,A,tx = 0.1,tm = 1.0):
    
    res = np.empty(im.shape,im.dtype)
    
    t = cv2.max(t,tx)
    
    t = cv2.min(t,tm)
    
    for idx in range(0,3):
        
        res[:,:,idx] = (im[:,:,idx]-A[0,idx])/t + A[0,idx]
        
    return res

# 批量图片去雾

In [7]:
# py
def dehaze(hazy_img_path, dehazy_img_path):
    
    src = cv2.imread(hazy_img_path)
    
    I = src/255.0
    
    dark = DarkChannel(I,16)
    
    A = AtmLight(I,dark)
    
    te = TransmissionEstimate(src,net)
    
    t = TransmissionRefine(src,te)
    
    J = Recover(I,t,A,0.1)
    
    cv2.imwrite(dehazy_img_path, J*255)

In [8]:
## preprocess img to suit with w%16 == h%16 == 0
def preprocess_img(img_dir_path):
    
    img_list = os.listdir(img_dir_path)
    
    num_img = len(img_list)
    
    cutting_count = 0
    
    for img in img_list:
        
        img_path = os.path.join(img_dir_path,img)
        
        src = cv2.imread(img_path)    # src.shape == [h,w,c]
        
        (h,w) = src.shape[0:2]
        
        if w%patch_size != 0 or h%patch_size != 0:
            
            if w%patch_size != 0:
                
                w -= (w%patch_size)
                
            if h%patch_size != 0:
                
                h -= (h%patch_size)
                
            src = src[0:h, 0:w]  # 裁剪坐标为[y0:y1, x0:x1]
            
            cv2.imwrite(img_path,src)
            
            cutting_count += 1
            
    print("Preprocessing Done! Total %d images, where %d cut" %(num_img, cutting_count))


In [9]:
def batch_dehaze(hazy_dir_path, dehazy_dir_path, methods = 'proposed', preprocess = 1):
    
    begin = time.time()
    
    dehazy_img_list = os.listdir(dehazy_dir_path)    
    
    hazy_img_list = os.listdir(hazy_dir_path)
    
    num_img = len(hazy_img_list)
    
    print('Total %d images need to be dehazed...'%num_img)
    
    dehazing_count = 0
    
    if preprocess == 1:
        
        print('Please waiting for preprocess...')
        
        preprocess_img(hazy_dir_path)
        
        preprocess_done = time.time()
        
        print('Preprocessing Done! Using time %s s'%(preprocess_done-begin))
        
        begin = time.time()
        
        
    print('Begin to dehaze...')
    
    for hazy_img in hazy_img_list:
        
        img_form = hazy_img.split('.',-1)[-1]
        
        img_name = hazy_img.split('.',-1)[0]
        
        dehazy_img = img_name + '_Dehaze_' + methods + '.' + img_form
        
        if dehazy_img in dehazy_img_list:
            
            dehazing_count+=1
            
            print("alreday dehazing %d images."%dehazing_count)
            
            continue
        
        hazy_img_path = os.path.join(hazy_dir_path, hazy_img)
        
        dehazy_img_path = os.path.join(dehazy_dir_path, dehazy_img)
        
        dehaze(hazy_img_path, dehazy_img_path)
        
        dehazing_count += 1
        
        if dehazing_count%3==0:
            
            print('Please waiting for dehaze... Total %d hazy images, %d done...... %d/%d'
                  
                  %(num_img, dehazing_count, dehazing_count, num_img))
            
    print("Work Done!")
    
    print("Total %d images, %d done!" %(num_img, dehazing_count))
    
    time_cost = time.time() - begin
    
    time_cost_per_img = time_cost/num_img
    
    print("Total using %s s, %s per image." %(time_cost, time_cost_per_img))
    

In [10]:
# batch processing 
if __name__ == '__main__':
    
    PATCH_SIZE = 16
    
    model_dir = './model'
    
    hazy_dir_path = r'D:\AProject\Graduation_project_defog\evaluate\hazy'
    
    dehazy_dir_path = r'D:\AProject\Graduation_project_defog\evaluate\proposed'
    
    methods = 'proposed'
    
    net = tf.keras.models.load_model(os.path.join(model_dir,'model.hdf5'), 
                                    custom_objects={'MaxoutConv2D':MaxoutConv2D,                                                   
                                                   "r2":r2})
    
    batch_dehaze(hazy_dir_path, dehazy_dir_path, methods, preprocess = 0)

Total 6 images need to be dehazed...
Begin to dehaze...
Please waiting for dehaze... Total 6 hazy images, 3 done...... 3/6
Please waiting for dehaze... Total 6 hazy images, 6 done...... 6/6
Work Done!
Total 6 images, 6 done!
Total using 10.518873691558838 s, 1.7531456152598064 per image.
