## Denoising

In [None]:
############
# Packages #
############
import os
import sys
import pickle
import numpy as np
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import seaborn as sns
import matplotlib.pyplot as plt

import pywt
from scipy.signal import periodogram, fftconvolve, cwt

from typing import Dict, Union, List, Tuple, Any, Callable, Optional
pio.renderers.default = "plotly_mimetype+notebook"
################
#    Imports   #
################

root_path = Path(os.getcwd())
saving_path = root_path.joinpath("outputs")
# root_path should end by \nonlinear_ICA:
print(root_path)
sys.path.insert(0, str(root_path))

from src.data import (
    load_ecg_from_clean_data,
    get_npatients_by_diag,
    get_diag
)

from src.plot import (
    plot_all_st,
    plot_signal,
    my_pal,
    add_fig,
    plot_estim,
    plot_scalogram,
    plot_scalogram_freq
)

In [None]:
################################
#  clean meta data  Loading    #
################################
# with open(str(saving_path.joinpath("clean_data.pkl")), "rb") as f:
#     df = pickle.load(f)

with open(str(saving_path.joinpath("signals0.pkl")), "rb") as f:
    signals0 = pickle.load(f)

with open(str(saving_path.joinpath("meta_data.pkl")), "rb") as f:
    df_sub = pickle.load(f)

In [None]:
signal = signals0[0]

In [None]:
plot_signal(signal)

on a des dupliqués car un patient peut avoir plusieurs mesures.

In [None]:
scales = np.arange(1, 1024, 8)
#scales = 2**np.arange(0, 5)

In [None]:
plot_scalogram(signal, scales, waveletname = 'cmor', title = f"Scalogram of original signal", levels = np.linspace(1e-1, 3, 40))

In [None]:
plot_scalogram(signal, scales, waveletname = 'cmor', title = f"Scalogram of original signal", levels = np.linspace(1e-1, 3, 40))

In [None]:
plot_scalogram_freq(signal, scales, waveletname = 'cmor', title = f"Scalogram of original signal", levels = np.linspace(1e-1, 3, 40))

### Low pass filter

In [None]:
def calculate_psnr(signal, newsignal):
    """"Calculating peak signal-to-noise ratio (PSNR) between two signals."""
    mse = np.mean((signal - newsignal) ** 2)
    if mse == 0:
        return 100
    return 20 * np.log10(max(signal) / (np.sqrt(mse)))

def compression_ratio(signal, coeffs):
    sig_size = signal.shape[0]
    wavelet_size = 0
    for coeff in coeffs:
        wavelet_size += np.sum(coeff!=0)
    return sig_size/wavelet_size

def round(arg):
    if isinstance(arg, str):
        return arg
    return np.round(arg,1)

In [None]:
def my_threshold(coeff, arg0, arg1):
    return pywt.threshold(coeff, value=arg0, mode=arg1 )

In [None]:
def denoise_and_metrics(sign, wavelet, args_filter, f_filter = my_threshold, fig = None, color = None ):
    #decomposition
    coeffs = pywt.wavedec(sign, wavelet, mode="per")
    #filtration
    coeffs[1:] = (f_filter(coeff, *args_filter) for coeff in coeffs[1:])
    #reconstruction
    new_signal = pywt.waverec(coeffs, wavelet, mode="per" )
    #metrics
    pnsr = calculate_psnr(new_signal, sign)
    cratio = compression_ratio(sign, coeffs)
    args_filter_round = tuple((round(argf) for argf in args_filter))
    if fig:
        add_fig(fig, new_signal, color ,f"{wavename} {args_filter_round} | {round(pnsr)} {round(cratio)}")
    return pnsr, cratio

In [None]:
from itertools import product

In [None]:
def plot_wave_rec(lst_args, title = f"Low pass", f_filter = my_threshold):
    pal = my_pal(len(list(lst_args)))
    fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = title
    ))
    k = 0
    for k, wave_args in enumerate(lst_args):
        wavename = wave_args[0]
        args = wave_args[1:]
        denoise_and_metrics(signal, wavename, args, my_threshold, fig = fig, color = pal[k]) 
    fig.show()

In [None]:
lst_args = list(product(
    ["db4", "db6","coif2","coif3", "sym4"],
    np.linspace(0.3, 0.5, 2),
    ["soft", "hard"]
))
plot_wave_rec(lst_args, title = f"<b>Low pass</b> <br> threshold mode | pnsr compression-ratio", f_filter = my_threshold)

On peut utiliser un seuil sur les valeurs de la représentation en ondelette pour éliminer le bruit.
la 

In [None]:
def lowpassfilter(signal, thresh = 0.63, wavelet="db4", mode = "soft"):
    thresh = thresh*np.nanmax(signal)
    coeffs = pywt.wavedec(signal, wavelet, mode="per" )
    coeffs[1:] = (pywt.threshold(coeff, value=thresh, mode=mode ) for coeff in coeffs[1:]) # modification de la représentation
    new_signal = pywt.waverec(coeffs, wavelet, mode="per" )
    return new_signal

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Low pass using Daubechy"
    ))

pal = my_pal(20)

k = 0
for wavename in ["db4", "db6","coif2","coif3", "sym4"]:
    for thresh in np.linspace(0.3, 0.5, 2):
        for mode in ["soft", "hard"]: 
            denoise_and_metrics(signal, wavename, (thresh, mode), my_threshold, fig = fig, color = pal[k])
            k+=1
fig.show()

In [None]:
def iter_filter(signals, thresh, wavename, mode):
    new_signals = np.zeros_like(signals)
    for i,sign in enumerate(signals):
        new_signals[i,:] = lowpassfilter(sign, thresh = thresh, wavelet=wavename, mode = mode)
    return new_signals

In [None]:
plot_all_st(signals0)
plot_all_st(iter_filter(signals0, 0.3, "coif2", "soft"))

la regularité de l'ondelette se retrouve dans le signal filtré

## Debruitage par méthode proximale

In [None]:
def power_proximal(q, χ, ξ):
    """
    Compute the proximal of the power q function
    scaled to χ at point ξ
    
    Parameters
    ----------
    q : float
        Exponent in the power function
        Available: q = {1, 4/3, 3/2, 2, 3, 4}
        
    χ : float
        Regularization parameter
        
    ξ : float
        Point at which to compute the proximal operator
    
    Returns
    ---------
    prox : numpy.array
        Value of the proximal of the q power function
        at point ξ
    
    """
    Q_VALUES = [1.0, 4/3, 3/2, 2., 3., 4.]
    
    if q not in Q_VALUES:
        raise ValueError(f"q ({q}) does not belong in the expected "
                         f"values: {Q_VALUES}")
        
    prox = None
    
    if q == 1.0:
        prox = np.sign(ξ) * np.maximum(np.abs(ξ) - χ, 0)
    elif q == 4/3:
        ϵ = np.sqrt(ξ**2 + 256 / 729 * χ ** 3)
        prox = ξ + 4 * χ / (3 * 2 ** (1/3)) * ((ϵ - ξ) ** (1/3) - (ϵ + ξ) ** (1/3))
    elif q == 3/2:
        prox = ξ + 9 * χ ** 2 * np.sign(ξ) / 8 * (1 - np.sqrt(1 + 16 * np.abs(ξ) / (9 * χ ** 2)))
    elif q == 2:
        prox = ξ / (1 + 2 * χ)                    
    elif q == 3.:
        prox = np.sign(ξ) * (np.sqrt(1 + 12 * χ * np.abs(ξ)) - 1) / (6 * χ)
    elif q == 4.:
        ϵ = np.sqrt(ξ ** 2 + 1 / (27 * χ))
        prox = ((ϵ + ξ) / (8 * χ)) ** (1/3) - ((ϵ - ξ) / (8 * χ)) ** (1/3)
        
    return prox

In [None]:
def get_prox_denoise(signal, wavelet="db4", reg_power = 1, reg_cst = 10):
    coeffs = pywt.wavedec(signal, wavelet, mode="per" )
    coeffs[1:] = ( q_power_proximal(reg_power, reg_cst, coeff) for coeff in coeffs[1:]) # modification de la représentation
    new_signal = pywt.waverec(coeffs, wavelet, mode="per" )
    return new_signal

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Power proximal using Daubechy"
    ))

pal = my_pal(32)

k = 0
for wavename in ["coif3", "coif4", "sym4", "sym3"]:
    for reg_power in [1.0, 4/3]:
        for reg_cst in np.linspace(5e-1, 8e-1, 4):
            new_sign = get_prox_denoise(signal, wavelet=wavename, reg_power = reg_power, reg_cst = reg_cst)
            add_fig(fig, new_sign, pal[k] ,f"{wavename} q={np.round(reg_power,2)} beta={np.round(reg_cst,2)} | ")
            #plot_scalogram(new_sign, scales, waveletname = 'cmor', title = f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            #plot_scalogram_freq(new_sign, scales, waveletname = 'cmor', title = f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            k+=1
fig.show()

In [None]:
def iter_prox(signals, reg_power, reg_cst, wavename):
    new_signals = np.zeros_like(signals)
    for i,sign in enumerate(signals):
        new_signals[i,:] = get_prox_denoise(sign, wavelet=wavename, reg_power = reg_power, reg_cst = reg_cst)
    return new_signals

In [None]:
plot_all_st(signals0)
plot_all_st(iter_prox(signals0, 4/3, 0.65, "sym4"))

In [None]:
plot_all_st(signals0)
plot_all_st(iter_prox(signals0, 4/3, 0.65, "coif3"))

In [None]:
plot_all_st(signals0)
plot_all_st(iter_prox(signals0, 1, 0.5, "sym4"))

In [None]:
fig = go.Figure(
        layout=go.Layout(
            height=600, 
            width=800, 
            template = "plotly_dark", 
            title = f"Power proximal using Daubechy"
    ))

pal = my_pal(45)

k = 0
for wavename in ["coif2","coif3", "sym4"]:
    for reg_power in [1.0, 4/3, 3/2]:
        for reg_cst in np.linspace(2e-1, 8e-1, 5):
            new_sign = get_prox_denoise(signal, wavelet=wavename, reg_power = reg_power, reg_cst = reg_cst)
            add_fig(fig, new_sign, pal[k] ,f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            #plot_scalogram(new_sign, scales, waveletname = 'cmor', title = f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            #plot_scalogram_freq(new_sign, scales, waveletname = 'cmor', title = f"{wavename} power {np.round(reg_power,2)} reg_cst = {np.round(reg_cst,2)}")
            k+=1
fig.show()

Il ne faut pas prendre une ondellette trop régulière au risque d'éliminer les pics.
Plus la puissance augmente, moins le denoising est forts. On remarque que cette approche est moins sensible que le filtrage par seuil.

## Test classifier

In [None]:
from collections import Counter
import scipy

def calculate_entropy(list_values):
    counter_values = Counter(list_values).most_common()
    probabilities = [elem[1]/len(list_values) for elem in counter_values]
    entropy=scipy.stats.entropy(probabilities)
    return entropy

 
def calculate_statistics(list_values):
    n5 = np.nanpercentile(list_values, 5)
    n25 = np.nanpercentile(list_values, 25)
    n75 = np.nanpercentile(list_values, 75)
    n95 = np.nanpercentile(list_values, 95)
    median = np.nanpercentile(list_values, 50)
    mean = np.nanmean(list_values)
    std = np.nanstd(list_values)
    var = np.nanvar(list_values)
    rms = np.nanmean(np.sqrt(list_values**2))
    return [n5, n25, n75, n95, median, mean, std, var, rms]
 
# def calculate_crossings(list_values):
#     zero_crossing_indices = np.nonzero(np.diff(np.array(list_values), 0))[0]
#     no_zero_crossings = len(zero_crossing_indices)
#     mean_crossing_indices = np.nonzero(np.diff(np.array(list_values) , np.nanmean(list_values)))[0]
#     no_mean_crossings = len(mean_crossing_indices)
#     return [no_zero_crossings, no_mean_crossings]
 
def get_features(list_values):
    print(list_values.shape[0])
    entropy = calculate_entropy(list_values)
    # crossings = calculate_crossings(list_values)
    statistics = calculate_statistics(list_values)
    # return [entropy] + crossings + statistics
    return [entropy] +  statistics

def get_ecg_features(ecg_data, ecg_labels, waveletname):
    list_features = []
    list_unique_labels = list(set(ecg_labels))
    list_labels = [list_unique_labels.index(elem) for elem in ecg_labels]
    for sig in ecg_data:
        list_coeff = pywt.wavedec(sig, waveletname)
        features = []
        for coeff in list_coeff:
            features += get_features(np.array(coeff))
        list_features.append(features)
    return list_features, list_labels

In [None]:
patient_ids = get_npatients_by_diag(df, npatients=100)
signals, patients = load_ecg_from_clean_data(df, root_path, patient_ids = patient_ids.tolist())
signals0 = np.array([sig[:,0] for sig in signals])

In [None]:
signals0.shape

In [None]:
labels = df.loc[df["patient_id"].isin(patients),"diag"]
list_features, list_labels = get_ecg_features(signals0, labels, "db5")

In [None]:
print(len(list_features), len(list_labels))

In [None]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split( list_features, list_labels, test_size=0.33, random_state=42)


In [None]:
cls = GradientBoostingClassifier(n_estimators=20)
cls.fit(X_train, y_train)
train_score = cls.score(X_train, y_train)
test_score = cls.score(X_test, y_test)
print(f"Train Score for the ECG dataset is about: {np.round(train_score,2)}")
print(f"Test Score for the ECG dataset is about: {np.round(test_score,2)}")
