Based on algorithm from article "SIMPLE EFFECTIVE IMAGE AND VIDEO COLOR CORRECTION USING QUATERNION DISTANCE METRIC" by Soo-Chang Pei and Yu-Zhe Hsiao

In [26]:
from PIL import Image
from tqdm import tqdm
from itertools import product

import numpy as np
import os
from random import shuffle

Quaternion class

In [27]:
class Vector:
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z

    def __add__(self, other):
        return Vector(self.x + other.x, self.y + other.y, self.z + other.z)

    def __mul__(self, other):
        if isinstance(other, float) or isinstance(other, int):
            return Vector(self.x * other, self.y * other, self.z * other)
        else:
            return self.x * other.x + self.y * other.y + self.z * other.z  # scalar product

    def __rmul__(self, other):
        return self * other

    def __pow__(self, other):  # vector product
        if isinstance(other, Vector):
            return Vector(
                self.y * other.z - self.z * other.y,
                self.z * other.x - self.x * other.z,
                self.x * other.y - self.y * other.x
            )

    def __str__(self):
        return f"Vec({self.x} {self.y} {self.z})"

    def __repr__(self):
        return str(self)

In [28]:
class Quaternion:
    def __init__(self, r, u):
        self.r = r
        self.u = u

    def __add__(self, other):
        return Quaternion(self.r + other.r, self.u + other.u)

    def __mul__(self, other):
        if isinstance(other, Quaternion):
            return Quaternion(
                self.r * other.r - self.u * other.u,
                self.r * other.u + other.r * self.u + self.u ** other.u
            )
        else:
            return Quaternion(self.r * other, self.u * other)

    def __abs__(self):
        return (self.r ** 2 + self.u * self.u) ** .5

    def __pos__(self):  # conjugate operation
        return Quaternion(self.r, self.u * (-1))

    def __rmul__(self, other):
        return self * other

    def __str__(self):
        return f"{self.r} + {self.u}"

    def __repr__(self):
        return str(self)

Image representation as Quaternion array

In [29]:
q3 = 3**.5 * Quaternion(0, Vector(1, 1, 1))  # transformation axis

In [30]:
def mean_pixel_img(targ_img):
    global q3
    global ref_img
    
    width, height = RESIZE
    res = [[0] * (height // MASK_SIZE) for _ in range(width // MASK_SIZE)]
    for i in range(width // MASK_SIZE):
        for j in range(height //  MASK_SIZE):
            temp = Quaternion(0, Vector(0, 0, 0))
            for k, l in product(range(MASK_SIZE), repeat=2):
                q2 = Quaternion(0, Vector(*ref_img.getpixel((i * MASK_SIZE + k, j * MASK_SIZE + l))))
                q1 = Quaternion(0, Vector(*targ_img.getpixel((i * MASK_SIZE + k, j * MASK_SIZE + l))))
                temp += q2 + q3 * q1 * (+q3)
                
            res[i][j] = temp * (1 / MASK_SIZE**2)
    
    return res

Color distance calculation

In [31]:
def func_CD(q1, q2):  # color distance
    return DISTANCE_WEIGHTING * abs(func_Q(q1, q2)) + (1 - DISTANCE_WEIGHTING) * abs(func_I(q1, q2))

In [32]:
def func_I(q1, q2):  # luminance distance
    return (q2.u.x - q1.u.x + q2.u.y - q1.u.y + q2.u.z - q1.u.z) / 3

In [33]:
def func_Q(q1, q2):  # chromatically difference
    global q3
    
    q4 = q2 + q3 * q1 * (+q3)
    r, g, b = q4.u.x, q4.u.y, q4.u.z
    m = (r + g + b) / 3
    return Quaternion(0, Vector(r - m, g - m, b - m))

In [34]:
def calc_CD(pixel):
    global mean_pixel_ref_img
    
    res = [float("inf"), -1, -1]
    for i, row in enumerate(mean_pixel_ref_img):
        for j, el in enumerate(row):
            dist = func_CD(pixel, el)
            if dist < res[0]:
                res = [dist, i, j]
    return res

Image Conversion

In [35]:
def change_image(targ_img, matrix, filename_to_save):
    def process_func(pixel):
       return tuple(map(int, np.matmul(matrix, np.array(pixel).transpose())))

    pixel_map = targ_img.load()
    width, height = RESIZE
    for i in range(width):
        for j in range(height):
            pixel_map[i, j] = process_func(targ_img.getpixel((i, j)))

    targ_img.save(filename_to_save, format="jpeg")

In [44]:
dir_in = "C:\\Users\\809210\\Desktop\\l____l\\CITEC\\train-org-img\\"
dir_out = "C:\\Users\\809210\\Desktop\\l____l\\CITEC\\dataset\\"

In [38]:
RESIZE = (640, 480)  # size of image after resize
MASK_SIZE = 10
N_PIXELS = RESIZE[0] * RESIZE[1]

In [4]:
ref_dir = "C:\\Users\\809210\\Desktop\\l____l\\CITEC\\ColorChecker_FoggyNight.jpg"  # reference image for conversion

In [39]:
DISTANCE_WEIGHTING = .2  # from 0 to 1
THRESOLD_VALUE = 75  # maximum distance for close colors

In [40]:
def converse_image(targ_img, filename_to_save):
    def find_pixel(pixel):
        return tuple(map(lambda c: c * MASK_SIZE + MASK_SIZE // 2, pixel))
    
    global ref_img
    global mean_pixel_ref_img
    
    mean_pixel_targ_img = mean_pixel_img(targ_img)
    
    close_colors = []
    with tqdm(total=N_PIXELS // MASK_SIZE**2, position=0, leave=False) as pbar2:
        for i, row in enumerate(mean_pixel_targ_img):
            for j, el in enumerate(row):
                color_distance = calc_CD(el)
                if color_distance[0] < THRESOLD_VALUE:
                    close_colors += [((i, j), tuple(color_distance[1:]))]
                pbar2.set_description(f"({i} {j}: {color_distance[0]})")
                pbar2.update()

    if len(close_colors) < 3:
        return
    a = []
    b = []
    for (targ_pix, ref_pix) in close_colors:
        a += [find_pixel(ref_img.getpixel(ref_pix))]
        b += [find_pixel(targ_img.getpixel(targ_pix))]
    
    converse_matrix = np.matmul(np.linalg.pinv(np.array(b)), np.array(a)).transpose()
    change_image(targ_img, converse_matrix, filename_to_save)

In [41]:
filenames = tuple(*os.walk(dir_in))
limit = 20
start = 20
counter = 0

with Image.open(ref_dir).resize(RESIZE) as ref_img:
    mean_pixel_ref_img = mean_pixel_img(ref_img)

    with tqdm(total=limit, leave=True) as pbar:
        for file in filenames[2][start:]:
            with Image.open(dir_in + file).resize(RESIZE) as targ_img:
                converse_image(targ_img, dir_out + file)
            pbar.update()
            counter += 1
            if counter >= limit:
                break

In [47]:
color_map_ref = "CrispWinter"
ref_dir = "C:\\Users\\809210\\Desktop\\l____l\\CITEC\\ColorChecker_" + color_map_ref + ".jpg"
filenames = list(*os.walk(dir_in))
shuffle(filenames[2])
limit = 20
counter = 0

with Image.open(ref_dir).resize(RESIZE) as ref_img:
    mean_pixel_ref_img = mean_pixel_img(ref_img)
    
    with tqdm(total=limit, leave=True) as pbar:
        for file in filenames[2]:
            with Image.open(dir_in + file).resize(RESIZE) as targ_img:
                converse_image(targ_img, dir_out + color_map_ref + '\\' + file)
            pbar.update()
            counter += 1
            if counter >= limit:
                break

 20%|████████████████                                                                | 4/20 [20:55<1:23:42, 313.91s/it]


KeyboardInterrupt: 