# DWT (Discrete Wavelet Transform)

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pylab
%matplotlib inline
import image
import DWT
import pywt
import distortion
import YCoCg as YUV
import deadzone as Q

In [None]:
#wavelet = pywt.Wavelet("Haar")
#wavelet = pywt.Wavelet("db5")
wavelet = pywt.Wavelet("bior3.1")

In [None]:
test_image = "../sequences/lena_color/"

In [None]:
N_levels = 5

In [None]:
Q_steps = [128, 64, 32, 16, 8, 4, 2, 1]

## First ... some handy routines

In [None]:
def read_image(prefix):
    x = image.read(prefix, 0)
    if len(x.shape) == 2:
        extended_x = np.zeros(shape=(x.shape[0],  x.shape[1], 3), dtype=np.uint16) 
        extended_x[..., 0] = x
        return extended_x
    else:
        return x

def write_compact_decomposition(decom, prefix, image_number):
    rows = decom[len(decom)-1][0].shape[0]*2
    cols = decom[len(decom)-1][0].shape[1]*2
    coms = decom[0].shape[2]
    image_shape = (rows, cols, coms)
    view = np.empty(image_shape, np.uint16)
    # LL subband
    view[0:decom[0].shape[0],
         0:decom[0].shape[1]] = decom[0] + 32768

    for l in range(len(decom)-1):

        # LH
        view[0:decom[l+1][0].shape[0],
             decom[l+1][0].shape[1]:decom[l+1][0].shape[1]*2] =\
                decom[l+1][0] + 32768

        # HL
        view[decom[l+1][1].shape[0]:decom[l+1][1].shape[0]*2,
             0:decom[l+1][1].shape[1]] =\
                decom[l+1][1] + 32768

        # HH
        view[decom[l+1][2].shape[0]:decom[l+1][2].shape[0]*2,
             decom[l+1][2].shape[1]:decom[l+1][2].shape[1]*2] =\
                decom[l+1][2] + 32768
            
    return image.write(view, prefix, image_number)
    
def read_compact_decomposition(prefix, image_number, N_levels):
    view = image.read(prefix, image_number)
    wavelet = pywt.Wavelet("Haar")
    decom = DWT.analyze(np.zeros_like(view), wavelet, N_levels)
    
    # LL subband
    decom[0][...] = view[0:decom[0].shape[0],
                         0:decom[0].shape[1]] - 32768
    
    for l in range(len(N_levels)):
        
        # LH
        decom[l+1][0] =\
            view[0:decom[l+1][0].shape[0],
                 decom[l+1][0].shape[1]:decom[l+1][0].shape[1]*2] - 32668
            
        # HL
        decom[l+1][1] =\
            view[decom[l+1][1].shape[0]:decom[l+1][1].shape[0]*2,
                 0:decom[l+1][1].shape[1]] - 32768
            
        # HH
        decom[l+1][2] =\
            view[decom[l+1][2].shape[0]:decom[l+1][2].shape[0]*2,
                 decom[l+1][2].shape[1]:decom[l+1][2].shape[1]*2] - 32768

    return decom

## Testing `DWT.analyze_step()` and `DCT.synthesize_step()`

In [None]:
x = read_image(test_image)
image.show_RGB_image(x, title="Original")

In [None]:
L,H = DWT.analyze_step(x, wavelet)

In [None]:
image.show_RGB_image(255*image.normalize(L), "LL DWT domain")
subbands = ("LH", "HL", "HH")
for i, sb in enumerate(subbands):
    image.show_RGB_image(255*image.normalize(H[i]), f"{sb} DWT domain")

In [None]:
z = DWT.synthesize_step(L, H, wavelet).astype(np.uint8)

In [None]:
r = x - z

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_levels}")

In [None]:
r.max()

The DCT is not fully reversible, but it is almost.

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Testing `DWT.analyze()` and `DCT.synthesize()`

In [None]:
y = DWT.analyze(x, wavelet, N_levels)
z = DWT.synthesize(y, wavelet, N_levels).astype(np.uint8)

In [None]:
r = x - z

In [None]:
r.max()

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_levels}")

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Orthogonality test

Orthogonality is necessary to avoid that the quantization error generated in a subband does not affect to the rest of subband. This will speed up the RD optimization because the distortion can be measured in the DWT domain.

This orthogonality test does:
1. Compute the DWT of an image.
2. Set to zero all the subbands except one.
3. Compute the inverse DWT.
4. Compute the DWT again of the previous reconstruction.
5. Test if the decomposition matches the one generated in the step 2.  If matches (with some maximum error), the transform is orthogonal.

In [None]:
y = DWT.analyze(x, wavelet, N_levels)
subband_to_keep = 5
if subband_to_keep > DWT._N_levels:
    print("No way, José")
y[0][...] = 0.0
counter = 0
for sr in y[1:]:
    for sb in sr:
        if counter != subband_to_keep:
            sb[...] = 0.0
        counter += 1
z = DWT.synthesize(y, wavelet, N_levels)
#image.show_RGB_image(z, "Reconstructed image")
y2 = DWT.analyze(z, wavelet, N_levels)
counter = 0
orthogonal = True
for sr, sr2 in zip(y[1:], y2[1:]):
    for sb, sb2 in zip(sr, sr2):
        #print((sb == sb2).allclose())
        if not np.allclose(sb, sb2):
            orthogonal = False
        #if counter == subband_to_keep:
        #    image.show_RGB_image(sb)
        #    image.show_RGB_image(sb2)
        counter += 1
print("Orthogonal:", orthogonal)

## Subband gains
The gain of each subband can be computed giving energy to a subband, performing the inverse transform, and measuring the energy of the reconstruction divided by the number of coefficients in the subband. The gains are important because the quantization error generated in a subband is multiplied by its gain in the reconstructed image. Notice that in the case of the DWT, the high-frequency subbands has more coefficients than the low-frequency subbands.

In [None]:
gains = []
y = DWT.analyze(np.zeros_like(x), wavelet, N_levels)
y[0][...] = 1.0 
z = DWT.synthesize(y, wavelet, N_levels)
gains.append(distortion.energy(z)/np.size(y[0]))
prev_sb = y[0]
for sr in y[1:]:
    for sb in sr:
        prev_sb[...] = 0.0
        sb[...] = 1.0 
        z = DWT.synthesize(y, wavelet, N_levels)
        gains.append(distortion.energy(z)/np.size(sb))
        prev_sb = sb
print(gains)

## An example of quantization

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
x = YUV.from_RGB(xx)
y = DWT.analyze(x, wavelet, N_levels)
Q_step = 128

LL = y[0]
LL_k = Q.quantize(LL, Q_step)
LL_dQ = Q.dequantize(LL_k, Q_step)
counter = 1
y_dQ = [LL_dQ]
for sr in y[1:]:
    sr_dQ = []
    for sb in sr:
        sb_k = Q.quantize(sb, Q_step)
        sb_dQ = Q.dequantize(sb_k, Q_step)
        sr_dQ.append(sb_dQ)
    y_dQ.append(tuple(sr_dQ))

z_dQ = DWT.synthesize(y_dQ, wavelet, N_levels)
zz_dQ = np.clip( YUV.to_RGB(z_dQ) + 128, a_min=0, a_max=255)
image.show_RGB_image(zz_dQ, f"Reconstructed image (Q_step={Q_step})")

## RD performance using "constant" quantization among subbands

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
x = YUV.from_RGB(xx)

DWT_points = []
for Q_step in Q_steps:
    y = DWT.analyze(x, wavelet, N_levels)
    y_k = []
    LL = y[0]
    LL_k = Q.quantize(LL, Q_step)
    LL_dQ = Q.dequantize(LL_k, Q_step)
    y_k.append(LL_k)
    dist = distortion.MSE(LL, LL_dQ)
    MSE = (gains[0] * (dist * LL.size))/x.size
    #print(gains[0], dist, gains[0] * dist, MSE)
    #for i in range(4):
    #    for j in range(4):
    #        print(LL[i, j], LL_dQ[i, j])
    counter = 1
    for sr in y[1:]:
        sr_k = []
        for sb in sr:
            #print(MSE)
            sb_k = Q.quantize(sb, Q_step)
            sb_dQ = Q.dequantize(sb_k, Q_step)
            sr_k.append(sb_k)
            dist = distortion.MSE(sb, sb_dQ)
            #print(gains[counter], dist, gains[counter] * dist, MSE)
            MSE += (gains[counter] * (dist * sb.size))/x.size
            counter += 1
        y_k.append(tuple(sr_k))
    BPP = (write_compact_decomposition(y_k, f"/tmp/constant_{Q_step}", 0)*8)/x.size
    print(f"{Q_step} {BPP} {MSE}")
    DWT_points.append((BPP, MSE))

In [None]:
DCT_points = []
with open("DCT.txt", 'r') as f:
    for line in f:
        rate, _distortion = line.split('\t')
        DCT_points.append((float(rate), float(_distortion)))

In [None]:
pylab.figure(dpi=150)
pylab.plot(*zip(*DCT_points), label=f"DCT optimal Q")
pylab.plot(*zip(*DWT_points), label=f"DWT constant Q")
pylab.title(test_image)
pylab.xlabel("Bits/Pixel")
pylab.ylabel("MSE")
plt.legend(loc='upper right')
pylab.show()