In [None]:
import numpy as np
import cv2
from matplotlib import pyplot as plt

%matplotlib inline


In [None]:
def readRawImage(url, rows, cols, color=False):
    count = rows*cols*3 if color == True else rows*cols
    shape = (rows, cols, 3) if color == True else (rows, cols)

    fd = open(url, 'rb')
    file = np.fromfile(fd, dtype=np.uint8, count=count)
    fd.close()
    return file.reshape(shape)


def showScatter(x, y):
    plt.figure(dpi=188)
    plt.axis([np.min(x)-0.1, np.max(x)+0.1, np.min(y)-0.1, np.max(y)+0.1])
    plt.scatter(x, y, s=1)
    plt.show()


def showImage(image, title="", cmap=None):
    plt.figure(dpi=188)
    plt.imshow(image, cmap=None, vmin=0, vmax=255)
    plt.axis("off")
    plt.title(title)
    plt.show()


def showImageHistogram(image, title=""):
    plt.figure(dpi=188)
    plt.hist(image.ravel(), bins=256, range=(0, 255))
    plt.title(title)
    plt.show()


def showImageCDF(image, title=""):
    plt.figure(dpi=188)
    hist, bin_edges = np.histogram(image, bins=256, range=(0, 255))
    cdf = np.cumsum(hist/256/256)
    plt.title(title + "\nSuggest threshold: " +
              str(np.where(cdf > 0.70)[0][0]))
    plt.show()


In [None]:
cvtMatrices = {}

# RGB2XYZ_Matrix = np.array([[0.5141, 0.3239, 0.1604],
#                            [0.2651, 0.6702, 0.0641],
#                            [0.0241, 0.1228, 0.8444]])
cvtMatrices['rgb2xyz'] = np.array([[0.5141, 0.3239, 0.1604],
                                   [0.2651, 0.6702, 0.0641],
                                   [0.0241, 0.1228, 0.8444]])
cvtMatrices['xyz2rgb'] = np.array([[2.5655219, -1.16682231, -0.3987641],
                                   [-1.02201355,  1.97795944,  0.04398836],
                                   [0.07540761, -0.25434984,  1.1892568]])


# base on https://docs.opencv.org/3.3.0/de/d25/imgproc_color_conversions.html
# RGB2XYZ_Matrix = np.array([[0.412453, 0.357580, 0.180423],
#                            [0.212671, 0.715160, 0.072169],
#                            [0.019334, 0.119193, 0.950227]])
# cvtMatrices['rgb2xyz'] = np.array([[0.412453, 0.357580, 0.180423],
#                                    [0.212671, 0.715160, 0.072169],
#                                    [0.019334, 0.119193, 0.950227]])


# XYZ2LMS_Matrix = np.array([[0.3897, 0.6890, -0.0787],
#                            [-0.2298, 1.1834, 0.0464],
#                            [0.0000, 0.0000, 1.0000]])
cvtMatrices['xyz2lms'] = np.array([[0.3897, 0.6890, -0.0787],
                                   [-0.2298, 1.1834, 0.0464],
                                   [0.0000, 0.0000, 1.0000]])
cvtMatrices['lms2xyz'] = np.array([[1.91024040e+00, -1.11218154e+00,  2.01941143e-01],
                                   [3.70942406e-01,  6.29052461e-01,  5.13314556e-06],
                                   [0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])

# LMS2LAB_Matrix1 = np.array([[1/np.sqrt(3), 0.0000, 0.0000],
#                             [0.0000, 1/np.sqrt(6), 0.0000],
#                             [0.0000, 0.0000, 1/np.sqrt(2)]])
# LMS2LAB_Matrix2 = np.array([[1.0000, 1.0000, 1.0000],
#                             [1.0000, 1.0000, -2.0000],
#                             [1.0000, -1.0000, 0.0000]])
# LMS2LAB_Matrix = np.dot(LMS2LAB_Matrix1, LMS2LAB_Matrix2)
cvtMatrices['lms2lab'] = np.array([[0.57735027,  0.57735027,  0.57735027],
                                   [0.40824829,  0.40824829, - 0.81649658],
                                   [0.70710678, - 0.70710678,  0.]])
cvtMatrices['lab2lms'] = np.array([[5.77350269e-01,  4.08248290e-01,  7.07106781e-01],
                                   [5.77350269e-01,  4.08248290e-01, -7.07106781e-01],
                                   [5.77350269e-01, -8.16496581e-01,  1.14046500e-16]])


In [None]:
def Normalize(image):
    return image.astype('float32') / np.float32(255)


def Unnormalize(image):
    return image.astype('float32') * np.float32(255)


def gammaCorrection(image, gamma=1.0):
    return np.power(image, 1.0 / gamma)


def convertRGB2LAB(image):
    image = convertColorSpace(image, 'rgb2xyz')
    image = convertColorSpace(image, 'xyz2lms')
    image = convertColorSpace(image, 'lms2lab')
    return image


def convertLAB2RGB(image):
    image = convertColorSpace(image, 'lab2lms')
    image = convertColorSpace(image, 'lms2xyz')
    image = convertColorSpace(image, 'xyz2rgb')
    return image


def convertColorSpace(image, flag):
    imageTmp = image.copy()

    if flag == 'lms2lab':
        imageTmp[imageTmp == 0.0] = np.float32(0.00000000001)
        imageTmp = np.log10(imageTmp).dot(cvtMatrices[flag].T)
    elif flag == 'lab2lms':
        imageTmp = np.power(10, imageTmp.dot(cvtMatrices[flag].T))
    else:
        imageTmp = imageTmp.dot(cvtMatrices[flag].T)
    return imageTmp


def convertGrayWorld(channel):
    return channel - np.median(channel)


def splitChannels(image):
    return np.dsplit(image, 3)


def stackChannels(channels):
    return np.dstack(channels)


def clipChannel(channel):
    channel = np.clip(channel, 0, 255)
    return channel


In [None]:
def correctColor(number, extension):
    gamma = 0.2
    image = cv2.imread('./input/' + number + '.' + extension)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image = Normalize(image)
    image = gammaCorrection(image, gamma=gamma)
    image = convertRGB2LAB(image)

    l, a, b = splitChannels(image)
    a, b = convertGrayWorld(a), convertGrayWorld(b)
    image = stackChannels((l, a, b))

    image = convertLAB2RGB(image)
    image = gammaCorrection(image, gamma=1/gamma)
    image = Unnormalize(image)

    # fit 0 ~ 255
    r, g, b = splitChannels(image)
    r, g, b = clipChannel(r), clipChannel(g), clipChannel(b)
    image = stackChannels((r, g, b))

    return image


In [None]:
def loadImgSrc():
    import json
    with open('./input/imgSrc.json', 'r') as file:
        return json.loads(file.read())


imgSrc = loadImgSrc()


for name in imgSrc:
    image = correctColor(name, imgSrc[name]['ext'])
#     showImage(np.uint8(image))

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite('./output/' + name + '.png', image)
