## Denoising

In [None]:
############
# Packages #
############
import os
import sys
import pickle
import numpy as np
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import seaborn as sns
import matplotlib.pyplot as plt

import pywt
from scipy.signal import periodogram, fftconvolve, cwt

from typing import Dict, Union, List, Tuple, Any, Callable, Optional
pio.renderers.default = "plotly_mimetype+notebook"
################
#    Imports   #
################

root_path = Path(os.getcwd())
saving_path = root_path.joinpath("outputs")
# root_path should end by \nonlinear_ICA:
print(root_path)
sys.path.insert(0, str(root_path))

from src.data import (
    load_ecg_from_clean_data,
)

from src.plot import (
    plot_all_st,
    plot_signal,
    my_pal,
    add_fig,
    plot_estim,
)

In [None]:
################################
#  clean meta data  Loading    #
################################
with open(str(saving_path.joinpath("clean_data.pkl")), "rb") as f:
    df = pickle.load(f)
print(df.head())

In [None]:
signals = load_ecg_from_clean_data(df, root_path, patient_ids = df.loc[1:100,"patient_id"].tolist())
signal = signals[0][:,0]

### Low pass filter

On peut utiliser un seuil sur les valeurs de la représentation en ondelette pour éliminer le bruit.
la 

In [None]:
def lowpassfilter(signal, thresh = 0.63, wavelet="db4", mode = "soft"):
    thresh = thresh*np.nanmax(signal)
    coeffs = pywt.wavedec(signal, wavelet, mode="per" )
    coeffs[1:] = (pywt.threshold(coeff, value=thresh, mode=mode ) for coeff in coeffs[1:]) # modification de la représentation
    new_signal = pywt.waverec(coeffs, wavelet, mode="per" )
    return new_signal


In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Low pass using Daubechy"
    ))

pal = my_pal(20)

k = 0
for wavename in ["db5", "db10"]:
    for thresh in np.linspace(0.1, 0.9, 5):
        for mode in ["soft", "hard"]:
            new_sign = lowpassfilter(signal, thresh = thresh, wavelet=wavename, mode = mode)
            add_fig(fig, new_sign, pal[k] ,f"{wavename} {mode} threshold = {np.round(thresh,2)}")
            k+=1
fig.show()

la regularité de l'ondelette se retrouve dans le signal filtré

## Debruitage par méthode proximale

In [None]:
def q_power_proximal(q, χ, ξ):
    """
    Compute the proximal of the power q function
    scaled to χ at point ξ
    
    Parameters
    ----------
    q : float
        Exponent in the power function
        Available: q = {1, 4/3, 3/2, 2, 3, 4}
        
    χ : float
        Regularization parameter
        
    ξ : float
        Point at which to compute the proximal operator
    
    Returns
    ---------
    prox : numpy.array
        Value of the proximal of the q power function
        at point ξ
    
    """
    Q_VALUES = [1.0, 4/3, 3/2, 2., 3., 4.]
    
    if q not in Q_VALUES:
        raise ValueError(f"q ({q}) does not belong in the expected "
                         f"values: {Q_VALUES}")
        
    prox = None
    
    if q == 1.0:
        prox = np.sign(ξ) * np.maximum(np.abs(ξ) - χ, 0)
    elif q == 4/3:
        ϵ = np.sqrt(ξ**2 + 256 / 729 * χ ** 3)
        prox = ξ + 4 * χ / (3 * 2 ** (1/3)) * ((ϵ - ξ) ** (1/3) - (ϵ + ξ) ** (1/3))
    elif q == 3/2:
        prox = ξ + 9 * χ ** 2 * np.sign(ξ) / 8 * (1 - np.sqrt(1 + 16 * np.abs(ξ) / (9 * χ ** 2)))
    elif q == 2:
        prox = ξ / (1 + 2 * χ)                    
    elif q == 3.:
        prox = np.sign(ξ) * (np.sqrt(1 + 12 * χ * np.abs(ξ)) - 1) / (6 * χ)
    elif q == 4.:
        ϵ = np.sqrt(ξ ** 2 + 1 / (27 * χ))
        prox = ((ϵ + ξ) / (8 * χ)) ** (1/3) - ((ϵ - ξ) / (8 * χ)) ** (1/3)
        
    return prox

In [None]:
def get_prox_denoise(signal, wavelet="db4", reg_power = 1, reg_cst = 10):
    coeffs = pywt.wavedec(signal, wavelet, mode="per" )
    coeffs[1:] = ( q_power_proximal(reg_power, reg_cst, coeff) for coeff in coeffs[1:]) # modification de la représentation
    new_signal = pywt.waverec(coeffs, wavelet, mode="per" )
    return new_signal

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Power proximal using Daubechy"
    ))

pal = my_pal(40)

k = 0
for wavename in ["db5", "db10", "db30"]:
    for reg_power in [1.0, 4/3, 3/2]:
        for reg_cst in np.linspace(1e-1, 6e-1, 4):
            new_sign = get_prox_denoise(signal, wavelet=wavename, reg_power = reg_power, reg_cst = reg_cst)
            add_fig(fig, new_sign, pal[k] ,f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            k+=1
fig.show()

Il ne faut pas prendre une ondellette trop régulière au risque d'éliminer les pics.
Plus la puissance augmente, moins le denoising est forts. On remarque que cette approche est moins sensible que le filtrage par seuil.