In [34]:
import pickle
import numpy as np
import pandas as pd
# import os
# %matplotlib inline
# import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import roc_curve, roc_auc_score
from scipy.interpolate import interp1d

import plotly.graph_objects as go

from ipywidgets import widgets, Layout, Label, HBox, VBox, Box
from IPython.display import clear_output, display, Markdown

In [35]:
# Importamos los modelos
with open('.//model_app//clf_rl.pickle', 'rb') as f:
    model = pickle.load(f)

# Importamos los pred prob
with open('.//model_app//pred_prob_rl.pickle', 'rb') as f:
    probas = pickle.load(f)

# Importamos los true label 
with open('.//model_app//true_label_rl.pickle', 'rb') as f:
    labels = pickle.load(f)


Trying to unpickle estimator LogisticRegression from version 1.0.1 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations



In [36]:
style = {'description_width': 'initial'}

output_all = widgets.Output()
output_classification = widgets.Output()


def run_all():
    
    ######################
    # DATA INPUT WIDGETS #
    ######################
    CEV1 = widgets.Checkbox(value=False, continuous_update = True,
                            description= "¿A veces tiene algún dolor o molestia en el pecho?",
                            disabled=False, indent=False, layout = Layout(width = '450px'))
    
    LAB3 = widgets.BoundedFloatText(value=0.85,  min=0.10, max=3, step=0.05,
                                    description= "Creatinina",
                                    disabled=False, continuous_update=False,
                                    orientation='horizontal',readout=True,readout_format='.1f',
                                    style = style, layout = Layout(width = '125px'))
    
    CEV25 = widgets.BoundedFloatText(value=1.06,  min=0.20, max=3, step=0.02,
                                    description= "Índice tobillo brazo derecho/izquierdo",
                                    disabled=False, continuous_update=False,
                                    orientation='horizontal',readout=True,readout_format='.1f',
                                    style = style, layout = Layout(width = '280px'))
    
    ACF8 = widgets.Dropdown(options=[('Bajo o inactivo', 1), ('Moderado', 2), ('Alto', 3)],
                            value=2,
                            description='Nivel de actividad',
                            style = style, disabled=False, layout = Layout(width = '230px'))
    
    
    ESDG8 = widgets.Dropdown(options=[('Menos de 1000€/mes', 1), ('Otros', 0)],
                             value=0,
                             description='Ingresos totales familiares',
                             style = style, disabled=False, layout = Layout(width = '320px'))
    
    CEV21_1 = widgets.Dropdown(options=[('Ninguno', 0), ('Diastólico', 1), ('sistólico', 2),
                                        ('Sistólico y diastólico', 3)],
                             value=0,
                             description='Soplos cardiacos',
                             style = style, disabled=False, layout = Layout(width = '270px'))
    
    DIABETES = widgets.Checkbox(value=False, continuous_update = True,
                            description= "Diabetes",
                            disabled=False, indent=False, layout = Layout(width = '230px'))
    
    DISLIPEMIA = widgets.Checkbox(value=False, continuous_update = True,
                            description= "Dislipemia",
                            disabled=False, indent=False, layout = Layout(width = '230px'))
    
    CEV12 = widgets.Dropdown(options=[('I. No limitación actividad física',1),
                                      ('II. Ligera limitación de la actividad física',2),
                                      ('III. Marcada limitación de la actividad física',3),
                                      ('IV. Incapacidad para llevar a cabo cualquier actividad física',4)],
                             value=1,
                             description='Grado disnea NYHA',
                             style = style, disabled=False, layout = Layout(width = '500px'))
    
    FRH31 = widgets.Checkbox(value=False, continuous_update = True,
                            description= "Enfermedad reumatológica/inmunológica",
                            disabled=False, indent=False, layout = Layout(width = '300px'))
    
    FRH37 = widgets.Checkbox(value=False, continuous_update = True,
                            description= "EPOC",
                            disabled=False, indent=False, layout = Layout(width = '230px'))
    
    FRH33 = widgets.Checkbox(value=False, continuous_update = True,
                            description= "Historia de cáncer",
                            disabled=False, indent=False, layout = Layout(width = '230px'))
     
    ######################
    ######  RESULT  ######
    ######################
    output_texto_groups = widgets.Output()
    output_classification = widgets.Output()
    output_roc = widgets.Output()

    fpr1, tpr1, thres1 = roc_curve(labels, probas)
    def spe_thres(sensitivity):
        spe = 1- interp1d(tpr1, fpr1)(sensitivity)
        thres = interp1d(tpr1, thres1)(sensitivity)
        return spe, thres
    def sens_spe(threshold):
        sens = interp1d(thres1, tpr1)(threshold)
        spe = 1 - interp1d(thres1, fpr1)(threshold)
        return sens, spe

    def groups_risk(threshold):
        preds = pd.Series(probas > threshold)
        p = pd.crosstab(pd.Series(labels), preds).values
        risk_high = 100*p[1,1] / (p[1,1] + p[0, 1])
        n_patients_high = 100*preds.value_counts(normalize = True)[True]
        risk_moderate = 100*p[1,0] / (p[1,0] + p[0, 0])
        n_patients_moderate = 100*preds.value_counts(normalize = True)[False]
        return risk_high, n_patients_high, risk_moderate, n_patients_moderate
    
    prevalence = 0.209
    global sensitivity, specificity, threshold, risk_high, n_patients_high, risk_moderate, n_patients_moderate
    sensitivity = 60 / 100
     
    specificity, threshold = spe_thres(sensitivity)
    risk_high, n_patients_high, risk_moderate, n_patients_moderate = groups_risk(threshold)
    punto = None
    
    classification_dict = {1: 'Alto Riesgo', 0: 'No-alto riesgo'}
    run = widgets.Button(description='Ejecutar el modelo')
    
    def run_click(w):
#         global texto1, punto, fig
        global texto1, punto, threshold
        feature_values = [CEV1.value, LAB3.value, CEV25.value, ACF8.value, ESDG8.value, CEV21_1.value, DIABETES.value,
                          DISLIPEMIA.value, CEV12.value, FRH31.value, FRH37.value, FRH33.value]
        punto = model.predict_proba(np.array(feature_values).reshape(1,-1))[0][1]
        classification = classification_dict[1*(punto > threshold)]
        sensitivity_point, specificity_point = sens_spe(punto)
        try:
            fig.data[3].x = [1-specificity_point]
            fig.data[3].y = [sensitivity_point]
        except:
            fig.add_trace(go.Scatter(x = [1-specificity_point], y = [sensitivity_point],
                             mode = 'markers', name = 'Your point',
                             marker = dict(color='white', line_width=2, size = 10)))   
        with output_classification:
            clear_output()
            texto_clasificacion = f'### Resultado del modelo: **{100*punto:0.1f}% de riesgo \
                            de progresión grave**.\n### Paciente clasificado como: **{classification}**'
            display(Markdown(texto_clasificacion))
            
        with output_roc:
            clear_output()
            display(fig)
    
    run.on_click(run_click)
    
    ### Creación de la figura
    w = specificity
    trace1 = go.Scatter(x=[1-w,1,1,1-w, 1-w], y=[0,0,1,1,0],
                        fill="toself", mode = 'none', fillcolor = 'rgba (103,210,69, 0.5)',
                        name= 'Zona de no-alto riesgo')
    trace2 = go.Scatter(x=[0,1-w,1-w,0,0], y=[0,0,1,1,0],
                        fill="toself", mode = 'none',  fillcolor = 'rgba(255,31,31,0.5)',
                        name = 'Zona de alto riesgo')
    
    fig = go.FigureWidget(data = [trace1, trace2])
    fig.add_trace(go.Scatter(x = fpr1, y = tpr1, mode = 'lines', name = 'CURVA ROC',
                            line=dict(color='indigo', width=2)))
    
    fig.update_layout(title = 'Curva ROC, grupos de riesgo y resultado del modelo',
                      xaxis_range=[-0.01,1.01], yaxis_range = [-0.01,1.01],
                      xaxis_title = '1 - Especificidad',
                      yaxis_title = 'Sensibilidad',
                      width = 600, height = 400,  
                      margin=dict(l=50, r=50, t=50, b=50))
    
    
    
    ######################
    ####  HEAD BOARD  ####
    ######################
    cabecera_text = '# Riesgo de sufrir cardiopatía estructural \n'
    
#     cabecera2_text = 'Aplicación web del riesgo de sufrir cardiopatía estructural.\
#                     El código empleado para el entrenamiento y validación de los modelos \
#                     se encuentra en <https://github.com/PabloPerSa/StructuralHeartDiseasePrediction>.'

    cabecera2_text = 'Aplicación web del riesgo de sufrir cardiopatía estructural.'
    
    cabecera = Markdown(cabecera_text)
    cabecera2 = Markdown(cabecera2_text)
    output_cabecera = widgets.Output()
    
    with output_cabecera:
        display(cabecera)
        display(cabecera2)
    
    cabecera_final = HBox([output_cabecera], layout = Layout(width = '1240px', justify_content = 'space-between'))
    
    ######################
    ######  FOOTER  ######
    ######################
    pie_text = '**AVISO: Esta calculadora se ha desarrollado como un PROTOTIPO para la demostración \
    de las posibilides de incluir Inteligencia Artificial en clasificación dle riesgo de sufrir cardiopatía \
    estructural. Esta calculadora está dirigida a científicos de datos y profesionales de la salud con interés en \
    desarrollar una herramienta de similares características y funcional. Ninguno de las resultados \
    mostrados se puede interpretar como consejo médico. Si usted cree que puede padecer alguna cardiopatia \
    consulte con su cardiólogo o médico de cabecera.**'
    
    pie = Markdown(pie_text)
    output_footer = widgets.Output()
    
    with output_footer:
        display(pie)
    
    pie_final = HBox([output_footer], layout = Layout(width = '1240px', justify_content = 'space-between'))
    
    #####################
    ######  VOILA  ######
    #####################
    
    output = widgets.Output()
    
    tab2a = VBox(children=[CEV1, LAB3, CEV25, ACF8, ESDG8, CEV21_1, DIABETES,
                          DISLIPEMIA, CEV12, FRH31, FRH37, FRH33],
                         layout = Layout(width = '520px', justify_content = 'flex-start'))
    
    DATA_BOX = VBox([tab2a, run], layout = Layout(width = '520px', justify_content = 'flex-start'))
    
#     RES_BOX = widgets.VBox([output], layout = Layout(width = '720px'))
    RES_BOX = VBox([output_classification, output_roc], 
                         layout = Layout(width = '720px', justify_content = 'center'))

    TOTAL_BOX = VBox([HBox([DATA_BOX, RES_BOX]),
                      pie_final],layout = Layout(width = '1240px', justify_content = 'flex-start'))
                      
    
    global output_all
    with output_all:
        clear_output()
        display(cabecera_final)
        display(TOTAL_BOX)
#         display(pie_final)



In [37]:
run_all()
display(output_all)

Output()