In [None]:
import numpy as np
import cv2
from matplotlib import pyplot as plt

%matplotlib inline


In [None]:
# base on paper

# RGB2XYZ_Matrix = np.array([[0.5141, 0.3239, 0.1604],
#                            [0.2651, 0.6702, 0.0641],
#                            [0.0241, 0.1228, 0.8444]])

XYZ2LMS_Matrix = np.array([[0.3897, 0.6890, -0.0787],
                           [-0.2298, 1.1834, 0.0464],
                           [0.0000, 0.0000, 1.0000]])

LMS2LAB_Matrix1 = np.array([[1/np.sqrt(3), 0.0000, 0.0000],
                            [0.0000, 1/np.sqrt(6), 0.0000],
                            [0.0000, 0.0000, 1/np.sqrt(2)]])
LMS2LAB_Matrix2 = np.array([[1.0000, 1.0000, 1.0000],
                            [1.0000, 1.0000, -2.0000],
                            [1.0000, -1.0000, 0.0000]])
LMS2LAB_Matrix = np.dot(LMS2LAB_Matrix1, LMS2LAB_Matrix2)

# base on https://docs.opencv.org/3.3.0/de/d25/imgproc_color_conversions.html
RGB2XYZ_Matrix = np.array([[0.412453, 0.357580, 0.180423],
                           [0.212671, 0.715160, 0.072169],
                           [0.019334, 0.119193, 0.950227]])

In [None]:
def readRawImage(url, rows, cols, color=False):
    count = rows*cols*3 if color == True else rows*cols
    shape = (rows, cols, 3) if color == True else (rows, cols)

    fd = open(url, 'rb')
    file = np.fromfile(fd, dtype=np.uint8, count=count)
    fd.close()
    return file.reshape(shape)


def showScatter(x, y):
    plt.figure(dpi=188)
    plt.scatter(x, y, s=1)


def showImage(image, title=""):
    plt.figure(dpi=188)
    plt.imshow(image, cmap=plt.cm.gray, vmin=0, vmax=255)
    plt.axis("off")
    plt.title(title)


def showImages(datas, cols):
    plt.figure(num=None, figsize=(
        18, 18 * (((len(datas) - 1)//cols) + 1) / cols), dpi=94)
    for index, data in enumerate(datas):
        plt.subplot((len(datas) - 1)//cols + 1, cols, index+1)
        plt.imshow(data["image"], cmap=plt.cm.gray, vmin=0, vmax=255)
        plt.axis("off")
        plt.title(data["title"])


def showImagesHistogram(datas, cols):
    plt.figure(num=None, figsize=(18, 4), dpi=94)
    for index, data in enumerate(datas):
        plt.subplot(len(datas)/cols+1, cols, index+1)
        plt.hist(data["image"].ravel(), bins=256, range=(0, 255))
        plt.title(data["title"])


def showImagesCDF(datas, cols):
    plt.figure(num=None, figsize=(18, 4), dpi=94)
    for index, data in enumerate(datas):
        hist, bin_edges = np.histogram(data["image"], bins=256, range=(0, 255))
        cdf = np.cumsum(hist/256/256)
        plt.subplot(len(datas)/cols+1, cols, index+1)
        plt.plot(cdf)
        plt.title(data["title"] + "\nSuggest threshold: " +
                  str(np.where(cdf > 0.70)[0][0]))


def Normalize(image):
    return image.astype('float32') / np.float32(255)

def Unnormalize(image):
    return image.astype('float32') * np.float32(255)


def convertRGB2LAB(image, gamma=1.0):
    image = Normalize(image)
    image = gammaCorrection(image, gamma=gamma)
    image = convertRGB2XYZ(image)
    image = convertXYZ2LMS(image)
    image = convertLMS2LAB(image)
    return image


def convertLAB2RGB(image, gamma=1.0):
    image = convertLAB2LMS(image)
    image = convertLMS2XYZ(image)
    image = convertXYZ2RGB(image)
    image = Unnormalize(image)
    return image


def gammaCorrection(image, gamma=1.0):
    return np.power(image, 1.0 / gamma)


def convertRGB2XYZ(image):
    return image.dot(RGB2XYZ_Matrix.T)


def convertXYZ2RGB(image):
    inverse = np.linalg.inv(RGB2XYZ_Matrix)
    return image.dot(inverse.T)


def convertXYZ2LMS(image):
    return image.dot(XYZ2LMS_Matrix.T)


def convertLMS2XYZ(image):
    inverse = np.linalg.inv(XYZ2LMS_Matrix)
    return image.dot(inverse.T)


def convertLMS2LAB(image):
    image[image == 0.0] = np.float32(0.00001)
    image = np.log(image).dot(LMS2LAB_Matrix.T)
    return image


def convertLAB2LMS(image):
    inverse = np.linalg.inv(LMS2LAB_Matrix)
    return np.exp(image.dot(inverse.T))


def splitChannels(image):
    return np.dsplit(image, 3)


def stackChannels(channels):
    return np.dstack(channels)


def convertGrayWorld(channel):
    return channel - np.mean(channel)


In [None]:
def correctColor(number, extension):
    image = cv2.imread('./input/' + number + '.' + extension)

    image = convertRGB2LAB(image, gamma=1)

    l, a, b = splitChannels(image)

    a = convertGrayWorld(a)
    b = convertGrayWorld(b)

    image = stackChannels((l, a, b))

    image = convertLAB2RGB(image)

#     showImage(np.uint8(image))
    showScatter(a, b)


In [None]:
correctColor('0002', 'png')
