# Numpy implementation of Reinhard et al Color Transfer

Implementing the color transfer technique described in [Reinhard et al](https://ieeexplore.ieee.org/document/946629)


In [1]:
# Copyright 2020 Filippo Aleotti
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
import numpy as np
import urllib.request as urllib
from matplotlib import pyplot as plt

In [2]:
def std(x: np.array) -> np.array:
    """Get std of vector x, with shape Nx3
    Params:
        x: Nx3 array, containing the difference wrt the mean value
    """
    num_elements = x.shape[0]
    std_list = []
    for i in range(3):
        std_squared = (x[:, i] ** 2).sum() / num_elements
        std = np.sqrt(std_squared)
        std_list.append(std)
    return np.array(std_list)


def mean(x: np.array) -> np.array:
    """Get the mean along each axis of vector x, with shape Nx3
    Params:
        x: Nx3 array
    """
    num_elements = x.shape[0]
    mean_list = []
    for i in range(3):
        mean = x[:, i].sum() / num_elements
        mean_list.append(mean)
    return np.array(mean_list)


def to_LAB(rgb: np.array) -> np.array:
    """From RGB space to LAB space"""
    to_lms_matrix = np.array(
        [[0.3811, 0.5783, 0.0402], [0.1967, 0.7244, 0.0782], [0.0241, 0.1288, 0.8444]]
    )
    to_lab_first_matrix = np.diag(np.array([0.57735027, 0.40824829, 0.70710678]))

    to_lab_second_matrix = np.ones((3, 3), dtype=np.float32)
    to_lab_second_matrix[1, 2] = -2
    to_lab_second_matrix[2, 1] = -1
    to_lab_second_matrix[2, 2] = 0

    to_lms = lambda x: to_lms_matrix.dot(x)
    lms = np.apply_along_axis(to_lms, 1, rgb)

    to_log = lambda t: np.log10(t + np.finfo(np.float32).eps)
    lms_log = np.apply_along_axis(to_log, 1, lms)

    to_lab = lambda x: to_lab_first_matrix.dot(to_lab_second_matrix.dot(x))
    lab = np.apply_along_axis(to_lab, 1, lms_log)

    return lab


def to_RGB(lab: np.array) -> np.array:
    """From LAB space to RGB"""
    to_rgb_matrix = np.array(
        [
            [4.4679, -3.5873, 0.1193],
            [-1.2186, 2.3809, -0.1624],
            [0.0497, -0.2439, 1.2045],
        ]
    )
    to_log_lms_first_matrix = np.ones((3, 3), dtype=np.float32)
    to_log_lms_first_matrix[1, 2] = -1
    to_log_lms_first_matrix[2, 1] = -2
    to_log_lms_first_matrix[2, 2] = 0
    to_log_lms_second_matrix = np.diag(np.array([0.57735027, 0.40824829, 0.70710678]))

    to_log_lms = lambda x: to_log_lms_first_matrix.dot(to_log_lms_second_matrix.dot(x))
    log_lms = np.apply_along_axis(to_log_lms, 1, lab)

    to_lms = lambda t: np.power(10, t)
    lms = np.apply_along_axis(to_lms, 1, log_lms)
    to_rgb = lambda x: to_rgb_matrix.dot(x)
    rgb = np.apply_along_axis(to_rgb, 1, lms)

    return rgb


def color_transfer(src: np.array, tgt: np.array) -> np.array:
    """Apply color transfer from tgt to src image
    Params:
        src: HxWx3 RGB array with src image
        tgt: HxWx3 RGB array with tgt image
    Return:
        HxWx3 RGB image. Pixels from src have colors aligned with
        with tgt
    """
    h, w = src.shape[:2]
    src = np.reshape(src, (-1, 3))
    tgt = np.reshape(tgt, (-1, 3))

    src_lab = to_LAB(src)
    tgt_lab = to_LAB(tgt)

    src_mean_lab = mean(src_lab)
    tgt_mean_lab = mean(tgt_lab)

    src_lab_star = src_lab - src_mean_lab
    tgt_lab_star = tgt_lab - tgt_mean_lab

    src_lab_std = std(src_lab_star)
    tgt_lab_std = std(tgt_lab_star)

    new_src = src_lab_star * (tgt_lab_std / src_lab_std)
    new_src += tgt_mean_lab

    transformed = to_RGB(new_src)
    transformed = np.reshape(transformed, (h, w, 3))
    transformed[transformed > 255.0] = 255.0
    transformed[transformed < 0.0] = 0.0
    transformed = transformed.astype(np.uint8)

    return transformed

In [None]:
url_src = "https://encrypted-tbn0.gstatic.com/images?q=tbn%3AANd9GcQOoM4koKErzN12EKS0BCoK9UW9waVgM5vIDA&usqp=CAU"
resp = urllib.urlopen(url_src)
image = np.asarray(bytearray(resp.read()), dtype=np.uint8)
src = cv2.imdecode(image, cv2.IMREAD_COLOR)

url_tgt = "https://encrypted-tbn0.gstatic.com/images?q=tbn%3AANd9GcQDFXLZOhschjWJGvwe4hbdkdNQ2RjNWojE6A&usqp=CAU"
resp = urllib.urlopen(url_tgt)
image = np.asarray(bytearray(resp.read()), dtype=np.uint8)
tgt = cv2.imdecode(image, cv2.IMREAD_COLOR)

if src.shape != tgt.shape:
  tgt = cv2.resize(tgt, (src.shape[1], src.shape[0]))

# opencv handles images as BGR. 
# Instead, plt uses RGB color convention
src_rgb = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
tgt_rgb = cv2.cvtColor(tgt, cv2.COLOR_BGR2RGB)

grid = np.hstack((src_rgb, tgt_rgb))
plt.imshow(grid)

<matplotlib.image.AxesImage at 0x7f828118a828>

We aim to change colors of source image using those of target image as reference.

In [None]:
transformed = color_transfer(src_rgb, tgt_rgb)

Have a look at the result

In [None]:
grid = np.hstack((src_rgb, tgt_rgb, transformed))
plt.imshow(grid)