# Using mit-bih-noise-stress-test-database, Test our models and index performances

## Import

In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import pickle as pkl
import seaborn as sns
import os
import itertools
import sys
import xarray as xr
sys.path.append(os.path.join(os.getcwd(), ".."))
from operations.dataset_manager import get_path_petastorm_format

In [None]:
models_save_path = "/workspaces/ecg_evaluation/results"

# Get the dataset

In [None]:
model_name = ['backward_pval_selection.sav', 'hjmi_selection.sav','L2_reg_logistic.sav']
#model_name = ['backward_pval_selection.sav', 'hjmi_selection.sav',"lgbm_prob.sav", 'L2_reg_logistic.sav']
#model_name = ['backward_pval_selection', 'hjmi_selection', 'L2_reg_logistic']

In [None]:
name_index = [
        "Corr_interlead",
        "Corr_intralead",
        "wPMF",
        "SNRECG",
        "HR",
        "Flatline",
        "TSD",
    ]

# Test the model on new dataset

### Please note that you will have to get your dataset ready and your metrics already calculated using the command describe on the git repo. 

In [None]:
data_test_path = "/workspaces/ecg_evaluation/results/mit_bih_noise_test_metrics.nc"
metrics = xr.load_dataset(data_test_path)

In [None]:
## check data metrics 
id_signals = metrics.id.values
metrics_name = metrics.metric_name.values.tolist()
values_metrics = metrics.quality_metrics.values
signal = metrics.signal.values

In [None]:
def index_containing_substring(the_list, substring):
    list_index = []
    for i, s in enumerate(the_list):
        if substring in s:
              list_index.append(i)
    return list_index

## Test each model performance for each noise

### We will specifically focus on the noises signal. So, let's isolate them :

In [None]:
list_noise_new = ["em","ma","bw"]
list_noise_old = ["oldem","oldma","oldbw"]

list_index_oldem = index_containing_substring(id_signals,list_noise_old[0])
list_index_em = index_containing_substring(id_signals,list_noise_new[0])[len(list_index_oldem):]

list_index_oldma = index_containing_substring(id_signals,list_noise_old[1])
list_index_ma = index_containing_substring(id_signals,list_noise_new[1])[len(list_index_oldma):]

list_index_oldbw = index_containing_substring(id_signals,list_noise_old[2])
list_index_bw = index_containing_substring(id_signals,list_noise_new[2])[len(list_index_oldbw):]

In [None]:
dico_index_noise = {"em" : list_index_em,
                    "ma" : list_index_ma,
                    "bw" : list_index_bw,
                    "oldem" : list_index_oldem,
                    "oldma" : list_index_oldma,
                    "oldbw" : list_index_oldbw}

In [None]:
## 3 type of noise. We will reunite the old and new version
noise_data = np.zeros([len(list_noise_new),len(list_index_bw)*2,values_metrics.shape[1],values_metrics.shape[-1]])

for i in range(noise_data.shape[0]):
    ind_new = dico_index_noise[list_noise_new[i]]
    ind_old = dico_index_noise[list_noise_old[i]]
    noise_data[i,:,:,:] = np.concatenate((values_metrics[ind_new,:,:],values_metrics[ind_old,:,:]),axis = 0)

print(noise_data.shape)


In [None]:
## Small check concerning the values performance for each noise.
plt.rcParams.update({"font.size": 20})
plt.rcParams["legend.fontsize"] = 20
fig,axes = plt.subplots(2, 4, figsize=(20, 15),constrained_layout = True)
fig.tight_layout(pad=5)
fig.suptitle(f"Histogram value for each type of noise")
fig.subplots_adjust(top=0.88)
## take the average result obtained over all the segment of all datasets.

for i in range(2):
    for j in range(int(len(metrics_name)/2)):
        palette = itertools.cycle(sns.color_palette())
        for n in range(len(list_noise_new)):
            trial_rinter = noise_data[n,:,:,i*(int(len(metrics_name)/2))+j]
            axes[i,j].set_title(f"{metrics_name[i*int(len(metrics_name)/2)+j]}")
            axes[i,j].grid()
            sns.histplot(trial_rinter.reshape(-1),ax=axes[i,j],color=next(palette),label = list_noise_new[n],alpha = 0.5)
            plt.setp(axes[i,j].get_xticklabels(), rotation=30, horizontalalignment='right')
handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper left',bbox_to_anchor=(-0.01, 1.01))

### Load model 

In [None]:
def extract_index_label_mit_noise(ds_data,name_metrics,required_index):
    np_metrics = ds_data
    metrics_names = name_metrics
    np_label = np.ones(np_metrics.shape[0]).astype(int)

    df_X = pd.DataFrame(np_metrics, columns=metrics_names)
    df_y = pd.DataFrame(np_label, columns=["y"])

    if required_index is not None:
        df_X = df_X.loc[:, required_index]
    else:
        required_index = df_X.columns.tolist()

    return df_X, df_y

In [None]:
def SQA_method(data,name_metrics,feature_ex,model_path,model_name):

    ##give the dataset with noise
    X,_ = extract_index_label_mit_noise(data,name_metrics,feature_ex)
    model = pkl.load(open(model_path,"rb"))
    X = X.values
    y_proba = model.predict_proba(X)
    return y_proba

In [None]:
list_selection_features = [["Corr_interlead", "HR", "wPMF", "TSD"],
                           ['Corr_interlead', 'SNRECG', 'TSD', 'Corr_intralead'],
                           #["Corr_interlead","Corr_intralead","wPMF","SNRECG","HR","Flatline","TSD"],
                           ["Corr_interlead", "SNRECG", "HR", "Corr_intralead", "wPMF"]]


fig,axes = plt.subplots(1, 3, figsize=(20, 15),constrained_layout = True)
fig.tight_layout(pad=5)
fig.suptitle(f"Histogram value for each noise type ('Unacceptable' class probability)")
fig.subplots_adjust(top=0.88)

    
for j in range(len(model_name)):
    palette = itertools.cycle(sns.color_palette())
    for n in range(len(list_noise_new)):
        data = noise_data[n,:,:,:]
        Y_P = np.array([])
        for p in range(data.shape[0]):
            Y_P =np.append(Y_P,SQA_method(data[p,:,:],metrics_name,list_selection_features[j],os.path.join(models_save_path,model_name[j]),model_name[j])[:,1])
        axes[j].set_title(f"{model_name[j].split('.')[0]}")
        sns.histplot(Y_P.reshape(-1),ax=axes[j],color=next(palette),label = list_noise_new[n],alpha = 0.5)
        plt.setp(axes[j].get_xticklabels(), rotation=30, horizontalalignment='right')
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper left',bbox_to_anchor=(-0.01, 1.01))
    
