In [1]:
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import pywt
from scipy.optimize import newton
import matplotlib

In [2]:
pywt.wavelist()

['bior1.1',
 'bior1.3',
 'bior1.5',
 'bior2.2',
 'bior2.4',
 'bior2.6',
 'bior2.8',
 'bior3.1',
 'bior3.3',
 'bior3.5',
 'bior3.7',
 'bior3.9',
 'bior4.4',
 'bior5.5',
 'bior6.8',
 'cgau1',
 'cgau2',
 'cgau3',
 'cgau4',
 'cgau5',
 'cgau6',
 'cgau7',
 'cgau8',
 'cmor',
 'coif1',
 'coif2',
 'coif3',
 'coif4',
 'coif5',
 'coif6',
 'coif7',
 'coif8',
 'coif9',
 'coif10',
 'coif11',
 'coif12',
 'coif13',
 'coif14',
 'coif15',
 'coif16',
 'coif17',
 'db1',
 'db2',
 'db3',
 'db4',
 'db5',
 'db6',
 'db7',
 'db8',
 'db9',
 'db10',
 'db11',
 'db12',
 'db13',
 'db14',
 'db15',
 'db16',
 'db17',
 'db18',
 'db19',
 'db20',
 'db21',
 'db22',
 'db23',
 'db24',
 'db25',
 'db26',
 'db27',
 'db28',
 'db29',
 'db30',
 'db31',
 'db32',
 'db33',
 'db34',
 'db35',
 'db36',
 'db37',
 'db38',
 'dmey',
 'fbsp',
 'gaus1',
 'gaus2',
 'gaus3',
 'gaus4',
 'gaus5',
 'gaus6',
 'gaus7',
 'gaus8',
 'haar',
 'mexh',
 'morl',
 'rbio1.1',
 'rbio1.3',
 'rbio1.5',
 'rbio2.2',
 'rbio2.4',
 'rbio2.6',
 'rbio2.8',
 'rbio3.1',

In [2]:
matplotlib.pyplot.ion()

In [25]:
def convert_to_uint(x):
    m,n = x.shape
    for i in range(m):
        for j in range(n):
            x[i][j] = int(abs(x[i][j]))
            x[i][j] = max(x[i][j],0)
            x[i][j] = min(x[i][j],255)
    return x


def adaptation_average_grey(a_j,alpha, average):
    m,n = a_j.shape
    new_a_j = np.zeros((m,n))
    
    #average = a_j.mean()
    for i in range(m):
        for j in range(n):
            new_a_j[i][j] = alpha*average + (1-alpha)*a_j[i][j]
            
    return new_a_j

def local_contrast_enhancement(d_jk,a_jk,w_j):
    m,n = d_jk.shape
    new_d_jk = np.zeros((m,n))
    
    T = d_jk.max()/10
    for i in range(m):
        for j in range(n):
            if d_jk[i][j] >= T:
                def F(x):
                    return x-d_jk[i][j] - w_j*a_jk[i][j]/x
                try:
                    new_d_jk[i][j] = newton(func=F, x0=d_jk[i][j])
                except RuntimeError:
                     new_d_jk[i][j] = d_jk[i][j]
            else:
                new_d_jk[i][j] = d_jk[i][j]
            
    return new_d_jk
    

    
def enhance_image(img):
    img = mpimg.imread(img)
    r = img[:,:,0]
    g = img[:,:,1]
    b = img[:,:,2]
    
    wp_r = pywt.WaveletPacket2D(data=r, wavelet='sym8', mode='symmetric')
    wp_g = pywt.WaveletPacket2D(data=g, wavelet='sym8', mode='symmetric')
    wp_b = pywt.WaveletPacket2D(data=b, wavelet='sym8', mode='symmetric')
    
    scale = wp_r.maxlevel
    scale_index = 'a'*(scale-1)  
    
    m,n = wp_r[scale_index+'a'].data.shape
    average = wp_r[scale_index+'a'].data.sum() + wp_g[scale_index+'a'].data.sum() + wp_b[scale_index+'a'].data.sum()
    average = average/(3*m*n)
    
    new_r = provenzi_caselles(wp_r,average)
    new_g = provenzi_caselles(wp_g,average)
    new_b = provenzi_caselles(wp_b,average)
    
    return new_r, new_g, new_b
    
def provenzi_caselles(wp,average):
    scale = wp.maxlevel
    scale_index = 'a'*(scale-1)   
    
    a_j = adaptation_average_grey(wp[scale_index+'a'].data,0.1, average)
    while len(scale_index) >= 0:

        wp[scale_index+'a'] = a_j
        wp[scale_index+'h'] = local_contrast_enhancement(wp[scale_index+'h'].data,a_j,0.5)
        wp[scale_index+'v'] = local_contrast_enhancement(wp[scale_index+'v'].data,a_j,0.5)
        wp[scale_index+'d'] = local_contrast_enhancement(wp[scale_index+'d'].data,a_j,0.5)
        wp[scale_index].reconstruct()
        a_j = wp[scale_index].data
        if len(scale_index) > 0:
            scale_index= scale_index[:-1]
        else:
            break
    
    
    return convert_to_uint(wp.data)
    
    

In [26]:
path = './misc/4.1.06.tiff'

In [27]:
r,g,b = enhance_image(path)
rgb = np.stack((r,g,b),axis=2).astype('uint8')

In [28]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.imshow(rgb,cmap='gray')

<matplotlib.image.AxesImage at 0x7f4c047db700>

In [29]:
img = mpimg.imread(path)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.imshow(img,cmap='gray')

<matplotlib.image.AxesImage at 0x7f4c04748700>