In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from joblib import load
import gradio as gr
import os
import warnings
warnings.filterwarnings('ignore')

# Define paths
models_path = '../models/'
processed_data_path = '../data/processed/'

# Load the trained Neural Network model
model = load_model(os.path.join(models_path, 'neural_net_tuned.h5'))

# Load the preprocessor (e.g., ColumnTransformer)
preprocessor = load(os.path.join(processed_data_path, 'preprocessor.joblib'))
output_feature_names = preprocessor.get_feature_names_out()  # 48 output features
# Attempt to get input feature names (may need adjustment based on your ColumnTransformer)
try:
    input_features = preprocessor.named_transformers_['columntransformer'].get_feature_names_out() if 'columntransformer' in preprocessor.named_transformers_ else [f"feature_{i}" for i in range(22)]
except KeyError:
    print("Warning: Could not access input features directly. Assuming 22 features. Please verify.")
    input_features = [f"feature_{i}" for i in range(22)]  # Fallback: 22 generic features
print(f"Expected input features: {input_features}")  # Debug: Print to verify

# Define label mapping
label_map = {0: 'No Diabetes', 1: 'Yes, no complications', 2: 'Yes, with complications'}

# Create a function to make predictions with input validation
def predict_diabetes(num__MentHlth, num__PhysHlth, cat__Sex_1_0, cat__AgeGroup, cat__Income, cat__Education_,
                     cat__Asthma, cat__COPD, num__BMI, cat__HighBP, cat__Smoker, cat__HighChol,
                     cat__CholCheck, cat__Stroke, cat__PhysActivity, cat__HeartDiseaseorAttack,
                     cat__HvyAlcoholConsump, cat__NoDocbcCost, cat__DiffWalk, num__GenHlth, cat__Diabetes, cat__AnyHealthcare):
    # Validate inputs
    if not (0 <= num__MentHlth <= 30):
        return "Error: Mental Health days must be between 0 and 30.", 0.0
    if not (0 <= num__PhysHlth <= 30):
        return "Error: Physical Health days must be between 0 and 30.", 0.0
    if cat__Sex_1_0 not in [0, 1]:
        return "Error: Sex must be 0 (female) or 1 (male).", 0.0
    if not (0 <= cat__AgeGroup <= 80):
        return "Error: Age must be between 0 and 80.", 0.0
    if not (1 <= cat__Income <= 11):
        return "Error: Income must be between 1 and 11.", 0.0
    if not (1 <= cat__Education_ <= 6):
        return "Error: Education level must be between 1 and 6.", 0.0
    if not (0 <= cat__Asthma <= 3):
        return "Error: Asthma level must be between 0 and 3.", 0.0
    if not (0 <= cat__COPD <= 3):
        return "Error: COPD level must be between 0 and 3.", 0.0
    if not (0 <= num__BMI <= 1):
        return "Error: BMI must be between 0 and 1 (normalized).", 0.0
    for cat_feature, name in [(cat__HighBP, 'cat__HighBP'), (cat__Smoker, 'cat__Smoker'),
                             (cat__HighChol, 'cat__HighChol'), (cat__CholCheck, 'cat__CholCheck'),
                             (cat__Stroke, 'cat__Stroke'), (cat__PhysActivity, 'cat__PhysActivity'),
                             (cat__HeartDiseaseorAttack, 'cat__HeartDiseaseorAttack'),
                             (cat__HvyAlcoholConsump, 'cat__HvyAlcoholConsump'),
                             (cat__NoDocbcCost, 'cat__NoDocbcCost'), (cat__DiffWalk, 'cat__DiffWalk'),
                             (cat__Diabetes, 'cat__Diabetes'), (cat__AnyHealthcare, 'cat__AnyHealthcare')]:
        if not (0 <= cat_feature <= 1):
            return f"Error: {name} must be between 0 and 1.", 0.0
    if not (1 <= num__GenHlth <= 5):
        return "Error: General Health must be between 1 and 5.", 0.0

    # Map input features to the expected 22 features
    input_dict = {
        'num__MentHlth': num__MentHlth,
        'num__PhysHlth': num__PhysHlth,
        'cat__Sex_1.0': cat__Sex_1_0,
        'cat__AgeGroup': cat__AgeGroup,
        'cat__Income': cat__Income,
        'cat__Education_': cat__Education_,
        'cat__Asthma': cat__Asthma,
        'cat__COPD': cat__COPD,
        'num__BMI': num__BMI,
        'cat__HighBP': cat__HighBP,
        'cat__Smoker': cat__Smoker,
        'cat__HighChol': cat__HighChol,
        'cat__CholCheck': cat__CholCheck,
        'cat__Stroke': cat__Stroke,
        'cat__PhysActivity': cat__PhysActivity,
        'cat__HeartDiseaseorAttack': cat__HeartDiseaseorAttack,
        'cat__HvyAlcoholConsump': cat__HvyAlcoholConsump,
        'cat__NoDocbcCost': cat__NoDocbcCost,
        'cat__DiffWalk': cat__DiffWalk,
        'num__GenHlth': num__GenHlth,
        'cat__Diabetes': cat__Diabetes,
        'cat__AnyHealthcare': cat__AnyHealthcare
    }

    # Create full feature array with zeros for missing features
    all_features = np.zeros(len(input_features))
    for i, feature in enumerate(input_features):
        if feature in input_dict:
            all_features[i] = input_dict[feature]

    input_data = all_features.reshape(1, -1)

    # Preprocess the input data
    try:
        preprocessed_data = preprocessor.transform(input_data)
    except ValueError as e:
        print(f"Preprocessor error: {e}")
        return "Error: Input features do not match preprocessor expectations.", 0.0

    # Make prediction
    prediction_proba = model.predict(preprocessed_data)
    predicted_class = np.argmax(prediction_proba, axis=1)[0]
    confidence = prediction_proba[0, predicted_class]

    # Map the predicted class to a label
    predicted_label = label_map[predicted_class]

    return predicted_label, confidence

# Create an attractive Gradio interface with custom layout
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Diabetes Classification Prototype")
    gr.Markdown("Enter patient details, health conditions, and other features to predict diabetes status. All inputs are validated. Unselected 0-1 sliders default to 0.")

    with gr.Tab("Patient Details"):
        with gr.Row():
            with gr.Column():
                num__MentHlth = gr.Slider(0, 30, step=1, label="Mental Health Days (0-30)", info="Number of days mental health was not good", value=23)
                num__PhysHlth = gr.Slider(0, 30, step=1, label="Physical Health Days (0-30)", info="Number of days physical health was not good", value=23)
                cat__Sex_1_0 = gr.Dropdown([0, 1], label="Sex", info="0 = Female, 1 = Male", value=1)
            with gr.Column():
                cat__AgeGroup = gr.Slider(0, 80, step=1, label="Age Group (0-80+)", info="Age in years", value=22)
                cat__Income = gr.Slider(1, 11, step=1, label="Income Level (1-11)", info="Income category", value=4)
                cat__Education_ = gr.Slider(1, 6, step=1, label="Education Level (1-6)", info="Education category", value=4)

    with gr.Tab("Health Conditions"):
        with gr.Row():
            cat__Asthma = gr.Slider(0, 3, step=1, label="Asthma Level (0-3)", info="Severity of Asthma", value=1)
            cat__COPD = gr.Slider(0, 3, step=1, label="COPD Level (0-3)", info="Severity of COPD", value=1)

    with gr.Tab("Other Features"):
        with gr.Row():
            with gr.Column():
                num__BMI = gr.Slider(0, 1, label="BMI (0-1 normalized)", info="Body Mass Index", value=0.5)
                cat__HighBP = gr.Slider(0, 1, label="High Blood Pressure (0-1)", info="Presence of High BP", value=0.0)
                cat__Smoker = gr.Slider(0, 1, label="Smoker (0-1)", info="Smoking status", value=0.0)
                cat__HighChol = gr.Slider(0, 1, label="High Cholesterol (0-1)", info="Presence of High Cholesterol", value=0.0)
                cat__CholCheck = gr.Slider(0, 1, label="Cholesterol Check (0-1)", info="Cholesterol checked", value=0.0)
            with gr.Column():
                cat__Stroke = gr.Slider(0, 1, label="Stroke (0-1)", info="History of Stroke", value=0.0)
                cat__PhysActivity = gr.Slider(0, 1, label="Physical Activity (0-1)", info="Physical activity status", value=0.0)
                cat__HeartDiseaseorAttack = gr.Slider(0, 1, label="Heart Disease/Attack (0-1)", info="Heart condition", value=0.0)
                cat__HvyAlcoholConsump = gr.Slider(0, 1, label="Heavy Alcohol (0-1)", info="Heavy alcohol consumption", value=0.0)
                cat__NoDocbcCost = gr.Slider(0, 1, label="No Doctor due to Cost (0-1)", info="No doctor visit due to cost", value=0.0)
                cat__DiffWalk = gr.Slider(0, 1, label="Difficulty Walking (0-1)", info="Difficulty walking status", value=0.0)
                num__GenHlth = gr.Slider(1, 5, step=1, label="General Health (1-5)", info="Overall health rating", value=3)
                cat__Diabetes = gr.Slider(0, 1, label="Diabetes (0-1)", info="Diabetes status", value=0.0)
                cat__AnyHealthcare = gr.Slider(0, 1, label="Any Healthcare (0-1)", info="Healthcare access", value=0.0)

    with gr.Row():
        output_label = gr.Textbox(label="Predicted Label", interactive=False)
        output_conf = gr.Number(label="Confidence Score", precision=2, interactive=False)

    gr.Button("Predict").click(
        fn=predict_diabetes,
        inputs=[num__MentHlth, num__PhysHlth, cat__Sex_1_0, cat__AgeGroup, cat__Income, cat__Education_,
                cat__Asthma, cat__COPD, num__BMI, cat__HighBP, cat__Smoker, cat__HighChol, cat__CholCheck,
                cat__Stroke, cat__PhysActivity, cat__HeartDiseaseorAttack, cat__HvyAlcoholConsump,
                cat__NoDocbcCost, cat__DiffWalk, num__GenHlth, cat__Diabetes, cat__AnyHealthcare],
        outputs=[output_label, output_conf]
    )

# Launch the interface
demo.launch(debug=True)  # Enable debug mode to see detailed errors



Expected input features: ['feature_0', 'feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5', 'feature_6', 'feature_7', 'feature_8', 'feature_9', 'feature_10', 'feature_11', 'feature_12', 'feature_13', 'feature_14', 'feature_15', 'feature_16', 'feature_17', 'feature_18', 'feature_19', 'feature_20', 'feature_21']
* Running on local URL:  http://127.0.0.1:7866
* To create a public link, set `share=True` in `launch()`.
