# DWT (Discrete Wavelet Transform)

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import image
import DWT
import pywt
import distortion

In [None]:
#wavelet = pywt.Wavelet("Haar")
wavelet = pywt.Wavelet("bior1.3")

In [None]:
test_image = "../sequences/lena_color/"

In [None]:
N_DWT_levels = 5

In [None]:
Q_steps = [128, 64, 32, 16]#, 8, 4, 2]

## First ... some handy routines

In [None]:
def read_image(prefix):
    x = image.read(prefix, 0)
    if len(x.shape) == 2:
        extended_x = np.zeros(shape=(x.shape[0],  x.shape[1], 3), dtype=np.uint16) 
        extended_x[..., 0] = x
        return extended_x
    else:
        return x


## Testing `DWT.analyze_step()` and `DCT.synthesize_step()`

In [None]:
x = read_image(test_image)
image.show_RGB_image(x, title="Original")

In [None]:
L,H = DWT.analyze_step(x, wavelet=wavelet)

In [None]:
image.show_RGB_image(255*image.normalize(L), "LL DWT domain")
subbands = ("LH", "HL", "HH")
for i, sb in enumerate(subbands):
    image.show_RGB_image(255*image.normalize(H[i]), f"{sb} DWT domain")

In [None]:
z = DWT.synthesize_step(L,H, wavelet=wavelet).astype(np.uint8)

In [None]:
r = x - z

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_DWT_levels}")

In [None]:
r.max()

The DCT is not fully reversible, but it is almost.

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Testing `DWT.analyze()` and `DCT.synthesize()`

In [None]:
y = DWT.analyze(x, wavelet=wavelet)
z = DWT.synthesize(y, wavelet=wavelet).astype(np.uint8)

In [None]:
r = x - z

In [None]:
r.max()

In [None]:
image.show_RGB_image(255*image.normalize(r), f"DWT finite precission error N_DWT_levels={N_DWT_levels}")

In [None]:
image.show_RGB_image(z, "Reconstructed image")

## Orthogonality test

Orthogonality is necessary to avoid that the quantization error generated in a subband does not affect to the rest of subband. This will speed up the RD optimization because the distortion can be measured in the DWT domain.

This orthogonality test does:
1. Compute the DWT of an image.
2. Set to zero all the subbands except one.
3. Compute the inverse DWT.
4. Compute the DWT again of the previous reconstruction.
5. Test if the decomposition matches the one generated in the step 2.  If matches (with some maximum error), the transform is orthogonal.

In [None]:
y = DWT.analyze(x, wavelet=wavelet)
subband_to_keep = 5
if subband_to_keep > DWT._N_levels:
    print("No way, José")
y[0][...] = 0.0
counter = 0
for sr in y[1:]:
    for sb in sr:
        if counter != subband_to_keep:
            sb[...] = 0.0
        counter += 1
z = DWT.synthesize(y, wavelet=wavelet)
#image.show_RGB_image(z, "Reconstructed image")
y2 = DWT.analyze(z, wavelet=wavelet)
counter = 0
orthogonal = True
for sr, sr2 in zip(y[1:], y2[1:]):
    for sb, sb2 in zip(sr, sr2):
        #print((sb == sb2).allclose())
        if not np.allclose(sb, sb2):
            orthogonal = False
        #if counter == subband_to_keep:
        #    image.show_RGB_image(sb)
        #    image.show_RGB_image(sb2)
        counter += 1
print("Orthogonal:", orthogonal)

## Subband gains
The gain of each subband can be computed giving energy to a subband, performing the inverse transform, and measuring the energy of the reconstruction. The gains are important because the quantization error generated in a subband is multiplied by its gain in the reconstructed image. Notice that in the case of the DWT, the high-frequency subbands has a higher gain because they have more coefficients. In fact, the gain of a subband is proportional to the number of coefficients of such subband.

In [None]:
gains = []
y = DWT.analyze(np.zeros_like(x), wavelet=wavelet)
y[0][...] = 1.0 
z = DWT.synthesize(y, wavelet=wavelet)
gains.append(distortion.energy(z))
prev_sb = y[0]
for sr in y[1:]:
    for sb in sr:
        prev_sb[...] = 0.0
        sb[...] = 1.0 
        z = DWT.synthesize(y, wavelet=wavelet)
        gains.append(distortion.energy(z))
        prev_sb = sb
print(gains)

## RD performance using "constant" quantization among subbands

In [None]:
xx = read_image(test_image).astype(np.int16) - 128
x = YUV.from_RGB(xx)

RD_points = []
for Q_step in Q_steps:
    y = DWT.analyze(x, wavelet=wavelet)
    for sr 