# DWT (Discrete Wavelet Transform)

How to compress images using the DWT.

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pylab
%matplotlib inline
import image
import DWT
import pywt
import distortion
import YCoCg as YUV
import deadzone as Q
import math
import information

In [None]:
wavelet = pywt.Wavelet("Haar")
#wavelet = pywt.Wavelet("db1")
#wavelet = pywt.Wavelet("db5")
wavelet = pywt.Wavelet("bior3.1")
#wavelet = pywt.Wavelet("bior3.3")
print(wavelet)

In [None]:
test_image = "../sequences/lena_color/"

In [None]:
N_levels = 5

In [None]:
Q_steps = [128, 64, 32, 16, 8, 4, 2, 1]

## First ... some handy routines

In [None]:
def read_image(prefix):
    x = image.read(prefix, 0)
    if len(x.shape) == 2:
        extended_x = np.zeros(shape=(x.shape[0],  x.shape[1], 3), dtype=np.uint16) 
        extended_x[..., 0] = x
        return extended_x
    else:
        return x

def write_compact_decomposition(decom, prefix, image_number):
    rows = decom[len(decom)-1][0].shape[0]*2
    cols = decom[len(decom)-1][0].shape[1]*2
    coms = decom[0].shape[2]
    image_shape = (rows, cols, coms)
    view = np.empty(image_shape, np.uint16)
    # LL subband
    view[0:decom[0].shape[0],
         0:decom[0].shape[1]] = (decom[0].astype(np.int32) + 32768).astype(np.uint16)

    for l in range(len(decom)-1):

        # LH
        view[0:decom[l+1][0].shape[0],
             decom[l+1][0].shape[1]:decom[l+1][0].shape[1]*2] =\
                (decom[l+1][0].astype(np.int32) + 32768).astype(np.uint16)

        # HL
        view[decom[l+1][1].shape[0]:decom[l+1][1].shape[0]*2,
             0:decom[l+1][1].shape[1]] =\
                (decom[l+1][1].astype(np.int32) + 32768).astype(np.uint16)

        # HH
        view[decom[l+1][2].shape[0]:decom[l+1][2].shape[0]*2,
             decom[l+1][2].shape[1]:decom[l+1][2].shape[1]*2] =\
                (decom[l+1][2].astype(np.int32) + 32768).astype(np.uint16)
            
    return image.write(view, prefix, image_number)
    
def read_compact_decomposition(prefix, image_number, N_levels):
    view = image.read(prefix, image_number)
    wavelet = pywt.Wavelet("Haar")
    decom = DWT.analyze(np.zeros_like(view), wavelet, N_levels)
    
    # LL subband
    decom[0][...] = view[0:decom[0].shape[0],
                         0:decom[0].shape[1]] - 32768
    
    for l in range(len(N_levels)):
        
        # LH
        decom[l+1][0] =\
            view[0:decom[l+1][0].shape[0],
                 decom[l+1][0].shape[1]:decom[l+1][0].shape[1]*2] - 32668
            
        # HL
        decom[l+1][1] =\
            view[decom[l+1][1].shape[0]:decom[l+1][1].shape[0]*2,
                 0:decom[l+1][1].shape[1]] - 32768
            
        # HH
        decom[l+1][2] =\
            view[decom[l+1][2].shape[0]:decom[l+1][2].shape[0]*2,
                 decom[l+1][2].shape[1]:decom[l+1][2].shape[1]*2] - 32768

    return decom

def entropy(decomposition):
    entro = information.entropy(decomposition[0].flatten().astype(np.int16))
    accumulated_entropy = entro * decomposition[0].size
    image_size = decomposition[0].size
    for sr in y[1:]:
        for sb in sr:
            entro = information.entropy(sb.flatten().astype(np.int16))
            accumulated_entropy += (entro * sb.size)
            image_size += sb.size
    avg_entropy = accumulated_entropy / image_size
    return avg_entropy

## Testing `DWT.analyze_step()` and `DCT.synthesize_step()`

In [None]:
x = read_image(test_image)
image.show_RGB_image(x, title="Original")

In [None]:
L, H = DWT.analyze_step(x, wavelet)

In [None]:
image.show_RGB_image(255*image.normalize(L), "LL DWT domain")
subbands = ("LH", "HL", "HH")
for i, sb in enumerate(subbands):
    image.show_RGB_image(255*image.normalize(H[i]), f"{sb} DWT domain")

In [None]:
z = DWT.synthesize_step(L, H, wavelet).astype(np.uint8)

In [None]:
r = x - z

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_levels}")

In [None]:
r.max()

The DWT is not fully reversible, but it is almost.

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Testing `DWT.analyze()` and `DCT.synthesize()`

In [None]:
y = DWT.analyze(x, wavelet, N_levels)
z = DWT.synthesize(y, wavelet, N_levels).astype(np.uint8)

In [None]:
r = x - z

In [None]:
print(r.max(), r.min())

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_levels}")

In [None]:
distortion.MSE(x, z)

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Subbands information

In [None]:
x = read_image(test_image).astype(np.int16) - 128
y = DWT.analyze(x, wavelet, N_levels)
print("sb maximum mininum average std-dev entropy        energy  avg-enegy")
entro = information.entropy(y[0].flatten().astype(np.int16))
accumulated_entropy = entro * y[0].size
print(f" 0 {y[0].max():7.1f} {y[0].min():7.1f} {np.average(y[0]):7.1f} {math.sqrt(np.var(y[0])):7.1f} {entro:7.1f} {information.energy(y[0]):13.1f} {information.energy(y[0])/y[0].size:10.1f}")
sb_index = 1
for sr in y[1:]:
    for sb in sr:
        entro = information.entropy(sb.flatten().astype(np.int16))
        accumulated_entropy += (entro * sb.size)
        print(f"{sb_index:2d} {sb.max():7.1f} {sb.min():7.1f} {np.average(sb):7.1f} {math.sqrt(np.var(sb)):7.1f} {entro:7.1f} {information.energy(sb):13.1f} {information.energy(sb)/sb.size:10.1f}")
        sb_index += 1
avg_entropy = accumulated_entropy / x.size
print("Average entropy in the wavelet domain:", avg_entropy)
print("Entropy in the image domain:", information.entropy(x.flatten().astype(np.uint8)))

As it can be seen, most of the energy (and information) is concentrated in the low-frequency subbands. It's also worth to realize that the high-frequency subbands are potentially more compressibles. Finally, the wavelet domain is potentially more compressible than the image domain.

## RD performance using "constant" quantization among subbands

We compute the RD cuve of using uniform quantization in the wavelet domain, for different quantization steps. To measure the distortion we have two alternatives:
1. Always considering that the transform is (bi)orthogonal and therefore, the distortion among subbands is uncorrelated, we can measure the quantization error in the wavelet domain, inside of the quantized subband, considering the inverse transform gain of such subband.
2. We can measure the distortion in the image domain, after inversely transforming the quantized decomposition. Obviously, this alternative is slower. However, this is the only choice is the transform is not (bi)orthogonal.

In [None]:
xx = read_image(test_image)
x = YUV.from_RGB(xx.astype(np.int16) - 128)

DWT_points = []
for Q_step in Q_steps:
    y = DWT.analyze(x, wavelet, N_levels)
    y_k = []
    y_dQ = []
    LL = y[0]
    LL_k = Q.quantize(LL, Q_step)
    LL_dQ = Q.dequantize(LL_k, Q_step)
    y_k.append(LL_k)
    y_dQ.append(LL_dQ)
    #dist = distortion.MSE(LL, LL_dQ)
    #MSE = (dist * LL.size)/x.size
    #print(gains[0], dist, gains[0] * dist, MSE)
    #for i in range(4):
    #    for j in range(4):
    #        print(LL[i, j], LL_dQ[i, j])
    for sr in y[1:]:
        sr_k = []
        sr_dQ = []
        for sb in sr:
            #print(MSE)
            sb_k = Q.quantize(sb, Q_step)
            sb_dQ = Q.dequantize(sb_k, Q_step)
            sr_k.append(sb_k)
            sr_dQ.append(sb_dQ)
            #dist = distortion.MSE(sb, sb_dQ)
            #print(gains[counter], dist, gains[counter] * dist, MSE)
            #MSE += (dist * sb.size)/x.size
        y_k.append(tuple(sr_k))
        y_dQ.append(tuple(sr_dQ))
    BPP = (write_compact_decomposition(y_k, f"/tmp/constant_{Q_step}_", 0)*8)/x.size
    z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
    zz_dQ = np.clip(YUV.to_RGB(z_dQ), a_min=-128, a_max=127) + 128
    MSE = distortion.MSE(xx, zz_dQ)
    print(f"{Q_step} {BPP} {MSE}")
    DWT_points.append((BPP, MSE))
    image.show_RGB_image(zz_dQ, f"Reconstructed image (Q_step={Q_step})")

In [None]:
DCT_points = []
with open("DCT.txt", 'r') as f:
    for line in f:
        rate, _distortion = line.split('\t')
        DCT_points.append((float(rate), float(_distortion)))

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*DCT_points), label=f"DCT optimal Q")
pylab.plot(*zip(*DWT_points), label=f"DWT constant Q")
pylab.title(test_image)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc='upper right')
pylab.show()

## Uniform quantization vs optimal quantization

As we did with the DCT, let's compare both types of quantization in the RD space. Steps:

1. Compute the RD slope of each subband for a set of quantization steps. For this we will suppose that the subbands are independent (the DWT is orthogonal), measuring the distortion in the wavelet domain.
2. Sort the slopes. This will determine the optimal progression of quantization steps (which subband to incorporate more data to the code-stream, progressively).
3. Compute the distortion in the image domain for each bit-rate. Notice that this information should match with the privided by the step 1 (measuring the distortion in the wavelet domain). However, we prefer to computate the distortion in the image domain because the transform does not need to be orthogonal.

 Read the image, move to the YUV domain, and compute the DWT.

In [None]:
xx = read_image(test_image)
x = YUV.from_RGB(xx.astype(np.int16) - 128)
y = DWT.analyze(x, wavelet, N_levels)

For each subband, we populate:
1. A list of RD points, and
2. A list of RD slopes with these points, indicanting also the corresponding quantization step and subband.

Remember that we have a RD point for each quantization step for each subband. The first dimension of these lists is indexed the subband, and the second dimension is indexed by the quantization step.

In [None]:
# For BPP=0, the MSE is the energy of the subband. No slope can be computed for the first point.
RD_points = [[(0, information.energy(y[0]) / y[0].size)]] # Work with MSE's that are average distortions
RD_slopes = [[]]
for sr in y[1:]:
    for sb in sr:
        sb_avg_energy = information.energy(sb) / sb.size
        # The first point of each RD curve has a maximum distortion equal
        # to the energy of the subband and a rate = 0
        RD_points.append([(0, sb_avg_energy)])
        RD_slopes.append([])

for i,j in enumerate(RD_points):
    print(i,j)
    
for i,j in enumerate(RD_slopes):
    print(i,j)

In [None]:
# Now populate the rest of points of each subband

# Subband LL
sb_number = 0
sb = y[0]
Q_step_number = 0
for Q_step in Q_steps:
    print(Q_steps)
    sb_k = Q.quantize(sb, Q_step)
    sb_dQ = Q.dequantize(sb_k, Q_step)
    sb_MSE = distortion.MSE(sb, sb_dQ)
    sb_BPP = information.PNG_BPP((sb_k.astype(np.int32) + 32768).astype(np.uint16), "/tmp/BPP_")[0]
    #sb_BPP = information.entropy(sb_k.astype(np.int16).flatten())
    RD_points[sb_number].append((sb_BPP, sb_MSE))
    delta_BPP = sb_BPP - RD_points[sb_number][Q_step_number][0]
    delta_MSE = RD_points[sb_number][Q_step_number][1] - sb_MSE
    if delta_BPP > 0:
        slope = delta_MSE/delta_BPP
    else:
        slope = 0
    RD_slopes[sb_number].append((Q_step, slope, (sb_number)))
    Q_step_number += 1

print(N_levels)
    
for i,j in enumerate(RD_points):
    print(i, "---", j)
    
for i,j in enumerate(RD_slopes):
    print(i, "---", j)

In [None]:
sb_number = 1
for sr in y[1:]:
    for sb in sr:
        Q_step_number = 0
        for Q_step in Q_steps:
            sb_k = Q.quantize(sb, Q_step)
            sb_dQ = Q.dequantize(sb_k, Q_step)
            sb_MSE = distortion.MSE(sb, sb_dQ)
            sb_BPP = information.PNG_BPP((sb_k.astype(np.int32) + 32768).astype(np.uint16), "/tmp/BPP_")[0]
            #sb_BPP = information.entropy(sb_k.astype(np.int16).flatten())
            RD_points[sb_number].append((sb_BPP, sb_MSE))
            delta_BPP = sb_BPP - RD_points[sb_number][Q_step_number][0]
            delta_MSE = RD_points[sb_number][Q_step_number][1] - sb_MSE
            if delta_BPP > 0:
                slope = delta_MSE/delta_BPP
            else:
                slope = 9^9
            print(sb_number, len(y))
            RD_slopes[sb_number].append((Q_step, slope, (sb_number)))
            Q_step_number += 1
        sb_number += 1
        
for i,j in enumerate(RD_points):
    print(i, "---", j)
    
for i,j in enumerate(RD_slopes):
    print(i, "---", j)

In [None]:
RD_slopes

In [None]:
if sb_number < 12:
    pylab.figure(dpi=150)
    pylab.plot(*zip(*RD_points[0]), label="0", marker=0)
    sb_number = 1
    for sr in y[1:]:
        for sb in sr:
            pylab.plot(*zip(*RD_points[sb_number]), label=f"{sb_number}", marker=sb_number)
            sb_number += 1
    pylab.title("RD curves of the subbands")
    pylab.xlabel("Bits/Pixel")
    pylab.ylabel("MSE")
    plt.legend(loc="best")
    pylab.show()

In [None]:
if sb_number < 12:
    pylab.figure(dpi=150)
    pylab.plot(*zip(*RD_points[0]), label="0", marker=0)
    sb_number = 1
    for sr in y[1:]:
        for sb in sr:
            pylab.plot(*zip(*RD_points[sb_number]), label=f"{sb_number}", marker=sb_number)
            sb_number += 1
    pylab.title("RD curves of the subbands")
    pylab.xlabel("Bits/Pixel")
    pylab.ylabel("MSE")
    pylab.yscale("log")
    plt.legend(loc="best")
    pylab.show()

In [None]:
RD_slopes_without_sb_index = []
RD_slopes_without_sb_index.append([])
for sr in y[1:]:
    for sb in sr:
        RD_slopes_without_sb_index.append([])
for Q_step in range(len(Q_steps)):
    RD_slopes_without_sb_index[0].append(RD_slopes[0][Q_step][0:2])
sb_number = 1
for sr in y[1:]:
    for sb in sr:
        for Q_step in range(len(Q_steps)):
            RD_slopes_without_sb_index[sb_number].append(RD_slopes[sb_number][Q_step][0:2])
        sb_number += 1
print(RD_slopes_without_sb_index[0])
if sb_number < 12:
    pylab.figure(dpi=150)
    pylab.plot(*zip(*RD_slopes_without_sb_index[0]), label="0", marker=0)
    sb_number = 1
    for sr in y[1:]:
        for sb in sr:
            pylab.plot(*zip(*RD_slopes_without_sb_index[sb_number]), label=f"{sb_number}", marker=sb_number)
            sb_number += 1
    pylab.title("Slopes of the RD curves of the subbands")
    pylab.xlabel("Q_step")
    pylab.ylabel("Slope")
    plt.legend(loc="best")
    pylab.show()

In [None]:
if sb_number < 12:
    pylab.figure(dpi=150)
    pylab.plot(*zip(*RD_slopes_without_sb_index[0]), label="0", marker=0)
    sb_number = 1
    for sr in y[1:]:
        for sb in sr:
            pylab.plot(*zip(*RD_slopes_without_sb_index[sb_number]), label=f"{sb_number}", marker=sb_number)
            sb_number += 1
    pylab.title("Slopes of the RD curves of the subbands")
    pylab.xlabel("Q_step")
    pylab.ylabel("Slope")
    pylab.yscale("log")
    plt.legend(loc="best")
    pylab.show()

It can be seen that the slopes of the curves are quite similar, but the LL subband is somewhat steeper.

Let's sort the slopes.

In [None]:
single_list = []
for Q_step in range(len(Q_steps)):
    single_list.append(tuple(RD_slopes[0][Q_step]))
sb_number = 1
for sr in y[1:]:
    for sb in sr:
        for Q_step in range(len(Q_steps)):
            single_list.append(tuple(RD_slopes[sb_number][Q_step]))
        sb_number += 1
sorted_slopes = sorted(single_list, key=lambda x: x[1])[::-1]

In [None]:
sorted_slopes

In [None]:
def quantize(decomposition, Q_steps):
    #print(Q_steps)
    LL = decomposition[0]
    LL_k = Q.quantize(LL, Q_steps[0])
    decomposition_k = [LL_k]
    sb_number = 1
    for sr in decomposition[1:]:
        sr_k = []
        for sb in sr:
            #print(sb_number)
            sb_k = Q.quantize(sb, Q_steps[sb_number])
            sr_k.append(sb_k)
            sb_number += 1
        decomposition_k.append(tuple(sr_k))
    return decomposition_k

def dequantize(decomposition_k, Q_steps):
    LL_k = decomposition_k[0]
    LL_dQ = Q.dequantize(LL_k, Q_steps[0])
    decomposition_dQ = [LL_dQ]
    sb_number = 1
    for sr_k in decomposition_k[1:]:
        sr_dQ = []
        for sb_k in sr_k:
            sb_dQ = Q.dequantize(sb_k, Q_steps[sb_number])
            sr_dQ.append(sb_dQ)
            sb_number += 1
        decomposition_dQ.append(tuple(sr_dQ))
    return decomposition_dQ

def resolution_level(sb_number):
    '''Resolution level in decomposition.'''
    if sb_number > 0:
        return ((sb_number - 1) // 3) + 1
    else:
        return 0
    
def subband_index(sb_number):
    '''Subband index in resolution level.'''
    if sb_number > 0:
        return (sb_number % 3) - 1
    else:
        return 0

In [None]:
optimal_RD_points = []
#y_prog = DWT.analyze(np.zeros_like(x), wavelet, N_levels)
#print(len(y_prog))
Q_steps_by_subband = [9**9]
#for sr in y_prog[1:]:
for sr in y[1:]:
    #print(len(sr))
    for sb in sr:
        Q_steps_by_subband.append(9**9)
#print(Q_steps_by_subband)
slope_index = 0
for s in sorted_slopes:
    sb_number = s[2]
    #print("sb_number", sb_number)
    Q_steps_by_subband[sb_number] = s[0]
    print("Q_steps_by_subband", Q_steps_by_subband)
    #print(sb_number, Q_steps_by_subband[sb_number])
    #y_prog[resolution_level(sb_number)][subband_index(sb_number)] = y[resolution_level(sb_number)][subband_index(sb_number)]
    #y_k = quantize(y_prog, Q_steps_by_subband)
    y_k = quantize(y, Q_steps_by_subband)
    BPP = (write_compact_decomposition(y_k, f"/tmp/optimal_{slope_index}_", 0)*8)/xx.size
    #BPP = entropy(y_k)
    slope_index += 1
    y_dQ = dequantize(y_k, Q_steps_by_subband)
    z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
    zz_dQ = np.clip(YUV.to_RGB(z_dQ), a_min=-128, a_max=127) + 128
    MSE = distortion.MSE(xx, zz_dQ)
    print(f"{Q_step} {BPP} {MSE}")
    optimal_RD_points.append((BPP, MSE))
    #image.show_RGB_image(zz_dQ, f"Reconstructed image (Q_step={Q_step})")

In [None]:
optimal_RD_points

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*DWT_points), label="Uniform quantization")
pylab.plot(*zip(*optimal_RD_points), label="Optimal quantization")
#pylab.plot(*zip(*JPEG_RD_points), label="JPEG")
pylab.title(test_image)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc="best")
pylab.show()

# .......

ESTE PASO NO ES NECESARIO: Compute the average energy of the image and the decomposition.

In [None]:
def energy(decomposition):
    accumulated_energy = information.energy(decomposition[0])
    for sr in y[1:]:
        for sb in sr:
            accumulated_energy += information.energy(sb)
    return accumulated_energy

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
#xx = np.full(shape=(512, 512, 3), fill_value=100) - 128
x = YUV.from_RGB(xx)
y = DWT.analyze(x, wavelet, N_levels)
image_energy = information.average_energy(x)
print(image_energy)
print(energy(y)/x.size)

## Transform gains

This information measures whether the transform amplifies or attenuates the signal. If the forward transform amplifies the signal, the energy of the decomposition will be larger than the energy of the original signal, and viceversa. The same idea can be applied to the inverse transform.

In [None]:
x = np.full(shape=(512, 512, 3), fill_value=1)
x_energy = information.energy(x)
y = DWT.analyze(x, wavelet, N_levels)
decom_energy = information.energy(y[0])
z = DWT.synthesize(y, wavelet, N_levels)
for sr in y[1:]:
    for sb in sr:
        decom_energy += information.energy(sb)
z_energy = information.energy(z)
print(wavelet)
print("Energy of the original image:", x_energy)
print("Energy of the decomposition:", decom_energy)
print("Energy of the reconstucted image:", z_energy)
print("Average energy of the original image", x_energy / x.size)
print("Average energy of the decomposition:", decom_energy / x.size)
print("Average energy of the reconstructed image:", z_energy / x.size)

As it can be seen, the transform is energy preserving, which means that we the distortion generated by quantization is the same in the image and the wavelet domains.

## Subband gains

All the wavelet transforms implemented by PyWavelets are unitary (preserve the energy). However, this not means that the subbands have the same gain. We can determine the subbands gain of the inverse transform giving a fixed amount of energy to each subband and computing the energy of the inverse transform of the decomposition. Finally, considering that the inverse transform has a gain of one, the gains are scaled to sum 1.

We are specially interested in the subband gains, considering the inverse transform, because in the compression process the subbands are quantized, and the quantization error is scaled by the gain of the subbands.

In [None]:
gains = []
x = np.zeros(shape=(512, 512, 3))
y = DWT.analyze(x, wavelet, N_levels)
coeff_value = y[0].size
y[0][...] = coeff_value/y[0].size
z = DWT.synthesize(y, wavelet, N_levels)
gains.append(distortion.energy(z))
prev_sb = y[0]
for sr in y[1:]:
    for sb in sr:
        prev_sb[...] = 0.0
        sb[...] = coeff_value/sb.size
        z = DWT.synthesize(y, wavelet, N_levels)
        gains.append(distortion.energy(z))
        prev_sb = sb
        
x = np.empty(shape=(512, 512, 3))
y = DWT.analyze(x, wavelet, N_levels)
coeff_value = y[0].size
y[0][...] = coeff_value/y[0].size
for sr in y[1:]:
    for sb in sr:
        sb[...] = coeff_value/sb.size
z = DWT.synthesize(y, wavelet, N_levels)
z_energy = distortion.energy(z)

gains = [gain/z_energy for gain in gains]
print("Unitary (normalized) inverse transform subband gains:", gains)
np.testing.assert_almost_equal(sum(gains), 1.0)

## RD performance considering (and not) the subband gains

We compute the RD cuve of using scalar quantization when:
1. All subbands are quantized using the same quantization step.
2. The quantization step used in a subband is divided by the subband gain.

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
x = YUV.from_RGB(xx)

constant_Q_points = []
for Q_step in Q_steps:
    y = DWT.analyze(x, wavelet, N_levels)
    LL = y[0]
    LL_k = Q.quantize(LL, Q_step)
    y_k = [LL_k]
    LL_dQ = Q.dequantize(LL_k, Q_step)
    y_dQ = [LL_dQ]
    for sr in y[1:]:
        sr_k = []
        sr_dQ = []
        for sb in sr:
            sb_k = Q.quantize(sb, Q_step)
            sr_k.append(sb_k)
            sb_dQ = Q.dequantize(sb_k, Q_step)
            sr_dQ.append(sb_dQ)
        y_k.append(tuple(sr_k))
        y_dQ.append(tuple(sr_dQ))
    BPP = (write_compact_decomposition(y_k, f"/tmp/constant_{Q_step}", 0)*8)/x.size
    z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
    zz_dQ = np.clip(YUV.to_RGB(z_dQ), a_min=-128, a_max=127)
    MSE = distortion.MSE(xx, zz_dQ)
    print(f"{Q_step} {BPP} {MSE}")
    constant_Q_points.append((BPP, MSE))
    #image.show_RGB_image(zz_dQ + 128, f"Reconstructed image (Q_step={Q_step})")

Let's suppose that the slope of the subband is proportional to the subband gain.

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
x = YUV.from_RGB(xx)

relative_gains = [gain/gains[-1] for gain in gains]
print(relative_gains)
gains_Q_points = []
for Q_step in Q_steps:
    y = DWT.analyze(x, wavelet, N_levels)[::-1]
    counter = len(y) - 1
    for sr in y[:-1]:
        sr_k = []
        sr_dQ = []
        for sb in sr:
            _Q_step = Q_step / relative_gains[counter]
            print("Q_step =",_Q_step)
            sb_k = Q.quantize(sb, _Q_step)
            sr_k.append(sb_k)
            sb_dQ = Q.dequantize(sb_k, _Q_step)
            sr_dQ.append(sb_dQ)
            counter -= 1
        y_k.append(tuple(sr_k))
        y_dQ.append(tuple(sr_dQ))
    LL = y[-1]
    _Q_step = Q_step / relative_gains[0]
    print(_Q_step)
    LL_k = Q.quantize(LL, _Q_step)
    y_k = [LL_k]
    LL_dQ = Q.dequantize(LL_k, _Q_step)
    y_dQ = [LL_dQ]
    BPP = (write_compact_decomposition(y_k, f"/tmp/gains_{Q_step}", 0)*8)/x.size
    z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
    zz_dQ = np.clip(YUV.to_RGB(z_dQ), a_min=-128, a_max=127)
    MSE = distortion.MSE(xx, zz_dQ)
    print(f"{Q_step} {BPP} {MSE}")
    gains_Q_points.append((BPP, MSE))
    image.show_RGB_image(zz_dQ + 128, f"Reconstructed image (Q_step={Q_step})")

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*constant_Q_points), label=f"Constant Q")
pylab.plot(*zip(*gains_Q_points), label=f"Gains Q")
pylab.title(test_image)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc='upper right')
pylab.show()

## Optimal quantization progression

The previous quantization is not (usually) optimal, because the RD constribution of each subband is not constant.
Let's use now a different quantization step for each subband that operates (approximately) at the same RD slope.

## An example of uniform quantization

We will measure also the distortion in both domains.

In [None]:
Q_step = 128
x = read_image(test_image).astype(np.int16) #- 128
#x = YUV.from_RGB(xx)
y = DWT.analyze(x, wavelet, N_levels)

LL = y[0]
LL_k = Q.quantize(LL, Q_step)
LL_dQ = Q.dequantize(LL_k, Q_step)
dist = distortion.MSE(LL, LL_dQ)
subband_ratio = LL.size / x.size
print(subband_ratio)
MSE_wavelet_domain = dist * subband_ratio #* gains[0]
counter = 1
y_dQ = [LL_dQ]
for sr in y[1:]:
    sr_dQ = []
    for sb in sr:
        sb_k = Q.quantize(sb, Q_step)
        sb_dQ = Q.dequantize(sb_k, Q_step)
        sr_dQ.append(sb_dQ)
        dist = distortion.MSE(sb, sb_dQ)
        subband_ratio = sb.size / x.size
        print(subband_ratio)
        MSE_wavelet_domain += dist * subband_ratio #* gains[counter]
        counter += 1
    y_dQ.append(tuple(sr_dQ))

z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
cz_dQ = np.clip(z_dQ, a_min=0, a_max=255)
#zz_dQ = np.clip( YUV.to_RGB(z_dQ) + 128, a_min=0, a_max=255)
#zz_dQ = YUV.to_RGB(z_dQ) #+ 128
#print("Distortion in the image domain:", distortion.MSE(xx + 128, zz_dQ))
print("Distortion in the image domain:", distortion.MSE(x, cz_dQ))
print("Distortion in the wavelet domain:", MSE_wavelet_domain)
image.show_RGB_image(cz_dQ, f"Reconstructed image (Q_step={Q_step})")

## Optimal quantization progression

The previous quantization is not (usually) optimal, because the RD constribution of each subband is not constant.
Let's use now a different quantization step for each subband that operates (approximately) at the same RD slope.

## An orthogonality test

Orthogonality is necessary to avoid that the quantization error generated in a subband does not affect to the rest of subband. This will speed up the RD optimization because the distortion can be measured in the DWT domain.

This orthogonality test does:
1. Compute the DWT of an image.
2. Set to zero all the subbands except one.
3. Compute the inverse DWT.
4. Compute the DWT again of the previous reconstruction.
5. Test if the decomposition matches the one generated in the step 2.  If matches (with some maximum error), the transform is orthogonal.

In [None]:
y = DWT.analyze(x, wavelet, N_levels)
subband_to_keep = 5
if subband_to_keep > DWT._N_levels:
    print("No way, José")
y[0][...] = 0.0
counter = 0
for sr in y[1:]:
    for sb in sr:
        if counter != subband_to_keep:
            sb[...] = 0.0
        counter += 1
z = DWT.synthesize(y, wavelet, N_levels)
#image.show_RGB_image(z, "Reconstructed image")
y2 = DWT.analyze(z, wavelet, N_levels)
counter = 0
orthogonal = True
for sr, sr2 in zip(y[1:], y2[1:]):
    for sb, sb2 in zip(sr, sr2):
        #print((sb == sb2).allclose())
        if not np.allclose(sb, sb2):
            orthogonal = False
        #if counter == subband_to_keep:
        #    image.show_RGB_image(sb)
        #    image.show_RGB_image(sb2)
        counter += 1
print("Orthogonal:", orthogonal)

Another way to know if the transform is orthogonal is compute the quantization distortion in the wavelet domain and see if it is the same than the distortion in the image domain. 

## Optimal quantization progression

The previous quantization is not (usually) optimal, because the RD constribution of each subband is not constant.
Let's use now a different quantization step for each subband that operates (approximately) at the same RD slope.

This information is important to known if the transform is unitary or not (usually, biorthogonal transforms are not unitary, i.e., the energy of the decomposition is different to (usually larger than) the energy of the image). Notice that if the transform is not unitary, the distortion is measured differently in the image and the transform domain. For example, is the gain is larger than 1, then overall distortion should be divided by the gain.

In [None]:
x = read_image(test_image)
#x = YUV.from_RGB(xx)
y = DWT.analyze(x, wavelet, N_levels)
image_energy = distortion.energy(x)
image_average_energy = image_energy / x.size
print("Image average energy:", image_average_energy)
#decom_average_energy = distortion.average_energy(y[0])*y[0].size/x.size
decom_energy = distortion.energy(y[0])
counter = 1
for sr in y[1:]:
    for sb in sr:
        #decom_energy += distortion.average_energy(sb)*sb.size/x.size
        decom_energy += distortion.energy(sb)
        counter += 1
print("Decomposition energy", decom_energy)
decom_average_energy = decom_energy / x.size
print("Decomposition average energy", decom_average_energy)
forward_transform_gain = decom_energy/image_energy
print("Forward transform gain:", forward_transform_gain)
print("The transform is", end=' ')
try:
    np.testing.assert_almost_equal(forward_transform_gain, 1.0)
except AssertionError:
    print("not unitary")
else:
    print("unitary")