# RD performance of some spatial DWTs

In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.axes as ax
import math
import numpy as np
from scipy import signal
import cv2
import os
import pywt
import pylab

In [None]:
def quantizer(x, quantization_step):
    k = (x / quantization_step).astype(np.int16)
    return k

def dequantizer(k, quantization_step):
    y = quantization_step * k
    return y

def q_deq(x, quantization_step):
    k = quantizer(x, quantization_step)
    y = dequantizer(k, quantization_step)
    return k, y

In [None]:
def load_frame(prefix):
    fn = f"{prefix}.png"
    print(fn)
    frame = cv2.imread(fn, cv2.IMREAD_UNCHANGED) # [rows, columns, components]
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = np.array(frame)
    frame = frame.astype(np.float32) - 32768.0
    return frame

def write_frame(frame, prefix):
    frame = frame.astype(np.float32)
    frame += 32768.0
    frame = frame.astype(np.uint16)
    cv2.imwrite(f"{prefix}.png", frame)

In [None]:
def load_indexes(prefix):
    load_frame(prefix)
    
def write_indexes(prefix):
    write_frame(prefix)

In [None]:
fn = "/home/vruiz/MRVC/sequences/stockholm/000"
RGB_frame = load_frame(fn)

In [None]:
def normalize(img):
    max_component = np.max(img)
    min_component = np.min(img)
    max_min_component = max_component - min_component
    return (img-min_component)/max_min_component

def show_frame(frame, prefix):
    frame = normalize(frame)
    plt.figure(figsize=(10,10))
    plt.title(prefix, fontsize=20)
    plt.imshow(frame)

In [None]:
show_frame(RGB_frame, fn)

In [None]:
def average_energy(x):
    return np.sum(x.astype(np.double)*x.astype(np.double))/len(x)

def RMSE(x, y):
    error_signal = x - y
    return math.sqrt(average_energy(error_signal))

In [None]:
def bytes_per_frame(frame):
    write_frame(frame, "/tmp/frame")
    length_in_bytes = os.path.getsize("/tmp/frame.png")
    return length_in_bytes

In [None]:
def only_Q_RD_curve(frame):
    points = []
    for q_step in range(0, 8):
        k, y = q_deq(frame, 1<<q_step)
        rate = bytes_per_frame(k)
        distortion = RMSE(frame, y)
        print(f"q_step={1<<q_step:>3}, rate={rate:>7} bytes, distortion={distortion:>6.1f}")
        points.append((rate, distortion))
    return points

RGB_points = only_Q_RD_curve(RGB_frame)

In [None]:
def RGB_to_YCoCg(RGB_frame):
    R, G, B = RGB_frame[:,:,0], RGB_frame[:,:,1], RGB_frame[:,:,2]
    YCoCg_frame = np.empty_like(RGB_frame)
    YCoCg_frame[:,:,0] = R/4 + G/2 + B/4 
    YCoCg_frame[:,:,1] = R/2 - B/2
    YCoCg_frame[:,:,2] = -R/4 + G/2 - B/4
    return YCoCg_frame

def YCoCg_to_RGB(YCoCg_frame):
    Y, Co, Cg = YCoCg_frame[:,:,0], YCoCg_frame[:,:,1], YCoCg_frame[:,:,2]
    RGB_frame = np.empty_like(YCoCg_frame)
    RGB_frame[:,:,0] = Y + Co - Cg 
    RGB_frame[:,:,1] = Y - Cg
    RGB_frame[:,:,2] = Y - Co - Cg
    return RGB_frame

In [None]:
def YCoCg_RD_curve(RGB_frame):
    RD_points = []
    for q_step in range(0, 8):
        YCoCg_frame = RGB_to_YCoCg(RGB_frame)
        k, dequantized_YCoCg_frame = q_deq(YCoCg_frame, 1<<q_step)
        rate = bytes_per_frame(k)
        dequantized_RGB_frame = YCoCg_to_RGB(dequantized_YCoCg_frame)
        distortion = RMSE(RGB_frame, dequantized_RGB_frame)
        print(f"q_step={1<<q_step:>3}, rate={rate:>7} bytes, distortion={distortion:>6.1f}")
        RD_points.append((rate, distortion))
    return RD_points

YCoCg_points = YCoCg_RD_curve(frame)

In [None]:
WAVELET = pywt.Wavelet("db5")
LEVELS = 1

def color_DWT_analyze(color_frame, wavelet=WAVELET, n_levels=LEVELS):
    n_channels = color_frame.shape[2]
    color_decomposition = [None]*n_channels
    for c in range(n_channels):
        color_decomposition[c] = pywt.wavedec2(data=color_frame[:,:,c], wavelet=wavelet, mode='per', level=n_levels)
    return color_decomposition # A list of "gray" decompositions

def color_DWT_synthesize(color_decomposition, wavelet=WAVELET):
    n_channels = len(color_decomposition)
    #n_levels = len(color_decomposition[0])-1
    # color_decomposition[0] <- First channel
    # color_decomposition[0][0] <- cAn (lowest frequecy subband) of the first channel
    # color_decomposition[0][1] <- (cHn, cVn, cDn) (lowest high-frequency subbands) of the first channel
    # color_decomposition[0][1][0] <- cHn (LH subband) of the first channel
    # See https://pywavelets.readthedocs.io/en/latest/ref/2d-dwt-and-idwt.html#d-multilevel-decomposition-using-wavedec2
    _color_frame = []
    for c in range(n_channels):
        frame = pywt.waverec2(color_decomposition[c], wavelet=wavelet, mode='per')
        _color_frame.append(frame)
    n_rows = _color_frame[0].shape[0]
    n_columns = _color_frame[0].shape[1]
    color_frame = np.ndarray((n_rows, n_columns, n_channels), np.float64)
    for c in range(n_channels):
        color_frame[:,:,c] = _color_frame[c][:,:]
    return color_frame

In [None]:
RGB_decomposition = color_DWT_analyze(RGB_frame)
RGB_reconstructed_frame = color_DWT_synthesize(RGB_decomposition)
assert RGB_frame.all() == RGB_reconstructed_frame.all()

In [None]:
YCoCg_frame = RGB_to_YCoCg(RGB_frame)
YCoCg_decomposition = color_DWT_analyze(YCoCg_frame)
reconstructed_YCoCg_frame = color_DWT_synthesize(YCoCg_decomposition)
reconstructed_RGB_frame = YCoCg_to_RGB(reconstructed_YCoCg_frame)
assert RGB_frame.all() == reconstructed_RGB_frame.all()

In [None]:
def DWT_RD_curve(RGB_frame):
    n_channels = RGB_frame.shape[2]
    RD_points = []
    for q_step in range(0, 8):
        YCoCg_frame = RGB_to_YCoCg(RGB_frame)
        YCoCg_decomposition = color_DWT_analyze(YCoCg_frame)
        dequantized_YCoCg_decomposition = []
        rate = 0
        for channel in range(n_channels):
            # In a channel there is a decomposition
            decomposition = YCoCg_decomposition[channel]
            cAn = decomposition[0]
            k, dequantized_cAn = q_deq(cAn, 1<<q_step)
            dequantized_decomposition = [dequantized_cAn]
            rate += bytes_per_frame(k)
            rest_of_resolutions = decomposition[1:]
            for resolution in rest_of_resolutions:
                # In a resolution there is/are one/three subbands
                dequantized_resolution = []
                for subband in resolution:
                    k, dequantized_subband = q_deq(subband, 1<<q_step)
                    rate += bytes_per_frame(k)
                    dequantized_resolution.append(dequantized_subband)
                dequantized_decomposition.append(tuple(dequantized_resolution))
            dequantized_YCoCg_decomposition.append(dequantized_decomposition)
        reconstructed_YCoCg_frame = color_DWT_synthesize(dequantized_YCoCg_decomposition)
        reconstructed_RGB_frame = YCoCg_to_RGB(reconstructed_YCoCg_frame)
        distortion = RMSE(RGB_frame, reconstructed_RGB_frame)
        print(f"q_step={1<<q_step:>3}, rate={rate:>7} bytes, distortion={distortion:>6.1f}")
        RD_points.append((rate, distortion))
    return RD_points

DWT_points = DWT_RD_curve(RGB_frame)

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*RGB_points), c='b', marker="x",
           label='$\Delta_{\mathrm{R}}=\Delta_{\mathrm{G}}=\Delta_{\mathrm{B}}$')
pylab.plot(*zip(*YCoCg_points), c='g', marker="x",
           label='$\Delta_{\mathrm{Y}}=\Delta_{\mathrm{Co}}=\Delta_{\mathrm{Cg}}$')
pylab.plot(*zip(*DWT_points), c='r', marker="x",
           label='DWT (same $\Delta$ all subbands)')
pylab.title("Performance of Different Quantization Schemes and Domains")
pylab.xlabel("Bytes/Frame")
pylab.ylabel("RMSE")
plt.legend(loc='upper right')
pylab.show()