## Wavelet tests

In [None]:
############
# Packages #
############
import os
import sys
import pickle
import numpy as np
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import seaborn as sns
import matplotlib.pyplot as plt

import pywt
from scipy.signal import periodogram, fftconvolve, cwt

from typing import Dict, Union, List, Tuple, Any, Callable, Optional
pio.renderers.default = "plotly_mimetype+notebook"
################
#    Imports   #
################

root_path = Path(os.getcwd())
saving_path = root_path.joinpath("outputs")
# root_path should end by \nonlinear_ICA:
print(root_path)
sys.path.insert(0, str(root_path))

from src.data import (
    load_meta,
    add_superclasse,
    describ_raw_df,
    load_ecg,
    load_ecg_from_clean_data,
    plot_all_st
)

In [None]:
################################
#  clean meta data  Loading    #
################################
with open(str(saving_path.joinpath("clean_data.pkl")), "rb") as f:
    df = pickle.load(f)

In [None]:
############################
#  500hz signals  Loading  #
############################
signals = load_ecg_from_clean_data(df, root_path, patient_ids = df.loc[1:100,"patient_id"].tolist())

In [None]:
plot_all_st(signals[0].T, clustering=None, title="<b>Signals 500Hz</b> patient 0")

### Wavelets mothers and fathers

In [None]:
def add_fig(fig, signal, color, name):
    fig.add_trace(go.Scatter(y=signal, 
                 mode="lines", 
                 line=dict(
                     width=2,
                     color=color,
                 ),
                 opacity = 0.6,
                 name=name
                )
             )

def add_fig_continuous(level, w, fig, signal, color, wave):
    psi, x = w.wavefun(level=level)
    if not w.complex_cwt:
        add_fig(fig, psi, color, f"{wave} wavelet continuous function psi")
    else:
        add_fig(fig, np.real(psi), color, f"{wave} wavelet continuous function psi real part")

def add_fig_discret_ortho(level, w, fig, signal, color, wave):
    (phi, psi, x) = w.wavefun(level=level)
    add_fig(fig, phi, color, f"{wave} scaling function phi")
    add_fig(fig, psi, color, f"{wave} wavelet function psi")

def add_fig_discret_biortho(level, w, fig, signal, color, wave):
    (phi_d, psi_d, phi_r, psi_r, x) = w.wavefun(level=level)
    add_fig(fig, phi_d, color, f"{wave} decomposition scaling phi_d")
    add_fig(fig, psi_d, color, f"{wave} decomposition wavelet psi_d")
    add_fig(fig, phi_r, color, f"{wave} reconstruction scaling phi_r")
    add_fig(fig, psi_r, color, f"{wave} reconstruction wavelet psi_r")

def my_pal(n):
    return sns.color_palette("Spectral", n).as_hex()

def plot_wave_family(level, type, iscontinous = False):
    lst_wave = pywt.wavelist(type)
    nwave = len(lst_wave)
    pal = my_pal(nwave)
    fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Wavelets {type}"
    ))
    
    for i in range(nwave):
        wave = lst_wave[i]
        w = pywt.DiscreteContinuousWavelet(wave)
        if iscontinous:
            # w = pywt.ContinuousWavelet(wave)
            add_fig_continuous(level, w, fig, signal, pal[i], wave)
        else:
            # w = pywt.Wavelet(wave)
            if w.orthogonal:
                add_fig_discret_ortho(level, w, fig, signal, pal[i], wave)
            else:
                add_fig_discret_biortho(level, w, fig, signal, pal[i], wave)

    fig.update_layout(title = w.family_name + f" code={type}  level = {level} ")
    fig.show()    

In [None]:
level = 5

lst_types = pywt.families()
sep_continous = lst_types.index("gaus")

lst_discret = lst_types[:sep_continous]
lst_continuous = lst_types[sep_continous:]

for type in lst_discret:
    plot_wave_family(level, type, iscontinous = False)
    
for type in lst_continuous:
    plot_wave_family(level, type, iscontinous = True)

In [None]:
for level in range(1,10,2 ):
    plot_wave_family(level, "db", iscontinous = False)
    
for level in range(1,10,2):
    plot_wave_family(level, "morl", iscontinous = True)

### Discrete Wavelet Transform

In [None]:
def plot_signal(vec, title = "signal"):
    fig = px.line(vec, template = "plotly_dark", title = title)
    fig.show()

def plot_estim(xb, true_x, title = "estimate xbar"):
    fig = px.line(xb, template = "plotly_dark", title = title)
    fig.add_traces(
    list(px.line(true_x.copy()).select_traces())
    )
    fig.data[0].line.color = "red"
    fig.show()

In [None]:
signal = signals[0][:,0]

plot_signal(signal, title = "test signal: patient 0 electrode 0")

A partir de la décomposition en ondellette on obtient l'approximation et le détail (ie le reste)

In [None]:
approx, detail = pywt.dwt(signal, 'db2')

plot_signal(approx, title = "Approximation coefficients")
plot_signal(detail, title = "Details coefficients")

# on reconstruit le signal

In [None]:
# change les détails
estim = pywt.idwt(approx, np.zeros_like(detail), 'db2')
#change detail
plot_estim(estim, signal, title = "Wavelet reconstruction")

In [None]:
w = pywt.DiscreteContinuousWavelet("sym3")
approx2, detail2 = pywt.dwt(signal, wavelet=w, mode="constant")
plot_signal(approx, title = "Approximation coefficients")
plot_signal(detail, title = "Details coefficients")

estim2 = pywt.idwt(approx2, np.zeros_like(detail2), "sym3")
plot_estim(estim2, signal, title = "Wavelet reconstruction 2")

## Multilevel DWT, IDWT

In [None]:
w = pywt.DiscreteContinuousWavelet("sym3")

coeffs = pywt.wavedec(signal, w, level = 5 )
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Wavelets multilevel decomposition"
    ))
pal = my_pal(len(coeffs))

for i,coeff in enumerate(coeffs):
    add_fig(fig, coeff, pal[i] ,f"coeff {i}")
fig.show()

In [None]:


fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Wavelets multilevel reconstruction"
    ))
pal = my_pal(len(coeffs))

coeffs2 = [np.zeros_like(coeff) for coeff in coeffs]

for i,coeff in enumerate(coeffs):
    coeffs2[i] = coeff
    add_fig(fig, pywt.waverec(coeffs2, w), pal[i] ,f"add coeff {i}")
fig.show()


### SWT

The stationary wavelet transform (SWT) is a wavelet transform algorithm designed to overcome the lack of translation-invariance of the discrete wavelet transform (DWT). Translation-invariance is achieved by removing the downsamplers and upsamplers in the DWT and upsampling the filter coefficients by a factor of jth level of the algorithm. The SWT is an inherently redundant scheme as the output of each level of SWT contains the same number of samples as the input – so for a decomposition of N levels there is a redundancy of N in the wavelet coefficients. This algorithm is more famously known as "algorithme à trous" in French (word trous means holes in English) which refers to inserting zeros in the filters. It was introduced by Holschneider et al.

In [None]:
(cA2, cD2), (cA1, cD1) = pywt.swt(signal, w, level=2)

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Stationary wavelet transform"
    ))
pal = my_pal(4)

add_fig(fig, cA1, pal[0] ,f"Aproximation 1")
add_fig(fig, cD1, pal[1] ,f"Detail 1")
add_fig(fig, cA2, pal[2] ,f"Aproximation 2")
add_fig(fig, cD2, pal[3] ,f"Detail 2")
add_fig(fig, cA2, "blue" ,f"Signal")
fig.show()

In [None]:
#on a une relation de récurrence sur la définition des niveau : 
# [(cA2, cD2)] = pywt.swt(cA1, db1, level=1, start_level=1)

In [None]:
coeffs_swt = pywt.swt(signal, w)
print(len(coeffs_swt))

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"details SWT"
    ))
pal = my_pal(len(coeffs))


for i,coeff in enumerate(coeffs_swt):
    detail = coeffs_swt[i][1]
    add_fig(fig, detail, pal[i] ,f"detail {len(coeffs)-i}")
fig.show()

On a bien invariance par translation sur la représentation

### Wavelet packet

wavelet packets or subband tree,  is discrete-time (sampled) signal is passed through more filters than the discrete wavelet transform (DWT).In the DWT, each level is calculated by passing only the previous wavelet approximation coefficients (cAj) through discrete-time low- and high-pass quadrature mirror filters. However, in the WPD, both the detail (cDj (in the 1-D case), cHj, cVj, cDj (in the 2-D case)) and approximation coefficients are decomposed to create the full binary tree.

There are several algorithms for subband tree structuring that find a set of optimal bases that provide the most desirable representation of the data relative to a particular cost function (entropy, energy compaction, etc.)

In [None]:
wp = pywt.WaveletPacket(data=signal, wavelet=w)

In [None]:
print(wp.maxlevel)

#### Un arbre de données

In [None]:
print(wp.data)

In [None]:
print(wp["a"].data)

On récupère les noeuds dans l'odre de filtrage

In [None]:
print([node.path for node in wp.get_level(3, 'natural')])

On récupère les noeuds dans l'ordre de fréquence 

In [None]:
print([node.path for node in wp.get_level(3, 'freq')])

In [None]:
lst_coeffs_by_f = [node.path for node in wp.get_level(wp.maxlevel, 'freq')]

print(lst_coeffs_by_f[:10])

#### On modifie et reconstruit

In [None]:
wp2 = pywt.WaveletPacket(data=None, wavelet=w)
for i,node in enumerate(lst_coeffs_by_f):
    if i<10:
        wp2[node] = wp[node].data
    else:
        wp2[node] = np.zeros_like(wp[node].data)


In [None]:
new_sign = wp2.reconstruct(update=True)

In [None]:
plot_estim(new_sign, signal, title = "Wavelet reconstruction packet decomposition keep")

On peut utiliser un score utilisé en compression d'image

In [None]:
def calculate_psnr(img1, img2, max_value=255):
    """"Calculating peak signal-to-noise ratio (PSNR) between two images."""
    mse = np.mean((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32)) ** 2)
    if mse == 0:
        return 100
    return 20 * np.log10(max_value / (np.sqrt(mse)))

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Wavelet Packet : keep low frequencies"
    ))
lst_nkeep = list(range(5,50,5))
pal = my_pal(len(lst_nkeep))

for k,nkeep in  enumerate(lst_nkeep):
    wp2 = pywt.WaveletPacket(data=None, wavelet=w)
    for i,node in enumerate(lst_coeffs_by_f):
        if i<nkeep:
            wp2[node] = wp[node].data
        else:
            wp2[node] = np.zeros_like(wp[node].data)
    new_sign = wp2.reconstruct(update=True)
    psnr = calculate_psnr(signal, new_sign[:signal.shape[0]], max_value=np.max(signal))
    add_fig(fig, new_sign, pal[k] ,f"packet decomposition keep {nkeep} PSNR {np.round(psnr,3)}")
fig.show()

## Scalogram continuous

C'est super long!

In [None]:
fs = 500
scales = np.arange(1,signal.shape[0], 10)
print(scales)
coef, freqs = pywt.cwt(signal, scales=scales, wavelet='gaus1')

In [None]:
print(coef.shape)

In [None]:
sns.heatmap(np.abs(coef))
plt.xticks(np.arange(0, 5000/fs, 5000/(scales[0]*fs)))
plt.show()