In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from scipy.optimize import curve_fit

In [2]:
circle = np.load('circle.npy')

def norm_hist(A):
    '''Calculate the normalised histogram of a single-band image
    input:
        A(2d np array): single-band image
    return:
        h: histogram
        b: bins
    '''
    hist = plt.hist((A/255).ravel(),range=(0,1),bins=256)
    h, b, _ = hist
    _ = plt.close()
    return h, b

def gauss(x, A, mu, sigma):
    '''Gaussian distribution:
        A: Amplitude
        mu: mean
        sigma: standard deviation'''
    return A*np.exp(-(x-mu)**2/(2.*sigma**2))

def gauss_fit(h,b):
    '''Fit Gaussian distribution to histogram
    input:
        h: histogram
        b:bins
    return: parameters for ASF
        k: cut-off (mean of Gaussian)
        g: gain (inverse std of Gaussian)'''
    
    # Calculate bin centers 
    c = (b[:-1] + b[1:])/2
    
    # Try fit gaussian distribution
    try:
        coeff, _ = curve_fit(gauss, c, h)
        _,k,g = coeff
        g = abs(1/g)
    
    # If cannot fit, assign g=7.5, k=0.5
    except:
        g = 7.5
        k = 0.5
    return k,g

def ASF(A,g,k):
    '''Adaptive sigmoid function
        g: gain
        k: cut-off
    Values of g & k chosen by Gaussian fit on histogram
    '''
    return 1/(1+np.exp(g*(k-A)))

def new(band):
    '''Create new band
    All calculation is applied on the region inside the circle of the image'''
    # Assign new band as a copy of old band
    band_new = band.copy()
    
    # Calculate the normalised histogram
    h_orig = norm_hist(band[circle])
    
    # Calculate cut-off and gain by Gaussian fit on histogram
    k,g = gauss_fit(*h_orig)
    
    # Apply ASF
    band_sig = np.array(ASF(band[circle]/255,g,k)*255,dtype=np.uint8)
    
    # Apply histogram equalization
    band_new[circle] = cv2.equalizeHist(band_sig).ravel()
    return band_new

In [13]:
im_dir = 'data/test/'
# rgb_out_dir = 'Results/test/RGB/'
# lab_out_dir = 'Results/test/LAB/'
rgb_out_dir = 'Results/test/presentation/'
lab_out_dir = 'Results/test/presentation/'
for fname in os.listdir(im_dir):
    im = plt.imread(os.path.join(im_dir,fname))
    
    R,G,B = [im[:,:,i] for i in [0,1,2]]
    im_rgb_new = np.dstack([new(R),new(G),new(B)])
    plt.imsave(os.path.join(rgb_out_dir,fname+'_rgb.png'),im_rgb_new)

    im_lab = cv2.cvtColor(im,cv2.COLOR_RGB2Lab)
    L,a,b = [im_lab[:,:,i] for i in [0,1,2]]
    im_lab_new = cv2.cvtColor(np.dstack([new(L),a,b]),cv2.COLOR_Lab2RGB)
    plt.imsave(os.path.join(lab_out_dir,fname+'_lab.png'),im_lab_new)
    