MLP Neural Network for Stroke Prediction

This notebook implements a Multi-Layer Perceptron (MLP) neural network for stroke prediction. Unlike tree-based models, neural networks require numerical inputs and benefit from balanced class distributions. Therefore, categorical encoding, feature scaling, and SMOTE are applied to improve minority class detection.

In [1]:
#safe environment

import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import matplotlib
matplotlib.use("Agg")




In [2]:
#Importing all the necessary librarires

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, confusion_matrix
)

from imblearn.over_sampling import SMOTE

os.makedirs("../models", exist_ok=True)
os.makedirs("../figures", exist_ok=True)







In [3]:
#loading the clean dataset 

df = pd.read_csv("../data/stroke_clean.csv")
print("Loaded:", df.shape)

# drop id if exists (important for consistency with CatBoost + Streamlit)
if "id" in df.columns:
    df = df.drop(columns=["id"])
  

df.head()




Loaded: (5110, 13)


Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke,age_group
0,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1,Senior
1,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,28.1,never smoked,1,Adult
2,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1,Senior
3,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1,Adult
4,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1,Senior


In [4]:
#Defining the X/y 

TARGET = "stroke"

X = df.drop(columns=[TARGET])
y = df[TARGET]

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

print("Train:", X_train.shape, "Test:", X_test.shape)
print("Target distribution:\n", y.value_counts())
print("Target %:\n", y.value_counts(normalize=True) * 100)




Train: (4088, 11) Test: (1022, 11)
Target distribution:
 stroke
0    4861
1     249
Name: count, dtype: int64
Target %:
 stroke
0    95.127202
1     4.872798
Name: proportion, dtype: float64


In [5]:
#columns&preprocessor


num_cols = X_train.select_dtypes(include=["int64", "float64"]).columns.tolist()
cat_cols = X_train.select_dtypes(include=["object", "category"]).columns.tolist()

print("Numeric columns:", num_cols)
print("Categorical columns:", cat_cols)




Numeric columns: ['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi']
Categorical columns: ['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status', 'age_group']


In [6]:
preprocessor = ColumnTransformer(
    transformers=[
        ("num", StandardScaler(), num_cols),
        ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), cat_cols)
    ]
)

print("Preprocessor ready")



Preprocessor ready


In [7]:
X_train_p = preprocessor.fit_transform(X_train)
X_test_p = preprocessor.transform(X_test)

sm = SMOTE(random_state=42)
X_train_sm, y_train_sm = sm.fit_resample(X_train_p, y_train)

print(" Before SMOTE:", X_train_p.shape, "Pos:", int((y_train==1).sum()))
print(" After  SMOTE:", X_train_sm.shape, "Pos:", int((y_train_sm==1).sum()))



 Before SMOTE: (4088, 25) Pos: 199
 After  SMOTE: (7778, 25) Pos: 3889


In [8]:
#mlp training

mlp = MLPClassifier(
    hidden_layer_sizes=(64, 32),
    activation="relu",
    solver="adam",
    alpha=0.0005,
    max_iter=400,
    random_state=42,
    early_stopping=True
)

mlp.fit(X_train_sm, y_train_sm)
print("MLP trained")



MLP trained


In [9]:
#evaluation metrics

p_test = mlp.predict_proba(X_test_p)[:, 1]
y_pred = (p_test >= 0.5).astype(int)
y_pred_mlp = y_pred

metrics = {
    "Accuracy": accuracy_score(y_test, y_pred),
    "Precision": precision_score(y_test, y_pred, zero_division=0),
    "Recall": recall_score(y_test, y_pred, zero_division=0),
    "F1": f1_score(y_test, y_pred, zero_division=0),
    "ROC-AUC": roc_auc_score(y_test, p_test)
}

metrics_df = pd.DataFrame(list(metrics.items()), columns=["Metric", "Score"])
metrics_df



Unnamed: 0,Metric,Score
0,Accuracy,0.843444
1,Precision,0.107143
2,Recall,0.3
3,F1,0.157895
4,ROC-AUC,0.738169


In [10]:
#confusion metrics

cm = confusion_matrix(y_test, y_pred)

cm_df = pd.DataFrame(
    cm,
    index=["Actual_0", "Actual_1"],
    columns=["Pred_0", "Pred_1"]
)

cm_path = "../figures/confusion_matrix_mlp.csv"
cm_df.to_csv(cm_path, index=True)

print("Saved:", cm_path)
cm_df


Saved: ../figures/confusion_matrix_mlp.csv


Unnamed: 0,Pred_0,Pred_1
Actual_0,847,125
Actual_1,35,15


In [11]:
from sklearn.metrics import confusion_matrix
from matplotlib.patches import Patch

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred_mlp)


fig, ax = plt.subplots(figsize=(5, 4))

# Base heatmap (no colour meaning yet)
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cbar=False,
    xticklabels=['No Stroke', 'Stroke'],
    yticklabels=['No Stroke', 'Stroke'],
    linewidths=0.5,
    linecolor='white',
    ax=ax
)

# Colours
green = "#E8F5E9"   # Correct prediction
orange = "#FDEBD0"  # Misclassification

# Overlay colours based on correctness
cell_colours = {
    (0, 0): green,   # True Negative
    (1, 1): green,   # True Positive
    (0, 1): orange,  # False Positive
    (1, 0): orange   # False Negative
}

for (i, j), colour in cell_colours.items():
    ax.add_patch(
        plt.Rectangle((j, i), 1, 1, color=colour, alpha=0.8)
    )

# Labels and title
plt.title("Confusion Matrix – MLP Model")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")

# Legend
legend_elements = [
    Patch(facecolor=green, label='Correct Prediction'),
    Patch(facecolor=orange, label='Misclassification')
]

plt.legend(
    handles=legend_elements,
    loc='upper center',
    bbox_to_anchor=(0.5, -0.25),
    ncol=2,
    frameon=False,
    fontsize=10
)

plt.subplots_adjust(bottom=0.28)
plt.tight_layout()

# Save figure
plt.savefig(
    '../figures/confusion_matrix_mlp.png',
    dpi=300,
    bbox_inches='tight'
)

plt.show()

In [12]:
metrics_path = "../figures/mlp_metrics.csv"
metrics_df.to_csv(metrics_path, index=False)
print("Saved:", metrics_path)



Saved: ../figures/mlp_metrics.csv


In [13]:
mlp_bundle = {
    "preprocessor": preprocessor,
    "model": mlp,
    "feature_names": list(X_train.columns),
    "categorical_cols": cat_cols,
    "numeric_cols": num_cols
}

bundle_path = "../models/mlp_bundle.pkl"
joblib.dump(mlp_bundle, bundle_path)

print("Saved:", bundle_path)


Saved: ../models/mlp_bundle.pkl


In [14]:
print("Models:", os.listdir("../models"))
print("Figures:", os.listdir("../figures"))



Models: ['.ipynb_checkpoints', 'catboost_baseline.cbm', 'catboost_categorical_cols.pkl', 'catboost_feature_names.pkl', 'hybrid_config.pkl', 'mlp_bundle.pkl', 'mlp_model.pkl']
Figures: ['.ipynb_checkpoints', 'catboost_feature_importance.csv', 'confusion_matrix_catboost.csv', 'confusion_matrix_catboost.png', 'confusion_matrix_hybrid.csv', 'confusion_matrix_mlp.csv', 'confusion_matrix_mlp.png', 'hybrid_final_metrics.csv', 'hybrid_threshold_results.csv', 'hybrid_weight_results.csv', 'mlp_metrics.csv', 'shap_bar.png', 'shap_feature_importance_bar_catboost.png', 'shap_importance_catboost.csv', 'shap_local_meta.pkl', 'shap_summary.png', 'shap_summary_catboost.png', 'shap_waterfall_0.png', 'shap_waterfall_10.png', 'shap_waterfall_50.png', 'shap_waterfall_catboost_10.png']


In [15]:
bundle = joblib.load("../models/mlp_bundle.pkl")

input_df = pd.DataFrame([{col: 0 for col in bundle["feature_names"]}])

# Example patient values (change as you want)
input_df.loc[0, "gender"] = "Male"
input_df.loc[0, "age"] = 67
input_df.loc[0, "hypertension"] = 1
input_df.loc[0, "heart_disease"] = 0
input_df.loc[0, "ever_married"] = "Yes"
input_df.loc[0, "work_type"] = "Private"
input_df.loc[0, "Residence_type"] = "Urban"
input_df.loc[0, "avg_glucose_level"] = 228.69
input_df.loc[0, "bmi"] = 36.6
input_df.loc[0, "smoking_status"] = "formerly smoked"

# If your clean data has age_group, set it too (only if it exists)
if "age_group" in input_df.columns:
    input_df.loc[0, "age_group"] = "Senior"

pre = bundle["preprocessor"]
mlp_model = bundle["model"]

X_input_p = pre.transform(input_df)
p_mlp = mlp_model.predict_proba(X_input_p)[:, 1][0]

print("✅ MLP probability:", p_mlp)



✅ MLP probability: 0.9862914678767392
