# RD curve quantizing a frame

In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.axes as ax
import math
import numpy as np
from scipy import signal
import cv2
import os

In [None]:
def quantizer(x, quantization_step):
    k = (x / quantization_step).astype(np.int16)
    return k

def dequantizer(k, quantization_step):
    y = quantization_step * k
    return y

def q_deq(x, quantization_step):
    k = quantizer(x, quantization_step)
    y = dequantizer(k, quantization_step)
    return k, y

In [None]:
def load_frame(prefix):
    fn = f"{prefix}.png"
    print(fn)
    frame = cv2.imread(fn, cv2.IMREAD_UNCHANGED) # [rows, columns, components]
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = np.array(frame)
    frame = frame.astype(np.float32) - 32768.0
    return frame

def write_frame(frame, prefix):
    frame = frame.astype(np.float32)
    frame += 32768.0
    frame = frame.astype(np.uint16)
    cv2.imwrite(f"{prefix}.png", frame)

In [None]:
def load_indexes(prefix):
    load_frame(prefix)
    
def write_indexes(prefix):
    write_frame(prefix)

In [None]:
fn = "/home/vruiz/MRVC/sequences/stockholm/000"
frame = load_frame(fn)

In [None]:
def normalize(img):
    max_component = np.max(img)
    min_component = np.min(img)
    max_min_component = max_component - min_component
    return (img-min_component)/max_min_component

def show_frame(frame, prefix):
    frame = normalize(frame)
    plt.figure(figsize=(10,10))
    plt.title(prefix, fontsize=20)
    plt.imshow(frame)

In [None]:
show_frame(frame, fn)

In [None]:
def average_energy(x):
    return np.sum(x.astype(np.double)*x.astype(np.double))/len(x)

def RMSE(x, y):
    error_signal = x - y
    return math.sqrt(average_energy(error_signal))

In [None]:
def byte_rate(frame):
    write_frame(frame, "/tmp/frame")
    length_in_bytes = os.path.getsize("/tmp/frame.png")
    return length_in_bytes

In [None]:
WAVELET = pywt.wavelet("db5")
LEVELS = 3

def DWT_analyze(frame, wavelet=WAVELET, n_levels=LEVELS):
    n_components = frame.shape[2]
    color_decomposition = [None]*n_components
    for c in range(n_components):
        color_decomposition[c] = pywt.wavedec2(data=frame[:,:,c], wavelet=wavelet, mode='per', level=levels)
    return color_decomposition

def DWT_synthesize(color_decomposition, wavelet=WAVELET):
    n_components = len(color_decomposition)
    #n_levels = len(color_decomposition[0])
    n_rows = color_decomposition[0][-1].shape[0]*2 # Only true if we use mode="preriodization"
    n_cols = color_decomposition[0][-1].shape[1]*2 # Only true if we use mode="preriodization"
    frame = np.ndarray((n_rows, n_columns, n_components), np.float64)
    for c in range(n_components):
        frame[:,:,c] = pywt.waverec2(color_decomposition[c], wavelet=wavelet, mode='per')
    return frame

In [None]:
def RD_curve(x):
    points = []
    for q_step in range(1, 32768, 32):
        print(q_step, end=' ')
        k, y = q_deq(x, q_step)
        rate = byte_rate(k)
        distortion = RMSE(x, y)
        points.append((rate, distortion))
    return points

RD_points = RD_curve(frame)

In [None]:
plt.title("RD Tradeoff")
plt.xlabel("Bytes/frame")
plt.ylabel("RMSE")
plt.scatter(*zip(*RD_points), s=2, c='b', marker="o")
plt.show()

In [None]:
k, y = q_deq(frame, 32)
show_frame(y, "")