<h1 style="text-align: center;">Stroke Prediction</h1>

<h3 style="text-align: center;">4. Individual Predictions with the ANN Model</h3>

<p style="text-align: center;">Hugo Gálvez</p>

In [1]:
import numpy as np
import pandas as pd
import joblib
from tensorflow.keras.models import load_model

Now that our model is trained, how would we use it to predict the probability of stroke for a new patient? Let's analyze it.

The categorical and numerical variables to be used as input for the model are defined. This ensures that the order and format of the columns are consistent with the preprocessor and the trained model. It is a fundamental step to avoid errors during preprocessing or predictions.

In [2]:
categorical_features = [
    'gender', 'hypertension', 'heart_disease', 'ever_married',
    'work_type', 'Residence_type', 'smoking_status'
]
numerical_features = ['age', 'avg_glucose_level', 'bmi']

We create the `preprocess_input` function, which converts a person's input data into the format expected by the model. It ensures that the data is properly preprocessed using the saved preprocessor, including standardizing numerical variables and encoding categorical ones. The function returns a `NumPy` array ready to be processed by the model, making the pipeline more automated and reusable.

In [3]:
def preprocess_input(person_data, preprocessor):
    """
    Converts a person's input data into the format expected by the model.

    Args:
        person_data (dict): Dictionary containing the person's information,
        with keys like 'age', 'gender', etc.
        preprocessor (ColumnTransformer): Pre-fitted preprocessor object.

    Returns:
        np.array: Transformed data ready for the model.
    """
    # Convert the data into a DataFrame for easier preprocessing
    df = pd.DataFrame([person_data])[numerical_features + categorical_features]

    # Transform the data
    transformed_data = preprocessor.transform(df)
    
    # Ensure it is a NumPy array
    return np.array(transformed_data)

With the `validate_input` function, it is ensured that the input data contains all the required keys (both numerical and categorical). This prevents errors during preprocessing by ensuring no critical data is missing. If keys are missing, a descriptive error is generated to facilitate debugging. This step is crucial for ensuring the robustness of the prediction pipeline.

In [4]:
def validate_input(person_data):
    """
    Validates that the input data contains all the required keys.

    Args:
        person_data (dict): Information about the person.

    Raises:
        ValueError: If required keys are missing in the data.
    """
    required_keys = numerical_features + categorical_features
    missing_keys = [key for key in required_keys if key not in person_data]
    if missing_keys:
        raise ValueError(f"Missing required keys: {missing_keys}")

The `get_stroke_prediction` function encapsulates the entire prediction flow, from data validation to predicting the probability of stroke. It loads the saved model and preprocessor, preprocesses the input data, and performs the prediction. By directly returning the stroke probability, it simplifies the model's usage in practical applications. However, the function does not provide additional information about which features influence the prediction, limiting its interpretability.

In [5]:
def get_stroke_prediction(person_data, model_path, preprocessor_path):
    """
    Obtains the stroke probability for a person using the saved model and preprocessor.

    Args:
        person_data (dict): Person's data.
        model_path (str): Path to the saved model.
        preprocessor_path (str): Path to the saved preprocessor.

    Returns:
        float: Stroke probability.
    """
    # Validate input
    validate_input(person_data)

    # Load model and preprocessor
    model = load_model(model_path)
    preprocessor = joblib.load(preprocessor_path)

    # Preprocess data and predict
    transformed_data = preprocess_input(person_data, preprocessor)
    probability = model.predict(transformed_data)[0][0]
    return probability

### Expected Values for the `new_person` Dictionary Variables

| **Variable**         | **Data Type**      | **Possible Values**                                    | **Description**                                     |
|-----------------------|--------------------|-------------------------------------------------------|----------------------------------------------------|
| `gender`             | Categorical        | `'Male'`, `'Female'`                                  | Patient's gender.                                  |
| `age`                | Numeric (float)    | Any positive value (e.g., `45.0`)                     | Patient's age in years.                            |
| `hypertension`       | Boolean            | `True`, `False`                                       | Whether the patient has hypertension (`True`) or not. |
| `heart_disease`      | Boolean            | `True`, `False`                                       | Whether the patient has heart disease (`True`).    |
| `ever_married`       | Boolean            | `True`, `False`                                       | Whether the patient has ever been married (`True`). |
| `work_type`          | Categorical        | `'Private'`, `'Self-employed'`, `'Govt_job'`, `'children'`, `'Never_worked'` | Patient's employment type.                        |
| `Residence_type`     | Categorical        | `'Urban'`, `'Rural'`                                  | Patient's residence type.                          |
| `avg_glucose_level`  | Numeric (float)    | Any positive value (e.g., `85.0`)                     | Average blood glucose level.                       |
| `bmi`                | Numeric (float)    | Any positive value (e.g., `24.5`)                     | Body Mass Index (BMI).                             |
| `smoking_status`     | Categorical        | `'never smoked'`, `'formerly smoked'`, `'smokes'`, `'Unknown'` | Smoking history.                                   |

### Important Notes:
- Ensure that categorical values match the defined options exactly (including capitalization and spaces).
- For numeric variables (`age`, `avg_glucose_level`, `bmi`), input values must be in decimal format (e.g., `45.0`) to maintain consistency with the preprocessor and avoid unexpected errors.

### Example:

In [6]:
# Global variables with default values
new_person = {
    'gender': 'Female',
    'age': 72.5,
    'hypertension': True,
    'heart_disease': False,
    'ever_married': True,
    'work_type': 'Private',
    'Residence_type': 'Urban',
    'avg_glucose_level': 135.7,
    'bmi': 29.3,
    'smoking_status': 'formerly smoked'
}

probability = get_stroke_prediction(new_person, 'best_ann.keras', 'preprocessor.pkl')
print(f'Stroke probability: {probability:.2%}')

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
Stroke probability: 89.09%


In this cell, a practical example is used to illustrate the model's application with input data. A dictionary with a person's data is defined, and the stroke probability is calculated using the optimized model. While a numerical prediction is obtained, the approach is neither intuitive nor explanatory, as it does not provide information about the features contributing to the outcome. This highlights the need to integrate tools like SHAP to improve the model's interpretability in clinical contexts.

# Referencias

- Angermueller, C., Pärnamaa, T., Parts, L., & Stegle, O. (2016). Deep learning for computational biology. *Molecular systems biology*, 12(7), 878.

- Bosch Rué, A., Casas Roma, J., & Lozano Bagén, T. (2019). *Deep Learning: Principios y Fundamentos*. Editorial UOC. Retrieved from [here](http://www.editorialuoc.com).

- Codecademy. (2024). *Normalization*. Retrieved from [here](https://www.codecademy.com/article/normalization).

- DataCamp. (2024). *Normalization in Machine Learning*. Retrieved from [here](https://www.datacamp.com/tutorial/normalization-in-machine-learning).

- Hunter, J. D. (2007). Matplotlib: A 2D Graphics Environment. *Computing in Science & Engineering*, 9(3), 90-95. doi: 10.1109/MCSE.2007.55.

- IBM. (2024). *Jupyter Notebook Markdown Cheatsheet*. Retrieved from [here](https://www.ibm.com/docs/en/watson-studio-local/1.2.3?topic=notebooks-markdown-jupyter-cheatsheet).

- Keras. (2024). *Keras Examples*. Retrieved from [here](https://keras.io/examples/).

- LeCun, Y., Bengio, Y., & Hinton, G. (2015). Deep learning. *Nature*, 521(7553), 436-444. DOI: 10.1038/nature14539.

- Neural networks with Python by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py35-redes-neuronales-python).

- Scikit-learn: Machine Learning in Python. (2011). *Journal of Machine Learning Research*, 12, 2825-2830.

- TensorFlow. (2024). *tf.keras.utils.plot_model*. Retrieved from [here](https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model).

- Waskom, M. L. (2021). seaborn: statistical data visualization. *Journal of Open Source Software*, 6(60), 3021. [DOI](https://doi.org/10.21105/joss.03021).

- Machine Learning with Python and Scikit-learn by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py06_machine_learning_python_scikitlearn).

- Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion, B., Grisel, O., ... Duchesnay, E. (2011). Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research, 12, 2825-2830.

- Calibrate machine learning models by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py11-calibrar-modelos-machine-learning).

- Machine Learning with Python and Scikitlearn by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py06_machine_learning_python_scikitlearn).

- Neural networks with Python by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py35-redes-neuronales-python).

- Multiple linear regression with Python by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py10b-regresion-lineal-multiple-python).

- Using pandas category to encode categorical variables in machine learning models by Joaquín Amat Rodrigo, available under CC BY-NC-SA 4.0 license [here](https://www.cienciadedatos.net/documentos/py55-pandas-category-modelos-machine-learning). 

- OpenAI Community. (2025). *Issue with accessing 'choices' attribute from OpenAI API response*. Retrieved from [here](https://community.openai.com/t/issue-with-accessing-choices-attribute-from-openai-api-response/362725).

- Ebbelaar, D. (2025). *Streamlit Chatbot Interface*. GitHub Repository. Retrieved from [here](https://github.com/daveebbelaar/streamlit-chatbot-interface/tree/main).

- OpenAI Community. (2025). *Content is a required property error (400)*. Retrieved from [here](https://community.openai.com/t/content-is-required-property-error-400/486260/2).

- Stack Overflow. (2025). *Save SHAP summary plot as PDF/SVG*. Retrieved from [here](https://stackoverflow-com.translate.goog/questions/52137579/save-shap-summary-plot-as-pdf-svg?_x_tr_sl=en&_x_tr_tl=es&_x_tr_hl=es-419&_x_tr_pto=sc).

- SHAP Documentation. (2025). *Overview of SHAP*. Retrieved from [here](https://shap.readthedocs.io/en/latest/overviews.html).

- Stack Overflow. (2025). *Using Python's eval() vs ast.literal_eval()*. Retrieved from [here](https://stackoverflow.com/questions/15197673/using-pythons-eval-vs-ast-literal-eval).

- PyPI. (2025). *missingno: Missing Data Visualization Module*. Retrieved from [here](https://pypi.org/project/missingno/). 

- Kokkotis, C., Giarmatzis, G., Giannakou, E., Moustakidis, S., Tsatalas, T., Tsiptsios, D., ... & Aggelousis, N. (2022). An explainable machine learning pipeline for stroke prediction on imbalanced data. *Diagnostics*, 12(10), 2392.

- Chen, T., & Guestrin, C. (2016, August). Xgboost: A scalable tree boosting system. In *Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and data mining* (pp. 785-794).

- Grinsztajn, L., Oyallon, E., & Varoquaux, G. (2022). Why do tree-based models still outperform deep learning on typical tabular data?. *Advances in neural information processing systems*, 35, 507-520.

- Dev, S., Wang, H., Nwosu, C. S., Jain, N., Veeravalli, B., & John, D. (2022). A predictive analytics approach for stroke prediction using machine learning and neural networks. *Healthcare Analytics*, 2, 100032.

- Emon, M. U., Keya, M. S., Meghla, T. I., Rahman, M. M., Al Mamun, M. S., & Kaiser, M. S. (2020, November). Performance analysis of machine learning approaches in stroke prediction. In *2020 4th international conference on electronics, communication and aerospace technology (ICECA)* (pp. 1464-1469). IEEE.

- Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., & Lerer, A. (2017). *Automatic differentiation in PyTorch*. Retrieved from [here](https://pytorch.org/docs/).

- Fernández, A., Garcia, S., Herrera, F., & Chawla, N. V. (2018). SMOTE for learning from imbalanced data: progress and challenges, marking the 15-year anniversary. *Journal of artificial intelligence research*, 61, 863-905.

- Stack Overflow. (2025). *How does the class_weight parameter in scikit-learn work?* Retrieved from [here](https://stackoverflow.com/questions/30972029/how-does-the-class-weight-parameter-in-scikit-learn-work).

- GeeksforGeeks. (2025). *How does the class_weight parameter in scikit-learn work?* Retrieved from [here](https://www.geeksforgeeks.org/how-does-the-classweight-parameter-in-scikit-learn-work/).

- Analytics Vidhya. (2025). *Improve Class Imbalance with Class Weights*. Retrieved from [here](https://www.analyticsvidhya.com/blog/2020/10/improve-class-imbalance-class-weights/).