# Quantization of block DWT-ransformed audio signals

In [None]:
import sounddevice as sd
import pywt
import math
import numpy as np
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from scipy import signal
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import pylab

## Capture an audio sequence

In [None]:
def plot(y, xlabel='', ylabel='', title='', marker='.'):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(title)
    ax.grid()
    ax.xaxis.set_label_text(xlabel)
    ax.yaxis.set_label_text(ylabel)
    x = np.linspace(0, len(y)-1, num=len(y))
    ax.plot(x, y, marker, markersize=1)
    plt.show(block=False)

In [None]:
fs = 44100
duration = 80000/44100  # seconds
signal = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype=np.int16)
print("Say something!")
while sd.wait():
    pass
print("done")
signal = signal.flatten()

In [None]:
plot(signal, "sample", "amplitude", "original")

## Select a number of levels of the DWT

In [None]:
levels = 3

## Filters selection

In [None]:
#wavelet_name = "haar"
wavelet_name = "db5"
#wavelet_name = "db20"
#wavelet_name = "bior2.2"
#wavelet_name = "rbio2.2"
wavelet = pywt.Wavelet(wavelet_name)
#mode = "zero"
mode = "per"

## Define a dead-zone quantizer

In [None]:
def deadzone_quantizer(x, quantization_step):
    k = (x / quantization_step).astype(np.int)
    return k

def deadzone_dequantizer(k, quantization_step):
    y = quantization_step * k
    return y

## Extract 3 consecutive chunks from the audio sequence

In [None]:
chunk_size = 128
chunk_number = 15
chunk_left = signal[chunk_size * (chunk_number - 1) : chunk_size * chunk_number]
chunk_center = signal[chunk_size * chunk_number : chunk_size * (chunk_number + 1)]
chunk_right = signal[chunk_size * (chunk_number + 1) : chunk_size * (chunk_number + 2)]
chunks = np.concatenate([chunk_left, chunk_center, chunk_right])

In [None]:
plot(chunks, "sample", "amplitude", "3 consecutive chunks", '-')

In [None]:
128*3

In [None]:
len(chunks)

## Quantize the chunks in the DWT domain
Each chunk is transformed independently.

In [None]:
quantization_step = 128

def transform_and_quantize(chunk):
    decomposition = pywt.wavedec(chunk, wavelet=wavelet, level=levels, mode=mode)
    coefficients, slices = pywt.coeffs_to_array(decomposition)
    quantization_indexes = deadzone_quantizer(coefficients, quantization_step)
    return quantization_indexes
    
def dequantize_and_detransform(quantization_indexes):
    zeros = np.empty_like(quantization_indexes)
    _ = pywt.wavedec(zeros, wavelet=wavelet, level=levels, mode=mode)
    _, slices = pywt.coeffs_to_array(_)
    quantized_coeffs = deadzone_dequantizer(quantization_indexes, quantization_step)
    decomposition = pywt.array_to_coeffs(quantized_coeffs, slices, output_format="wavedec")
    reconstructed_chunk = pywt.waverec(decomposition, wavelet=wavelet, mode=mode)
    return reconstructed_chunk

def process_chunk(chunk):
    quantization_indexes = transform_and_quantize(chunk)
    reconstructed_chunk = dequantize_and_detransform(quantization_indexes)
    return reconstructed_chunk
    
rchunk_left = process_chunk(chunk_left)
rchunk_center = process_chunk(chunk_center)
rchunk_right = process_chunk(chunk_right)

## Concatenation of the reconstructed chunks

In [None]:
rchunks = np.concatenate([rchunk_left, rchunk_center, rchunk_right])
plot(rchunks, "sample", "amplitude", "concatenation of the reconstructed chunks", '-')

Signal discontinuities happen between chunks :-/

## Reconstruction of the concatenated chunks
This is the ideal reconstruction (ingnoring the first and the last samples of the concatenation).

In [None]:
ideal_chunks = process_chunk(chunks)
plot(ideal_chunks, "sample", "amplitude", "reconstructed concatenated chunks", '-')

## A solution: use the neighbor samples between chunks

In [None]:
number_of_overlaped_samples = 1 << math.ceil(math.log(wavelet.dec_len * levels) / math.log(2))
print("number_of_overlaped_samples =", number_of_overlaped_samples)

### Create an extended chunk that overlaps with the previous and the next one

In [None]:
last_samples_left_chunk = chunk_left[chunk_size - number_of_overlaped_samples :]
first_samples_right_chunk = chunk_right[: number_of_overlaped_samples]
extended_chunk = np.concatenate([last_samples_left_chunk, chunk_center, first_samples_right_chunk])
print("number of samples overlaped with the previous chunk=", len(last_samples_left_chunk))
print("number of samples in the current chunk =", len(chunk_center))
print("number of samples overlaped with the next chunk =", len(first_samples_right_chunk))
print("length of the extended chunk =", len(extended_chunk))
plot(extended_chunk, "sample", "amplitude", "extended chunk", '-')

### Reconstruction of the extended chunk

In [None]:
rextended_chunk = process_chunk(extended_chunk)

In [None]:
extended_chunk.shape

In [None]:
plot(rextended_chunk, "sample", "amplitude", "reconstructed extended chunk", '-')

### Extract the chunk from the extended chunk

In [None]:
rchunk = rextended_chunk[number_of_overlaped_samples : -number_of_overlaped_samples]

In [None]:
plot(rchunk, "sample", "amplitude", "reconstructed chunk with overlaping", '-')

### Reconstruction of the chunk without the overlaped coeffs

In [None]:
central_coeffs = extended_chunk[number_of_overlaped_samples: -number_of_overlaped_samples]
central_coeffs.shape

In [None]:
recons_central_coeffs = process_chunk(central_coeffs)

In [None]:
plot(recons_central_coeffs, "sample", "amplitude", "reconstructed extended chunk", '-')

### Supossing that the exterior coeffs are zero

In [None]:
zeroed_extended_chunk = np.concatenate(
    [np.zeros(number_of_overlaped_samples),
    extended_chunk[number_of_overlaped_samples: -number_of_overlaped_samples],
    np.zeros(number_of_overlaped_samples)]
)
recons_zeroed_extended_chunk = process_chunk(zeroed_extended_chunk)[number_of_overlaped_samples: -number_of_overlaped_samples]
plot(recons_zeroed_extended_chunk, "sample", "amplitude", "reconstructed extended chunk", '-')

### The transformed and quantized (non extended) chunk (without overlaping)

In [None]:
plot(rchunk_center, "sample", "amplitude", "reconstructed chunk without overlaping", '-')