# Using mit-bih-noise-stress-test-database, Test our models and index performances

## Import

In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import pickle as pkl
import seaborn as sns
import os
import itertools
import sys
import xarray as xr
sys.path.append(os.path.join(os.getcwd(), ".."))
from operations.dataset_manager import get_path_petastorm_format

## Paths

In [None]:
models_save_path = "/workspaces/ecg_evaluation/results"
data_save_path = get_path_petastorm_format("mit-bih-noise-stress-test-database-1.0.0","ParquetFile")
save_path  = "/workspaces/ecg_evaluation/results"

# Get the dataset

In [None]:
## get save files (where model is saved)

model_name = [
                f.split(".")[0]
                for f in os.listdir(models_save_path)
                if os.path.isfile(os.path.join(models_save_path,f )) and (f.endswith(".sav"))
            ]

In [None]:
print(model_name)
#model_name = ['backward_pval_selection', 'hjmi_selection','lgbm_prob', 'L2_reg_logistic']
model_name = ['backward_pval_selection', 'hjmi_selection', 'L2_reg_logistic']

In [None]:
name_index = [
        "Corr_interlead",
        "Corr_intralead",
        "wPMF",
        "SNRECG",
        "HR",
        "Flatline",
        "TSD",
    ]

# Test the model on new dataset

### Please note that you will have to get your dataset ready and your metrics already calculated using the command describe on the git repo. 

In [None]:
data_test_path = "/workspaces/ecg_evaluation/results/mit_bih_noise_test_metrics.nc"
metrics = xr.load_dataset(data_test_path)

In [None]:
## check data metrics 
id_signals = metrics.id.values
metrics_name = metrics.metric_name.values.tolist()
values_metrics = metrics.quality_metrics.values
signal = metrics.signal.values
nb_segment = metrics.number_signal.values[0]

## Test each model performance for each noise

### We will specifically focus on the noises signal. So, let's isolate them :

In [None]:
list_noise_new = ["em","ma","bw"]
list_noise_old = ["oldem","oldma","oldbw"]
index_noise_new = [np.where(id_signals==n)[0][0] for n in list_noise_new]
index_noise_old = [np.where(id_signals==n)[0][0] for n in list_noise_old]

In [None]:
## 3 type of noise. We will reunite the old and new version
noise_data = np.zeros([len(list_noise_new),nb_segment*2,values_metrics.shape[-1]])

for j in range(noise_data.shape[0]):
    noise_data[j,:,:] = np.concatenate((values_metrics[index_noise_new[j],:,:],values_metrics[index_noise_old[j],:,:]),axis = 0)


In [None]:
## Small check concerning the values performance for each noise.
plt.rcParams.update({"font.size": 20})
plt.rcParams["legend.fontsize"] = 20
for n in range(len(list_noise_new)):
    fig,axes = plt.subplots(2, 4, figsize=(20, 15),constrained_layout = True)
    fig.tight_layout(pad=5)
    fig.suptitle(f"Histogram value for noise type {list_noise_new[n]}")
    fig.subplots_adjust(top=0.88)
    palette = itertools.cycle(sns.color_palette())
## take the average result obtained over all the segment of all datasets.

    for i in range(2):
        for j in range(int(len(metrics_name)/2)):
            trial_rinter = noise_data[n,:,i*(int(len(metrics_name)/2))+j]
            axes[i,j].set_title(f"{metrics_name[i*int(len(metrics_name)/2)+j]}")
            axes[i,j].grid()
            axes[i,j].axvline(trial_rinter.mean(),color='k', linestyle='dashed', linewidth=2,label ="mean value")
            axes[i,j].axvline(np.quantile(trial_rinter,0.10),color='red', linestyle='dashed', linewidth=2,label ="q10 acceptable")
            axes[i,j].axvline(np.quantile(trial_rinter,0.90),color='blue', linestyle='dashed', linewidth=2,label ="q90 acceptable")
            #axes[i,j].legend(title = 'Quality')
            sns.histplot(trial_rinter,ax=axes[i,j],color=next(palette))
            plt.setp(axes[i,j].get_xticklabels(), rotation=30, horizontalalignment='right')

    handles, labels = axes[0,0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper left',bbox_to_anchor=(-0.01, 1.01))
            # if np.max(trial_rinter)<1 and np.min(trial_rinter)>0:
            #     axes[i,j].set_xlim([0,1])

### Load model 

In [None]:
def extract_index_label_mit_noise(ds_data,name_metrics,required_index):
    np_metrics = ds_data
    metrics_names = name_metrics
    np_label = np.ones(np_metrics.shape[0]).astype(int)

    df_X = pd.DataFrame(np_metrics, columns=metrics_names)
    df_y = pd.DataFrame(np_label, columns=["y"])

    if required_index is not None:
        df_X = df_X.loc[:, required_index]
    else:
        required_index = df_X.columns.tolist()

    return df_X, df_y

In [None]:
def SQA_method(data,name_metrics,feature_ex,model_path):

    ##give the dataset with noise
    X,_ = extract_index_label_mit_noise(data,name_metrics,feature_ex)
    model = pkl.load(open(model_path,"rb"))
    X = X.values
    y_proba = model.predict_proba(X)
    return y_proba

In [None]:
list_selection_features = [["Corr_interlead", "HR", "wPMF", "TSD"],
                           ['Corr_interlead', 'SNRECG', 'TSD', 'Corr_intralead'],
                           #["Corr_interlead","Corr_intralead","wPMF","SNRECG","HR","Flatline","TSD"],
                           ["Corr_interlead", "SNRECG", "HR", "Corr_intralead", "wPMF"]]

for n in range(len(list_noise_new)):
    fig,axes = plt.subplots(1, 3, figsize=(20, 15),constrained_layout = True)
    fig.tight_layout(pad=5)
    fig.suptitle(f"Histogram value for noise type {list_noise_new[n]} (Unacceptable probability)")
    fig.subplots_adjust(top=0.88)
    palette = itertools.cycle(sns.color_palette())
    ##We will do it for one model (the first one in the list)
    data = noise_data[n,:,:]
    for j in range(len(model_name)):
        y_p = SQA_method(data,metrics_name,list_selection_features[j],os.path.join(models_save_path,model_name[j]+ ".sav"))

        axes[j].axvline(y_p[:,1].mean(),color='k', linestyle='dashed', linewidth=2,label ="mean value")
        axes[j].axvline(np.quantile(y_p[:,1],0.10),color='red', linestyle='dashed', linewidth=2,label ="q10 acceptable")
        axes[j].axvline(np.quantile(y_p[:,1],0.90),color='blue', linestyle='dashed', linewidth=2,label ="q90 acceptable")
        axes[j].set_title(f"{model_name[j]}")
            #axes[i,j].legend(title = 'Quality')
        sns.histplot(y_p[:,1],ax=axes[j],color=next(palette))
        plt.setp(axes[j].get_xticklabels(), rotation=30, horizontalalignment='right')
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper left',bbox_to_anchor=(-0.01, 1.01))
    
