# Representation of signals & inverse problems - G1-G2
---
## Lab 2: Wavelet transform and image processing

---
## Guidelines (read carefully before starting)

**Objective**: this lab explores some aspects of orthogonal wavelet transforms in 2D for image processing, with applications to denoising and filtering. All the numerical illustrations covered in this session rely on the Python module `Pywavelets`.

Each wavelet is determined by its low-pass filter or scaling function (in french _fonction d'échelle_). The longer this filter (*i.e.*, the more coefficients it contains), the larger the number of null moments. For instance, Daubechies wavelets are composed of $p=2k$ coefficients, for $k$ null moments.

A larger number of null moments allows regular signals/images to be better compressed. However, an extended support of the wavelet can be problematic to represent highly irregular regions of the signal, such as discontinuities.

The choice of a wavelet is a matter of compromise, depending on the type of structures expected to be observed in the image of interest.

**Guidelines**: after retrieving the resources for the lab on moodle:
- place the .zip archive in a local folder (Computer -> Documents/Python/);
- unzip the archive .zip;
- rename the folder with the convention lab2_Name1_Name2;
- duplicate the notebook file and rename it lab2_Name1_Name2.ipynb;
- at the end of the session, do not forget to transfer your work to your own network space if you are working on a machine from the school (do not leave your work on the C: drive).

**Assessment** &#8594; global grade from F to A (A+)

Assessmment based on your answer to the exercises reported in the notebook and any additional `.py` file produced. Custom code should be commented whenever appropriate. Figures should be clearly annotated (axes, title).

1. Numerical correctness
2. Implementation clarity (documentation, relevance of the comments)
3. Answers to the questions and overall presentation of the Jupyter notebook.

## Configuration

In [None]:
# make sure the notebook reloads the module each time we modify it
%load_ext autoreload
%autoreload 2

# Uncomment the next line if you want to be able to zoom on plots (one of the options below)
# %matplotlib widget
# %matplotlib notebook
%matplotlib inline

In [None]:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pywt
import scipy.signal as sg
from IPython.display import Audio
from lib.plotwavelet import plot_wavelet, psnr
from matplotlib.colors import LogNorm  # for Log normalization

SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 24

plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

## Contents <a id="content"></a>

1. [1D discrete wavelet transform](#section1)
   - [Exercise 1](#ex1)
2. [2D discrete wavelet transform](#section2)
   - [Exercise 2](#ex2) 
3. [Wavelet filtering](#section3)
   - [Exercise 3](#ex3) 
4. [Wavelet denoising](#section4)
   - [Exercise 4](#ex4) 
   - [Bonus: Exercise 5](#ex5)

---
## 1D discrete wavelet transform <a id="section1"></a> [(&#8593;)](#content)

[PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/index.html) (`pywt`) is an open source wavelet transform package for Python. It combines a simple high level interface with low level C and Cython performance. 

The Discrete Wavelet Transform (DWT) of a signal can be computed with the `pywt.dwt` and `pywt.wavedec` functions (see the example below). The result is stored in a dictionary of arrays containing each octave separately, as well as the approximation level. 

An example of application to a chirp signal is reported below, illustrating the multiresolution analysis of a wideband signal.

In [None]:
N = 2**13     # 8192
Fs = 1024
t = np.arange(0, N, 1) / Fs
x = sg.chirp(t, f0=60, f1=2, t1=10, method="hyperbolic")

Audio(x, rate=44 * Fs)  # artificially change Fs for listening purposes

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(t, x)
plt.xlabel("Time [s]")
plt.ylabel(r"Amplitude")
plt.grid()
plt.show()

Different wavelets are available. Use the `pywt.wavelist()` command to display the list of the available built-in wavelets.

In [None]:
#pywt.wavelist() # uncomment to run the command

print(pywt.wavelist(kind='discrete'))

#### Computing your first orthogonal wavelet transform
The function `pywt.dwt` computes the wavelet transform for a single decomposition level, returning the corresponding approximation coefficients and details.

In [None]:
cA, cD = pywt.dwt(x, "sym8")

plt.figure(figsize=(15, 10))

plt.subplot(211)
plt.plot(cA)
plt.title("Approximation")
plt.grid()

plt.subplot(212)
plt.plot(cD)
plt.title("Details")
plt.grid()
plt.show()

The `pywt.wavedec` function computes a multi-level wavelet transform. It returns an ordered list of coefficients arrays, where `level` denotes the number of decomposition levels. Wavelet coefficients are stored in an object `coeff`. The first element `coeff[0]` is an array containing approximation coefficients. The following elements `coeff[i]` are arrays containing the details coefficients at the different scales.

In [None]:
level = 7
coeff = pywt.wavedec(x, "sym8", level=level)

In [None]:
plt.figure(figsize=(8, 15))
plt.subplot(311)
plt.plot(coeff[0])
plt.ylabel("Approximation (j={})".format(level))
plt.grid()
plt.subplot(312)
plt.plot(coeff[1])
plt.ylabel("Details (j={})".format(level))
plt.grid()
plt.subplot(313)
plt.plot(coeff[2])
plt.ylabel("Details (j={})".format(level - 1))
plt.grid()
plt.show()

### Exercise 1 <a id="ex1"></a> [(&#8593;)](#content)
1. Read the documentation of the functions `pywt.dwt`, `pywt.idwt` and `pywt.wavedec`. What are these functions used for? What is the difference between these functions?
2. The smallest scales can be filtered out by setting the associated wavelet coefficients to zero. Illustrate the effect of setting to 0 coefficients of the details (or the approximation) for a given set of octaves.
_Indication: to do this, you need to compute the wavelet transform, set some coefficients to zero and then reconstruct the signal by using the inverse wavelet transform._
3. Try this filtering operation with several wavelets of your choice and compare the results. Observe and comment.
4. Use at least one other 1D signal using `pywt.data`, and repeat the operations described in 2. and 3.

Exercise 1.1

`pywt.dwt` : it's used to perform a single level discrete wavelet transform. Its input is the signal (an array-like data) in which the transform will be applied and it returns the approximation and details coefficients as individual arrays.

`pywt.idwt` : it's used to perform a single level inverse discrete wavelet transform. Its input are the approximation and details coefficients and it returns a sigle level reconstructed signal.

`pywt.wavedec` : it's used to performs multilevel discrete wavelet transform decomposition. It is analogous to the `pywt.dwt` but with the level of decomposition as a parameter. This is an interesting characteristic as it can be usefull to perform more than on level decomposition at a time.

In [None]:
# exercise 1.2

coeffs = pywt.wavedec(x, "sym8", level=5)

#[cAn, cDn, cDn-1, ..., cD2, cD1]
coeffs[1][:] = 0    # sets the coefficients associated to details 5 (cD5) equals to zero

x_new = pywt.waverec(coeffs, "sym8")

plt.figure(figsize=(15, 10))

plt.subplot(211)
plt.plot(x)
plt.title("original signal")
plt.grid()

plt.subplot(212)
plt.plot(x_new)
plt.title("reconstruction without cD5")
plt.grid()
plt.show()


In [None]:
# exercise 1.3

plt.figure(figsize=(15, 10))

plt.subplot(211)
plt.plot(x)
plt.title("original signal")
plt.grid()

wavelets = ["sym8", "bior2.2", "coif10", "haar", "db20"]

for wavelet in wavelets:    

    coeffs = pywt.wavedec(x, wavelet, level=5)

    #[cAn, cDn, cDn-1, ..., cD2, cD1]
    coeffs[1][:] = 0    # sets the coefficients associated to details 5 (cD5) equals to zero

    x_new = pywt.waverec(coeffs, wavelet)



    plt.figure(figsize=(15, 5))
    plt.plot(x_new)
    plt.title("reconstruction without cD5 and " + wavelet )
    plt.grid()
    plt.show()


Exercise 1.3

It is remarkable the fact that `haar` is definitely the worse reconstruction. This can be explained by the fact that the `haar` wavelet has very low resemblance to the original signal, so the supression of the details coefficients 5 is more noticeable on all frequencies. This can also be explained by spectral leakage phenomenon associated to `haar`. On the other hand, `db20`, expect for certain frequencies that were suppressed, seems to display a very good reconstruction. This can be explained by the fact that `db20` has a chirp-like, frequency-modulation-like shape, so it makes sense that the approximation coefficients would be more adequate to represent the original signal.

In [None]:
# exercise 1.4

# Charger un signal 1D de test
signal_name = 'Blocks' 
#signal_name = 'Bumps' 
#signal_name = 'Riemann' 
signal = pywt.data.demo_signal(signal_name, n=N)

#print(pywt.data.demo_signal("list", n=10))

coeffs = pywt.wavedec(signal, wavelet, level=5)

#[cAn, cDn, cDn-1, ..., cD2, cD1]
coeffs[1][:] = 0    # sets the coefficients associated to details 5 (cD5) equals to zero

x_new = pywt.waverec(coeffs, wavelet)

plt.figure(figsize=(15, 13))

plt.subplot(311)
plt.plot(signal)
plt.title("original signal")
plt.grid()

plt.subplot(312)
plt.plot(x_new)
plt.title("reconstruction without cD5")
plt.grid()

In [None]:
# exercise 1.4

plt.figure(figsize=(15, 4))

plt.plot(signal)
plt.title("original signal")
plt.grid()

wavelets = ["sym8", "bior2.2", "coif10", "haar", "db20"]

for wavelet in wavelets:    

    coeffs = pywt.wavedec(signal, wavelet, level=5)

    #[cAn, cDn, cDn-1, ..., cD2, cD1]
    coeffs[1][:] = 0    # sets the coefficients associated to details 5 (cD5) equals to zero

    x_new = pywt.waverec(coeffs, wavelet)

    plt.figure(figsize=(15, 5))
    plt.plot(x_new)
    plt.title("reconstruction without cD5 and " + wavelet )
    plt.grid()
    plt.show()

Exercise 1.4

The situation now is exactly the contrary to the previous question. This signal `bumps` clearly ressambles more `haar` than `bd20`, so it is better represented by it (because its approximation coefficients are more important). This is also clearly the case for many of the other signals that appear to have artifacts due to `Gibbs phenomenon`. 

---
## 2D discrete wavelet transform <a id="section2"></a> [(&#8593;)](#content)
We now turn to the computation of a 2D wavelet transform to analyze images (example images are available in the `img/` folder). Let us start with a simple chessboard image.

_Indication: you can consider other images, in particular using `pywt.data`._

In [None]:
filename = "img/chessboard.png"

I = mpimg.imread(filename)

plt.figure()
plt.imshow(I, cmap="gray")
plt.colorbar()
plt.show()

#### Wavelet transform
We first compute a 2D wavelet transform using the function `pywt.wavedec2` and display the results.

In [None]:
# Wavelet decomposition
Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
J = 2

wtcoeffs = pywt.wavedec2(I, wavelet="sym8", level=J, mode="periodization")

Since the coefficients are stored in a dictionary of objects, we need to convert these into an array for obtain a graphical representation using the `plot_wavelet` function.

In [None]:
# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(wtcoeffs)

# Graphical representation
plt.figure(figsize=(5, 5), dpi=200)
plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)

## Exercise 2 <a id="ex2"></a> [(&#8593;)](#content)
1. Observe the figure above and briefly recall its structure and what it represents.

2. Play with the parameter $J$ (corresponding to the optional parameter `level` from `pywt.wavedec2`). What does it correspond to?

3. Observe the wavelet transform of other images.

4. For the checkerboard image, try the Haar wavelet and compare the result with any other wavelet transform. What do you notice? Do you have an explanation?

Exercise 2.1

At the top left of the image, we can see the square that represents the second level approximation for the image. This square is surrounded by 3 other squares of same size, which represent the horizontal (right square), the diagonal (diagonal, right below square) and the vertical (below square) coefficients of second level of the wavelet decomposition. The same logic can be applied subsequently for the larger squares with lower levels of decomposition, they represent respectively the horizontal, vertical and diagonal coefficients for the 1 level decomposition. 

In [None]:
# exercise 2.2

filename = "img/chessboard.png"

I = mpimg.imread(filename)

# Wavelet decomposition
Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves

Js = [1, 2, 3, 4, 5]

for J in Js:    
    
    wtcoeffs = pywt.wavedec2(I, wavelet="sym8", level=J, mode="periodization")
    
    # Casting coeff in a single array
    wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(wtcoeffs)

    # Graphical representation
    plt.figure(figsize=(4, 4), dpi=120)
    plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)
    plt.title("Level of decomposition : " + str(J))
    


Exercise 2.2  

The parameter J defines the level of decomposition the be applied in the wavelet transform.

In [None]:
# exercise 2.3

images = ["img/hair.png", "img/mandrill.png", "img/periodic_bumps.png"]

for image in images:    
    
    I = mpimg.imread(image)
    
    # Wavelet decomposition
    Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
    J = 1
    
    wtcoeffs = pywt.wavedec2(I, wavelet="haar", level=J, mode="periodization")
    
    # Casting coeff in a single array
    wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(wtcoeffs)

    # Graphical representation
    plt.figure(figsize=(4, 4), dpi=120)
    
    plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)
    
    plt.grid()
    plt.show()

In [None]:
# exercise 2.4

filename = "img/chessboard.png"

I = mpimg.imread(filename)

plt.figure()
plt.imshow(I, cmap="gray")
plt.colorbar()
plt.show()

# Wavelet decomposition
Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
J = 2

#wtcoeffs = pywt.wavedec2(I, wavelet="sym8", level=J, mode="periodization")
wtcoeffs = pywt.wavedec2(I, wavelet="haar", level=J, mode="periodization")

# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(wtcoeffs)

# Graphical representation
plt.figure(figsize=(5, 5), dpi=200)
plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)

Exercise 2.4

The `haar` waveform is very similar to the `chessboard`, essentialy constant frequency step transitions from white to black. This similarity makes it possible to encompass all information on the original only on the approximation coefficients, leaving no information for the horizontal, vertical and diagonal details coefficients.

**Remark:** a 2D wavelet transform is a succession of two 1D DWT. Using only 1 step of the multiresolution hierarchy, one obtains the following result.

In [None]:
coeffs2 = pywt.dwt2(I, "sym8")

LL, (LH, HL, HH) = coeffs2

fig = plt.figure(figsize=(12, 3))
titles = [
    "Approximation",
    " Horizontal details",
    "Vertical details",
    "Diagonal details",
]
for i, a in enumerate([LL, LH, HL, HH]):
    ax = fig.add_subplot(1, 4, i + 1)
    ax.imshow(a, interpolation="nearest", cmap=plt.cm.gray)
    ax.set_title(titles[i], fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
plt.show()

---
## Wavelet filtering <a id="section3"></a> [(&#8593;)](#content)
This section illustrates a filtering procedure conducted in the wavelet domain. The purpose is to observe that one can filter an image both at a given scale, for specific orientations, and at chosen places on top.

In [None]:
filename = "img/chessboard.png"
I = mpimg.imread(filename)


plt.figure()
plt.imshow(I, cmap="gray")
plt.colorbar()
plt.show()

#### First steps for filtering
We will first use a 1 step decomposition and put some coefficients to zero.

In [None]:
coeffs2 = pywt.dwt2(I, "db4")

In [None]:
LL, (LH, HL, HH) = coeffs2

In [None]:
L = 8
LL[2 ** (L - 1) : 2**L + 1, 2 ** (L - 1) : 2**L + 1] = 0

In [None]:
LL, (LH, HL, HH) = coeffs2
fig = plt.figure(figsize=(12, 3))
titles = ["Approximation", " Horizontal detail", "Vertical detail", "Diagonal detail"]
for i, a in enumerate([LL, LH, HL, HH]):
    ax = fig.add_subplot(1, 4, i + 1)
    ax.imshow(a, interpolation="nearest", cmap=plt.cm.gray)
    ax.set_title(titles[i], fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
plt.show()

In [None]:
I_d = pywt.idwt2(coeffs2, "db4")
fig = plt.figure(figsize=(15, 7))
plt.subplot(121)
plt.imshow(I, cmap="gray") #he checkerboard image, try the Haar wavelet and compare the result with any other wavelet transform. What do you notice? Do you have an explan
plt.title("Before")
# plt.colorbar()
plt.subplot(122)
plt.imshow(I_d, cmap="gray")
plt.title("After")
# plt.colorbar()
plt.show()

## Exercise 3 <a id="ex3"></a> [(&#8593;)](#content)
1. Take a look at the documentation of the `pywt.wavedec2` and `pywt.waverec2`. How are the wavelet coefficients organized (scale? orientation?)?
2. Play with wavelet filtering procedure described above by setting some sets of coefficients to zero (approximation or details of various orientations). Illustrate its result with the `chessboard.png` image.
3. Do the same with any gray level image of your choice (see `img` folder). 
4. Observe and comment.

Exercise 3.1

`pywt.wavedec2` : performs a multilevel 2D discrete wavelet transform with level of decomposition as a parameter. It is the 2D equivalent of `pywt.wavedec`. It returns a list of lists of the coefficients in order of decreasing level of decomposition: [cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)], in which cAn is the approximaiton coefficients, cHn is the horizontal details coefficients of order n, cVn is the vertical details coefficients of order n and cDn is the diagonal details coefficients of order n.

`pywt.waverec2` : performs a multilevel 2D inverse discrete wavelet transform. It is the 2d equivalent of `pywt.waverec`. Given the coefficients of approximation and details, it returns the corresponding 2D array of reconstructed data. 

In [None]:
# question 2

# Load image
filename = "img/chessboard.png"
I = mpimg.imread(filename)

Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
J = 3

#I_D = pywt.wavedec2(I, wavelet="haar", mode='periodization', level=3, axes=(-2, -1))
I_D = pywt.wavedec2(I, wavelet="db20", mode='periodization', level=3, axes=(-2, -1))

# setting to zero all coeffiecients associated to vertical details of levels 1, 2 and 3
I_D[1][0][:] = 0
I_D[2][0][:] = 0
I_D[3][0][:] = 0
print( I_D[1][0][:] )

I_R = pywt.waverec2(I_D, wavelet="db20", mode='periodization', axes=(-2, -1))

fig = plt.figure(figsize=(15, 7))
plt.subplot(131)
plt.imshow(I, cmap="gray")
plt.title("Before")


# Wavelet visualization

# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(I_D)

# Graphical representation
plt.subplot(132)
plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)

plt.subplot(133)

plt.imshow(I_R, cmap="gray")
plt.title("After")
# plt.colorbar()
plt.show()


In [None]:
# Load image
filename = "img/chessboard.png"
I = mpimg.imread(filename)

Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
J = 3

#I_D = pywt.wavedec2(I, wavelet="haar", mode='periodization', level=3, axes=(-2, -1))
I_D = pywt.wavedec2(I, wavelet="db20", mode='periodization', level=3, axes=(-2, -1))

# setting to zero all coeffiecients associated to horizontal details of levels 1, 2 and 3
I_D[1][1][:] = 0
I_D[2][1][:] = 0
I_D[3][1][:] = 0
print( I_D[1][0][:] )

I_R = pywt.waverec2(I_D, wavelet="db20", mode='periodization', axes=(-2, -1))

fig = plt.figure(figsize=(15, 7))
plt.subplot(131)
plt.imshow(I, cmap="gray")
plt.title("Before")


# Wavelet visualization

# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(I_D)

# Graphical representation
plt.subplot(132)
plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)

plt.subplot(133)

plt.imshow(I_R, cmap="gray")
plt.title("After")
# plt.colorbar()
plt.show()


Exercise 3.2

The example above shows thre resulting effect of setting all (horizontal and then vertical) details coefficients to zero. In the first case, when setting all horizontal details coefficients to zero, one can clearly see that the reconstructed image shows much blurrier transitions in the vertical direction. This makes sense as the lines that define these vertical transitions are horizontal on the original image. Setting all horizontal details coefficients to zero make only the approximation coefficients information available these lines (low-frequency data, thus, blurry vertical transitions. Similarly, when all vertical coefficients are set to zero the resulting image shows much blurier horizontal transitions.

In [None]:
# question 3.3

# Load image
filename = "img/barb.png"
I = mpimg.imread(filename)

Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves
J = 3

#I_D = pywt.wavedec2(I, wavelet="haar", mode='periodization', level=3, axes=(-2, -1))
I_D = pywt.wavedec2(I, wavelet="db20", mode='periodization', level=3, axes=(-2, -1))

# setting to zero all coeffiecients associated to all details of first level
# first level
I_D[3][2][:] = 0
I_D[3][1][:] = 0
I_D[3][0][:] = 0

# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(I_D)

I_R = pywt.waverec2(I_D, wavelet="db20", mode='periodization', axes=(-2, -1))

# Wavelet visualization
# Graphical representation
fig1 = plt.figure(figsize=(15, 7))
plt.subplot(111)
plot_wavelet(wtcoeffs_arr, Jmin=Jmax - J)

fig2 = plt.figure(figsize=(15, 7))
plt.subplot(121)
plt.imshow(I, cmap="gray")
plt.title("Before")

# Casting coeff in a single array
wtcoeffs_arr, coeff_slices = pywt.coeffs_to_array(I_D)

plt.subplot(122)
plt.imshow(I_R, cmap="gray")
plt.title("After")
# plt.colorbar()
plt.show()


Exercise 3.4

Here it was chosen to eliminate all details coefficients of the first level decomposition (horizontal, vertical and diagonal). These are the coefficients responsible for the higher frequency content on the original image. One can clearly see the effect on the reconstructed image, low-frequency (approximation) content is essentialy the same, but high-frequency (details) are much more attenuated. This is particularly clear in the detail in the pants of the lady, on the left the stripes are clearly visible and on the right they are much blurrier.

---
## Wavelet Denoising <a id="section4"></a> [(&#8593;)](#content)

Wavelets are suitable for denoising signals which contain fast transients. In this context, thresholding can be used to cancel out the wavelet coefficients corresponding to regions where the signal varies smoothly. In these regions, the coefficients are expected to remain small (denoising is then provided by the low-pass cascade), and large wavelet coefficients (corresponding to large and fast transients of the signal) are preserved.

Let's simulate a noisy image to test the denoising procedure described above.

In [None]:
# The original image
filename = "img/boat.png"
I = mpimg.imread(filename)

# The noise
sigma = 0.1  # noise level

[n, m] = I.shape

# Simulating normalized Gaussian random variables
rng = np.random.default_rng(1234)
noise = rng.standard_normal(size=(m, n)) * sigma

# The noisy image
I_noisy = I + noise

# The PSNR quantifies the level of noise (see below for more explanations)
psnr1 = psnr(I, I_noisy, vmax=-1)


# Graphical representation
fig = plt.figure(figsize=(15, 7))
plt.subplot(121)
plt.imshow(I, cmap="gray")
# plt.colorbar()
plt.subplot(122)
plt.imshow(I_noisy, cmap="gray")
plt.title("Noisy image: PSNR = {0:2.2f}dB".format(psnr1))
# plt.colorbar()
plt.show()

#### Peak signal-to-noise ratio: definition
The peak signal-to-noise ratio (PSNR) is defined as the ratio between the maximum possible power of a signal and the power of the noise that affecting its representation. Since many signals have a very wide dynamic range, the PSNR is usually expressed in decibel scale (dB).

The PSNR is usually defined via the mean squared error (MSE). Given a noise-free $m \times n$ monochromatic image $I$ and its noisy approximation $K$, the MSE is defined as:

$$\mathrm{MSE} = \frac{1}{m\,n}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - K(i,j)]^2 = \frac{1}{m\,n} \|I - K \|^2_{\text{F}}.$$

The PSNR in decibel (dB) is defined as:

\begin{align}
\mathrm{PSNR} &= 10 \log_{10} \left( \frac{\mathrm{MAX}_I^2}{\mathrm{MSE}} \right)\\ 
    &= 20 \log_{10} \left( {\mathrm{MAX}_I} \right) - 10 \log_{10} \left( {{\mathrm{MSE}}} \right)
\end{align}

where $\mathrm{MAX}_I$ corresponds to the maximum possible value taken by a pixel of the image. When the pixels are represented using 8 bits per sample, this is 255. The lower the error, the higher the PSNR.

**Extension to color images**: for color images, with containing three values per pixel (RGB representation), the definition of the PSNR is the similar, except the MSE is the sum over all squared value differences (for each color, i.e. three times as many values as in a monochromatic image) divided by 3 times the image size. 

### Thresholding
You can either use the `pywt.threshold` function, or a custom function as implemented below:

In [None]:
def perform_hardthresholding(f, thres):
    return f * (np.abs(f) >= thres)

In [None]:
# Wavelet decomposition
Jmax = int(np.log2(I_noisy.shape[0]))  # maximal number of octaves
J = 5

coeffs = pywt.wavedec2(I_noisy, wavelet="sym8", level=J, mode="periodization")

In [None]:
# Conversion to an array for graphical representation
coeffs_array, coeff_slices = pywt.coeffs_to_array(coeffs)

plt.figure(figsize=(5, 5), dpi=200)
plot_wavelet(coeffs_array, Jmin=Jmax - J)
plt.show()

In the following, we will keep the approximation coefficients unchanged, and only threshold details coefficients. Futher indications are provided below.

In [None]:
# Thresholding detail coefficients only (while preserving approximation coefficients)
# 1) threshold all coefficients stored in an array
# 2) reset approx coefficients to their original value

thres = 2.0 * sigma
coeffs_arr_hard = perform_hardthresholding(coeffs_array, thres)

In [None]:
# convert the array of coeffcients to pywt coeffs
coeffs_hard = pywt.array_to_coeffs(coeffs_arr_hard, coeff_slices)

# Setting original approximation coefficients back
coeffs_hard[0] = coeffs[0]  # approximation preserved

In [None]:
# Back to an array for graphical representation
coeffs_hard_arr, slices = pywt.coeffs_to_array(coeffs_hard)

plt.figure(figsize=(5, 5), dpi=200)
plot_wavelet(coeffs_hard_arr, Jmin=Jmax - J)
plt.show()

#### Reconstruction 

In [None]:
# Multilevel n-dimensional Inverse Discrete Wavelet Transform
I_den = pywt.waverecn(coeffs_hard, "sym8", mode="periodization")

In [None]:
psnr2 = psnr(I, I_den, vmax=-1)
fig = plt.figure(figsize=(15, 7))
plt.subplot(121)
plt.imshow(I_noisy, cmap="gray")
plt.title("Before thresholding: PSNR = {0:2.2f}dB".format(psnr1))
# plt.colorbar()
plt.subplot(122)
plt.imshow(I_den, cmap="gray")
plt.title("After thresholding: PSNR = {0:2.2f}dB".format(psnr2))
# plt.colorbar()
plt.show()

### Exercise 4 <a id="ex4"></a> [(&#8593;)](#content)
1. Observe the effect of the level of the decomposition and of the threshold.
2. Optimize the parameters of your wavelet denoising strategy to get the best possible PSNR. 

In [None]:
# Exercise 4.1

# testing the effect of various levels of decomposition for the same threshold

J = 5

test_values  = [0, 1, 2, 3, 4, 5]

for J in test_values:
    
    coeffs = pywt.wavedec2(I_noisy, wavelet="sym8", level=J, mode="periodization")
    
    # Conversion to an array for graphical representation
    coeffs_array, coeff_slices = pywt.coeffs_to_array(coeffs)
    
    thres = 2.0 * sigma
    coeffs_arr_hard = perform_hardthresholding(coeffs_array, thres)
    
    # convert the array of coeffcients to pywt coeffs
    coeffs_hard = pywt.array_to_coeffs(coeffs_arr_hard, coeff_slices)

    # Setting original approximation coefficients back
    coeffs_hard[0] = coeffs[0]  # approximation preserved
        
    # Multilevel n-dimensional Inverse Discrete Wavelet Transform
    I_den = pywt.waverecn(coeffs_hard, "sym8", mode="periodization")
    
    psnr2 = psnr(I, I_den, vmax=-1)
    fig = plt.figure(figsize=(14, 5))
    
    plt.subplot(121)
    plt.imshow(I_noisy, cmap="gray")
    plt.title("Before thresholding: PSNR = {0:2.2f}dB ".format(psnr1 ))
    
    plt.subplot(122)
    plt.imshow(I_den, cmap="gray")
    plt.title("After thresholding: PSNR = {0:2.2f}dB | decomp. level : {1:2} ".format(psnr2, J ))
    
    plt.show()    
    

In [None]:
# Exercise 4.1

# testing the effect of various thresholds for the same level of decomposition

J = 3   # fixed level of decomposition

thres_values  = [0, 1, 2, 3, 4, 5, 6]

for threshold in thres_values:
    
    coeffs = pywt.wavedec2(I_noisy, wavelet="sym8", level=J, mode="periodization")
    
    # Conversion to an array for graphical representation
    coeffs_array, coeff_slices = pywt.coeffs_to_array(coeffs)
    
    thres = 2.0 * sigma
    coeffs_arr_hard = perform_hardthresholding(coeffs_array, threshold*sigma)
    
    # convert the array of coeffcients to pywt coeffs
    coeffs_hard = pywt.array_to_coeffs(coeffs_arr_hard, coeff_slices)

    # Setting original approximation coefficients back
    coeffs_hard[0] = coeffs[0]  # approximation preserved
    
    # Multilevel n-dimensional Inverse Discrete Wavelet Transform
    I_den = pywt.waverecn(coeffs_hard, "sym8", mode="periodization")
    
    psnr2 = psnr(I, I_den, vmax=-1)
    fig = plt.figure(figsize=(14, 5))
    
    plt.subplot(121)
    plt.imshow(I_noisy, cmap="gray")
    plt.title("Before thresholding: PSNR = {0:2.2f}dB ".format(psnr1 ))
    
    plt.subplot(122)
    plt.imshow(I_den, cmap="gray")
    plt.title("After thresholding: PSNR = {0:2.2f}dB | threshold : {1:1.0f}*sigma ".format(psnr2, threshold ))
    
    plt.show() 

Exercise 4.1

When keeping the threshold fixed and varying the levels of decomposition the result is not particularly interessting. PSNR improves from 0 to 1 and from 1 to 2, but reaches a plateu in which each additional level of decomposition does not add much.  

On the other hand, when varying the thresholds the results are much more interesting. Raising the threshold from 0 unitl 3\*sigma improves the value of PSNR but any further augmentation degrades the image.

Exercise 4.2

The idea developped in the code bellow is of changing the threshold value as a function of the decomposition value. This makes sense as the relative noise energy is different as a function of the decomposition level - we can be much more aggresive in the thresold of the first level of decomposition (high-frequency content) as the possibility of filtering actual image data is smaller.


In [None]:
# Exercise 4.2

# idea : applying a different threshold for each level of decomposition

J = 4   # fixed level of decomposition
Jmax = int(np.log2(I.shape[0]))  # maximal number of octaves

thres_values  = [3.5, 3.0, 1.5, 1.0]

coeffs = pywt.wavedec2(I_noisy, wavelet="sym8", level=J, mode="periodization")

# Conversion to an array for graphical representation
coeffs_array, coeff_slices = pywt.coeffs_to_array(coeffs)

# x = 0 -> first level decomposition
# x = 1 -> second level decomposition
# x = 2 -> third level decomposition ...
for x in range(J):
    coeffs[J-x][2][:] = perform_hardthresholding( coeffs[J-x][2][:] , thres_values[x] * sigma )
    coeffs[J-x][1][:] = perform_hardthresholding( coeffs[J-x][1][:] , thres_values[x] * sigma )
    coeffs[J-x][0][:] = perform_hardthresholding( coeffs[J-x][0][:] , thres_values[x] * sigma )

# Setting original approximation coefficients back
coeffs_hard = coeffs  # approximation preserved

# Back to an array for graphical representation
coeffs_hard_arr, slices = pywt.coeffs_to_array(coeffs_hard)

plt.figure(figsize=(5, 5), dpi=200)
plot_wavelet(coeffs_hard_arr, Jmin=Jmax - J)
plt.show()

# Multilevel n-dimensional Inverse Discrete Wavelet Transform
#I_den = pywt.waverecn(coeffs_hard, "sym8", mode="periodization")

I_den = pywt.waverec2(coeffs, wavelet="sym8", mode='periodization')

psnr2 = psnr(I, I_den, vmax=-1)
fig = plt.figure(figsize=(14, 6))

plt.subplot(121)
plt.imshow(I_noisy, cmap="gray")
plt.title("Before : PSNR = {0:2.2f}dB ".format(psnr1 ))

plt.subplot(122)
plt.imshow(I_den, cmap="gray")
plt.title("After : PSNR = {0:2.2f}dB | thresholds [{1}, {2}, {3}, {4}] ".format(psnr2, thres_values[0],thres_values[1], thres_values[2], thres_values[3] ))

plt.show() 

### Bonus: Exercise 5 <a id="ex5"></a> [(&#8593;)](#content)
Optimize your denoising strategy by using a translation invariant wavelet transform provided by `pywt.swt2` (stationary wavelet transform).

> Note: 
>
> - there is currently an issue in the documentation of the `pywt.swt2` function: use the `trim_approx=True` argument to obtain the same output format as the `pywt.wavedec2` function.
> - due to differences of `pywt.swt2` with respect to `pywt.wavedec2`, the `plot_wavelet` function above will not work as expected. You can directly display the array given by the `pywt.coeffs_to_array` function. Note however that the position of the vertical and horizontal details in this array are interverted compared to what is mentioned in the documentation (which can be seen by oberving the vertical and horizontal patterns, or by comparing with, e.g., the wavelet library in MATLAB).