# A [RD](https://en.wikipedia.org/wiki/Rate%E2%80%93distortion_theory)-comparison of DWTs

In [None]:
import sounddevice as sd
import pywt # pip install pywavelets
import math
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy import signal
import zlib
import pylab

In [None]:
def plot(y, xlabel='', ylabel='', title=''):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(title)
    ax.grid()
    ax.xaxis.set_label_text(xlabel)
    ax.yaxis.set_label_text(ylabel)
    x = np.linspace(0, len(y)-1, len(y))
    ax.plot(x, y, '.', markersize=1)
    plt.show(block=False)

## Produce some sound ... for example, speak for 5 seconds
The audio signal will be used to compare the RD performance of the transforms.

In [None]:
fs = 44100      # Sampling frequency
duration = 5.0  # seconds
x = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype=np.int16)
print("Say something!")
while sd.wait():
    pass
print("done")
x = x.flatten()

## This is your audio

In [None]:
%matplotlib inline
sd.play(x)
plot(x, "Sample", "Amplitude", "Audio Signal")

## Let's remove some samples from the begining

In [None]:
x = x[50000:]

## This is your definitive audio sequence

In [None]:
%matplotlib inline
sd.play(x)
plot(x, "Sample", "Amplitude", "Audio Signal")

### Configuration

In [None]:
levels = 5          # Number of levels of the DWT
#filters_name = "db5"
#filters_name = "haar"
filters_name = "db11"
#filters_name = "db20"
#filters_name = "bior2.2"
#ilters_name = "bior3.5"
#filters_name = "rbio2.2"
wavelet = pywt.Wavelet(filters_name)
signal_mode_extension = "per"
Delta = 32
Q_steps = range(Delta, 1024, Delta)

In [None]:
def get_filter(wavelet, coef_index, N):
    zeros = np.zeros(N)
    decomposition = pywt.wavedec(zeros, wavelet=wavelet, level=1, mode="per")
    coefficients, slices = pywt.coeffs_to_array(decomposition)
    coefficients[coef_index] = 1
    decomposition = pywt.array_to_coeffs(coefficients, slices, output_format="wavedec")
    samples = pywt.waverec(decomposition, wavelet=wavelet, mode="per")
    return samples

## [Filter's response in the frequency domain](https://en.wikipedia.org/wiki/Filter_(signal_processing)#The_transfer_function)

In [None]:
%matplotlib inline
#K0, K1, __ = wavelet.wavefun(level=1) # For orthogonal transforms
#K0_basis, K0_dual, K1_basis, K1_dual__ = wavelet.wavefun(level=1) # For bi-orthogonal transforms
K0 = get_filter(wavelet, 8, 32)
K1 = get_filter(wavelet, 24, 32)
w0, h0 = signal.freqz(K0, fs=44100)
w1, h1 = signal.freqz(K1, fs=44100)

plt.subplot(211)
plt.title(f'{filters_name} Frequency Response (low/high pass)')
#plt.plot(w1, 20 * np.log10(abs(h1)), 'b')
plt.plot(w0, abs(h0), 'b')
plt.xlabel('')
plt.ylabel('Gain')
plt.subplot(212)
#plt.plot(w2, 20 * np.log10(abs(h2)), 'b')
plt.plot(w1, abs(h1), 'b')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Gain')
plt.show()

### Total gain of the filters

In [None]:
def average_complex_energy(x):
    #return np.sum(x[:, 0].astype(np.double)*x[:, 0].astype(np.double))/len(x[:, 0]) + \
    #       np.sum(x[:, 1].astype(np.double)*x[:, 1].astype(np.double))/len(x[:, 1])
    return np.sum(x.real.astype(np.double)*x.real.astype(np.double) +
                  x.imag.astype(np.double)*x.imag.astype(np.double))/len(x)

In [None]:
print(f"Average energy of K0 = {average_complex_energy(h0)}")
print(f"Average energy of K1 = {average_complex_energy(h1)}")

### Conclusion

Orthogonal transforms implemented in PyWavelets have the same gain in both filters, but this is not true for bi-orthogonal transforms.

## Some RD stuff

In [None]:
def average_energy(x):
    #return np.sum(x[:, 0].astype(np.double)*x[:, 0].astype(np.double))/len(x[:, 0]) + \
    #       np.sum(x[:, 1].astype(np.double)*x[:, 1].astype(np.double))/len(x[:, 1])
    return np.sum(x.astype(np.double)*x.astype(np.double))/len(x)

def RMSE(x, y):
    error_signal = x - y
    return math.sqrt(average_energy(error_signal))

# Based on https://stackoverflow.com/questions/15450192/fastest-way-to-compute-entropy-in-python
def entropy_in_bits_per_symbol(sequence_of_symbols):
    value, counts = np.unique(sequence_of_symbols, return_counts = True)
    probs = counts / len(sequence_of_symbols)
    n_classes = np.count_nonzero(probs)

    if n_classes <= 1:
        return 0

    entropy = 0.
    for i in probs:
        entropy -= i * math.log(i, 2)

    return entropy

def deadzone_quantize(x, quantization_step):
    k = (x / quantization_step).astype(np.int32)
    return k

def deadzone_dequantize(k, quantization_step):
    y = quantization_step * k
    return y

def deadzone_qdeq(x, quantization_step):
    k = deadzone_quantize(x, quantization_step)
    y = deadzone_dequantize(k, quantization_step)
    return k, y

def RD_curve(data, wavelet, levels):
    RD_points = []
    for q_step in Q_steps:
        #print(q_step)
        decomposition = pywt.wavedec(data, wavelet=wavelet, level=levels,
                                     mode=signal_mode_extension)
        quantized_decomposition = []
        for subband in decomposition:
            #print((subband))
            quantized_subband = deadzone_quantize(subband, q_step)
            quantized_decomposition.append(quantized_subband)
        k = np.concatenate(quantized_decomposition)
        rate = entropy_in_bits_per_symbol(k) + entropy_in_bits_per_symbol(k)
        dequantized_decomposition = []
        for subband in quantized_decomposition:
            dequantized_subband = deadzone_dequantize(subband, q_step)
            dequantized_decomposition.append(dequantized_subband)
        reconstructed_data = pywt.waverec(dequantized_decomposition, wavelet=wavelet,
                                          mode=signal_mode_extension)
        distortion = RMSE(data, reconstructed_data)
        RD_points.append((rate, distortion))
    return RD_points

## RD curve

In [None]:
#RD_points = RD_curve(x[0:frames_per_chunk], wavelet, 3)
RD_points = RD_curve(x, wavelet, levels)

In [None]:
%matplotlib inline
plt.title("RD Tradeoff")
plt.xlabel("Bits per Sample")
plt.ylabel("RMSE")
#plt.xscale("log")
#plt.yscale("log")
plt.scatter(*zip(*RD_points), c='b', marker="+", label=f'{wavelet}')
#plt.scatter(*zip(*KLT_RD_points), c='r', marker="x", label='KLT')
#plt.legend(loc='upper right')
plt.show()

## RD curve using DEFLATE

In [None]:
def analyze(chunk, wavelet, levels):
    decomposition = pywt.wavedec(chunk, wavelet=wavelet, level=levels,
                                 mode=signal_mode_extension)
    return decomposition

def quantize(decomposition, q_steps):
    quantized_decomposition = []
    for subband, q_step in zip(decomposition, q_steps):
        quantized_subband = deadzone_quantize(subband, q_step)
        quantized_decomposition.append(quantized_subband)
    return quantized_decomposition

def analyze_and_quantize(chunk, wavelet, levels, q_steps):
    decomposition = analyze(chunk, wavelet, levels)
    quantized_decomposition = quantize(decomposition, q_steps)
    return quantized_decomposition

def dequantize(quantized_decomposition, q_steps):
    dequantized_decomposition = []
    for subband, q_step in zip(quantized_decomposition, q_steps):
        dequantized_subband = deadzone_dequantize(subband, q_step)
        dequantized_decomposition.append(dequantized_subband)
    return dequantized_decomposition

def synthesize(decomposition, wavelet):
    reconstructed_chunk = pywt.waverec(decomposition, wavelet=wavelet,
                                       mode=signal_mode_extension)
    return reconstructed_chunk

def dequantize_and_synthesize(quantized_decomposition, wavelet, q_steps):
    dequantized_decomposition = dequantize(quantized_decomposition, q_steps)
    reconstructed_chunk = synthesize(dequantized_decomposition, wavelet)
    return reconstructed_chunk

def RD_curve_DEFLATE(chunk, wavelet, levels):
    RD_points = []
    for q_step in Q_steps:
        q_steps = [q_step] * (levels + 1)
        quantized_decomposition = analyze_and_quantize(chunk, wavelet, levels, q_steps)
        #print(len(quantized_decomposition[3]))
        k = np.concatenate(quantized_decomposition)
        rate = 8*len(zlib.compress(k.copy()))/len(chunk)
        #print(len(zlib.compress(k.copy())))
        #rate = entropy_in_bits_per_symbol(k) + entropy_in_bits_per_symbol(k)
        reconstructed_chunk = dequantize_and_synthesize(quantized_decomposition, wavelet, q_steps)
        distortion = RMSE(chunk, reconstructed_chunk)
        RD_points.append((rate, distortion))
    return RD_points

In [None]:
RD_points_DEFLATE = RD_curve_DEFLATE(x, wavelet, levels)
#print(RD_points_DEFLATE)

In [None]:
%matplotlib inline
plt.title("RD Tradeoff")
plt.xlabel("Bits per Sample")
plt.ylabel("RMSE")
#plt.xscale("log")
#plt.yscale("log")
plt.scatter(*zip(*RD_points), c='b', marker="+", label='entropy')
plt.scatter(*zip(*RD_points_DEFLATE), c='r', marker="x", label='DEFLATE')
#plt.scatter(*zip(*KLT_RD_points), c='r', marker="x", label='KLT')
plt.legend(loc='upper right')
plt.show()

## Impact of  each subband in the distortion of the reconstructed chunk

In [None]:
def subbands_DEFLATE_RD_curve(chunk, wavelet, levels):
    '''RD curves per subband.'''    
    subbands_RD_points = [None] * (levels + 1)
    decomposition = analyze(chunk, wavelet, levels)
    #print(len(decomposition))
    for l in range(levels + 1):
        #print(len(chunk), len(decomposition[l]))
        #print(decomposition)
        zero_chunk = np.zeros_like(chunk)
        zero_decomposition = analyze(zero_chunk, wavelet, levels)
        zero_decomposition[l][:] = decomposition[l]
        #print(decomposition[l])
        subbands_RD_points[l] = []
        for q_step in Q_steps:
            #print(q_step)
            quantized_subband = deadzone_quantize(zero_decomposition[l], q_step)
            zero_decomposition[l][:] = quantized_subband
            k = np.concatenate(zero_decomposition)
            rate = 8*len(zlib.compress(k.copy()))/len(chunk)
            dequantized_subband = deadzone_dequantize(quantized_subband, q_step)
            zero_decomposition[l][:] = dequantized_subband
            reconstructed_chunk = synthesize(zero_decomposition, wavelet)
            distortion = RMSE(chunk, reconstructed_chunk)
            #print(l, rate, distortion)
            subbands_RD_points[l].append((rate, distortion))
        subbands_RD_points[l].reverse()

    return subbands_RD_points

In [None]:
subband_DEFLATE_RD_points = subbands_DEFLATE_RD_curve(x, wavelet, levels)

In [None]:
%matplotlib inline
from matplotlib import cm
plt.title("RD using only one subband")
plt.xlabel("Bits per Sample")
plt.ylabel("RMSE")
plt.xscale("log")
#plt.yscale("log")
i = 0
for s in subband_DEFLATE_RD_points:
    plt.plot(*zip(*s), marker=".", label=f'Using only subband index {i}', color=cm.cool(i/4))
    i += 1
#plt.scatter(*zip(*KLT_RD_points), c='r', marker="x", label='KLT')
plt.legend(loc='best')
plt.show()
print("RMSE (using only zeros) =", RMSE(x, np.zeros(x.shape)))

In [None]:
%matplotlib notebook
fig, ax = plt.subplots()

def animate(frame):
    i = 0
    #plt.xscale("log")
    #plt.yscale("log")
    plt.xscale("linear")
    plt.yscale("linear")
    plt.xlabel("Bits per Sample")
    plt.ylabel("RMSE")
    for s in subband_DEFLATE_RD_points:
        plt.title(f"$\Delta={frame*Delta}$")
        plt.plot(*zip(*s[0:frame]), marker=".", color=cm.cool(i/4))
        i += 1
    #return plt,
    
def init():
    for s in subband_DEFLATE_RD_points:
        plt.plot(*zip(*s), marker=".", color='white')
    #return plt,

ani = animation.FuncAnimation(
    fig, animate, init_func=init, interval=1000, frames=len(Q_steps), repeat=True)

# To save the animation, use e.g.
#
# ani.save("movie.mp4")
#
# or
#
# writer = animation.FFMpegWriter(
#     fps=15, metadata=dict(artist='Me'), bitrate=1800)
# ani.save("movie.mp4", writer=writer)

plt.show()

## Conclusions
1. Subband 0 is, with a lot of difference, the most important subband.
2. Subband $l$ is, with a lot of difference, the least important one.
3. If we use the same quantization step size for all the subbands, the quality of the reconstruction should be good because the *quality layers* are used by their slope.