# Compressing images in the YCoCg domain

Compare the performance of compressing images in the RGB and YCoCg domains.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pylab
%matplotlib inline
import image
import YCoCg as YUV
import deadzone as Q
import distortion
import information

## Global parameters of the notebook

In [None]:
#test_image = "../sequences/stockholm/"
test_image = "../sequences/lena_color/"
#test_image = "../sequences/lena_bw/"

Q_steps = [128, 64, 32, 16, 8, 4, 2]

In [None]:
x = image.read(test_image, 0)

RGB_points = []
for Q_step in Q_steps:
    x_k = Q.quantize(x, Q_step)
    x_dQ = Q.dequantize(x_k, Q_step)
    BPP = image.write(x_k.astype(np.uint8), f"/tmp/RGB_{Q_step}_", 0)*8/x.size
    __ = image.read(f"/tmp/RGB_{Q_step}_", 0)
    assert (x_k == __).all()
    MSE = distortion.MSE(x, x_dQ)
    point = (BPP, MSE)
    print(point)
    RGB_points.append(point)

In [None]:
x = image.read(test_image, 0)
xx = YUV.from_RGB(x.astype(np.int16))

YUV_points = []
for Q_step in Q_steps:
    xx_k = Q.quantize(xx, Q_step)
    xx_dQ = Q.dequantize(xx_k, Q_step)
    print(xx_k.dtype, xx_k.max(), xx_k.min())
    #BPP = image.write((xx_k.astype(np.int32) + 32768).astype(np.uint16), f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    #__ = image.read(f"/tmp/YUV_{Q_step}_", 0).astype(np.int32) - 32768
    BPP = image.write((xx_k.astype(np.int16) + 128).astype(np.uint8), f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    __ = image.read(f"/tmp/YUV_{Q_step}_", 0).astype(np.int32) - 128
    #BPP = image.write(xx_k, f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    #BPP = image.write((xx_k + 128).astype(np.uint8), f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    #BPP = image.write(xx_k + xx_k.min(), f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    #BPP = image.write(xx_k - xx.min(), f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    #BPP = image.write(xx_k + 256, f"/tmp/YUV_{Q_step}_", 0)*8/x.size
    for i in range(512):
        for j in range(512):
            if (xx_k[i,j] != __[i,j]).any():
                print(Q_step, i, j, x_k[i,j], __[i,j])
                break
    assert (xx_k == __).all()
    x_dQ = YUV.to_RGB(xx_dQ)
    MSE = distortion.MSE(x, x_dQ)
    point = (BPP, MSE)
    print(point)
    YUV_points.append(point)

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*RGB_points), label="RGB")
pylab.plot(*zip(*YUV_points), label="YUV")
pylab.title("Which domain color is better?")
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc="best")
pylab.show()

## Let's optimize the quantization step

In the previous experiment we have used the same quantization step for the three color components. However, although this is the fastest quantization strategy, not necessaryly have to be the optimal one from a RD perpestive.

Let's compute the RD contribution (a slope in the RD curve) of each component for each quantization step, and define a quantization algoritm in which we select progressively smaller contributions, starting at the higher one. We will supose that the transform is orthogonal and therefore, we can estimate the distortion of the reconstructed image (that obviously is in the RGB domain) in the YUV domain. However, we will compare with the most general option in which the color transform does not need to be orthogonal.

In [None]:
# Read the image and move to the YUV domain.
x = image.read(test_image, 0)
xx = YUV.from_RGB(x.astype(np.int16))

In [None]:
for i in range(3):
    print(xx[...,i].max(), xx[...,i].min())

In [None]:
# Create a list of RD points and a list of RD slopes.
RD_points = []
RD_slopes = []
for i in range(3):
    comp = xx[..., i]
    comp_energy = information.energy(comp)
    # The first point of each RD curve has a distortion 
    # equal to the energy of the component and a rate=0
    RD_points.append([(0, comp_energy)])
    RD_slopes.append([])

In [None]:
# Now populate the rest of points of each component.
# Version in which we estimate the distortion in the YUV domain.
for i in range(3):
    comp = xx[..., i]
    Q_step_number = 0
    for Q_step in Q_steps:
        print(Q_step)
        comp_k = Q.quantize(comp, Q_step)
        comp_dQ = Q.dequantize(comp_k, Q_step)
        MSE = distortion.MSE(comp, comp_dQ)
        BPP = image.write((comp_k + 128).astype(np.uint8), f"/tmp/{i}_{Q_step}_", 0)*8/x.size
        RD_points[i].append((BPP, MSE))
        delta_BPP = BPP - RD_points[i][Q_step_number][0]
        delta_MSE = RD_points[i][Q_step_number][1] - MSE
        if delta_BPP > 0:
            slope = delta_MSE/delta_BPP
        else:
            slope = 0
        RD_slopes[i].append((Q_step, slope, i))
        Q_step_number += 1

In [None]:
# Now populate the rest of points of each component.
# Version in which we estimate the distortion in the RGB domain.
for i in range(3):
    Q_step_number = 0
    for Q_step in Q_steps:
        xx_ = xx.copy()
        print(Q_step)
        comp_k = Q.quantize(xx_[..., i], Q_step)
        xx_[..., i] = Q.dequantize(comp_k, Q_step)
        MSE = distortion.MSE(x, YUV.to_RGB(xx_))
        BPP = image.write((comp_k + 128).astype(np.uint8), f"/tmp/{i}_{Q_step}_", 0)*8/x.size
        RD_points[i].append((BPP, MSE))
        delta_BPP = BPP - RD_points[i][Q_step_number][0]
        delta_MSE = RD_points[i][Q_step_number][1] - MSE
        if delta_BPP > 0:
            slope = delta_MSE/delta_BPP
        else:
            slope = 0
        RD_slopes[i].append((Q_step, slope, i))
        Q_step_number += 1

In [None]:
RD_slopes_without_sb_index = []
for i in range(3):
    RD_slopes_without_sb_index.append([])
for i in range(3):
    for Q_step in range(len(Q_steps)):
        RD_slopes_without_sb_index[i].append(RD_slopes[i][Q_step][0:2])

pylab.figure(dpi=150)
for i in range(3):
    pylab.plot(*zip(*RD_slopes_without_sb_index[i]), label=f"{i}")
pylab.title("Slopes of the RD curves of the components")
pylab.xlabel("Q_step")
pylab.ylabel("Slope")
plt.legend(loc="best")
pylab.show()

In [None]:
# Sort the slopes
single_list = []
for i in range(3):
    for Q_step in range(len(Q_steps)):
        single_list.append(tuple(RD_slopes[i][Q_step]))
sorted_slopes = sorted(single_list, key=lambda x: x[1])[::-1]

In [None]:
sorted_slopes

In [None]:
def quantize(x, Q_steps):
    x_k = np.empty_like(x)
    for i in range(x.shape[2]):
        x_k[..., i] = Q.quantize(x[..., i], Q_steps[i])
    return x_k

def dequantize(x_k, Q_steps):
    x_dQ = np.empty_like(x_k)
    for i in range(x.shape[2]):
        x_dQ[..., i] = Q.dequantize(x_k[..., i], Q_steps[i])
    return x_dQ

In [None]:
# Find the optimal RD curve
optimal_RD_points = []
zz = np.zeros_like(xx)
Q_steps_combination = np.full(shape=(3,), fill_value=99999999)
for s in sorted_slopes:
    component_number = s[2]
    Q_steps_combination[component_number] = s[0]
    print(component_number, Q_steps_combination[component_number])
    zz[..., component_number] = xx[..., component_number]
    zz_k = quantize(zz, Q_steps_combination)
    zz_dQ = dequantize(zz_k, Q_steps_combination)
    z_dQ = YUV.to_RGB(zz_dQ)
    # If the color transform domain is not linear, the MSE should be measured in the RGB domain
    MSE = distortion.MSE(xx, zz_dQ)
    BPP = image.write((zz_k + 128).astype(np.uint8), f"/tmp/{component_number}_{Q_step}_", 0)*8/x.size
    optimal_RD_points.append((BPP, MSE))

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*YUV_points), label="YUV Constant quantization")
pylab.plot(*zip(*optimal_RD_points), label="YUV Optimal quantization")
pylab.title("RD optimization in the YUV domain")
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc="best")
pylab.show()

In [None]:
def normalize(img):
    max_component = np.max(img)
    min_component = np.min(img)
    max_min_component = max_component - min_component
    return (img-min_component)/max_min_component

def show_frame(frame, prefix=''):
    #frame = normalize(frame)
    plt.figure(figsize=(10,10))
    plt.title(prefix, fontsize=20)
    plt.imshow(frame)

In [None]:
x = frame.read("../sequences/stockholm/", 0)

In [None]:
x.max()

In [None]:
x.min()

In [None]:
show_frame(x.astype(np.uint8))

In [None]:
y = YCoCg.from_RGB(x)

In [None]:
show_frame(y.astype(np.uint8))

In [None]:
z = YCoCg.to_RGB(y)

In [None]:
z.max()

In [None]:
z.min()

In [None]:
show_frame(z.astype(np.uint8))

In [None]:
print(np.array_equal(x, z))