In [10]:
!pip install pandas numpy scikit-learn xgboost imbalanced-learn matplotlib seaborn




In [11]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_auc_score
from imblearn.over_sampling import SMOTE


In [12]:
pd.set_option('display.max_columns', None)
sns.set(style="whitegrid")


In [13]:
import pandas as pd

# Load your uploaded CSV file
df = pd.read_csv('/content/healthcare-dataset-stroke-data.csv')

# Quick look
df.head()


Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [14]:
df.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB


In [15]:
df.describe()


Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
count,5110.0,5110.0,5110.0,5110.0,5110.0,4909.0,5110.0
mean,36517.829354,43.226614,0.097456,0.054012,106.147677,28.893237,0.048728
std,21161.721625,22.612647,0.296607,0.226063,45.28356,7.854067,0.21532
min,67.0,0.08,0.0,0.0,55.12,10.3,0.0
25%,17741.25,25.0,0.0,0.0,77.245,23.5,0.0
50%,36932.0,45.0,0.0,0.0,91.885,28.1,0.0
75%,54682.0,61.0,0.0,0.0,114.09,33.1,0.0
max,72940.0,82.0,1.0,1.0,271.74,97.6,1.0


In [16]:
df.isnull().sum()


Unnamed: 0,0
id,0
gender,0
age,0
hypertension,0
heart_disease,0
ever_married,0
work_type,0
Residence_type,0
avg_glucose_level,0
bmi,201


In [17]:
df['stroke'].value_counts(normalize=True)


Unnamed: 0_level_0,proportion
stroke,Unnamed: 1_level_1
0,0.951272
1,0.048728


In [18]:
df = df.drop(columns=['id'])


In [19]:
df['bmi'] = df['bmi'].fillna(df['bmi'].median())


In [20]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
binary_cols = ['gender','ever_married','Residence_type']
for col in binary_cols:
    df[col] = le.fit_transform(df[col])

# One-hot encode multi-category features
df = pd.get_dummies(df, columns=['work_type','smoking_status'], drop_first=True)


In [21]:
X = df.drop('stroke', axis=1)
y = df['stroke']


In [22]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
num_cols = ['age','avg_glucose_level','bmi']
X[num_cols] = scaler.fit_transform(X[num_cols])


In [23]:
from imblearn.over_sampling import SMOTE

sm = SMOTE(random_state=42)
X_res, y_res = sm.fit_resample(X, y)

print('Before SMOTE:', y.value_counts())
print('After SMOTE:', y_res.value_counts())


Before SMOTE: stroke
0    4861
1     249
Name: count, dtype: int64
After SMOTE: stroke
1    4861
0    4861
Name: count, dtype: int64


In [24]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X_res, y_res, test_size=0.2, random_state=42, stratify=y_res
)

print("Training data shape:", X_train.shape)
print("Testing data shape:", X_test.shape)


Training data shape: (7777, 15)
Testing data shape: (1945, 15)


In [25]:
# =============================================
# STEP 3: MODEL DEVELOPMENT (Improved Version)
# =============================================

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.calibration import CalibratedClassifierCV
from imblearn.combine import SMOTEENN
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np

# --- 1. Split features/labels ---
X = df.drop('stroke', axis=1)
y = df['stroke']

# --- 2. Balance data using SMOTEENN (better than SMOTE) ---
from imblearn.combine import SMOTEENN
smote_enn = SMOTEENN(random_state=42)
X_resampled, y_resampled = smote_enn.fit_resample(X, y)
print(f"After SMOTEENN: {X_resampled.shape}, Stroke %: {y_resampled.mean():.2f}")

# --- 3. Normalize continuous features ---
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_resampled)

# --- 4. Split dataset ---
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_resampled, test_size=0.2, random_state=42, stratify=y_resampled)

# --- 5. Logistic Regression ---
log_reg = LogisticRegression(max_iter=1000, class_weight='balanced', random_state=42)
log_reg.fit(X_train, y_train)

# --- 6. Random Forest ---
rf = RandomForestClassifier(n_estimators=200, class_weight='balanced', random_state=42)
rf.fit(X_train, y_train)

# --- 7. XGBoost ---
xgb = XGBClassifier(
    n_estimators=250,
    learning_rate=0.05,
    max_depth=5,
    subsample=0.8,
    colsample_bytree=0.8,
    eval_metric='logloss',
    random_state=42
)
xgb.fit(X_train, y_train)

# --- 8. Probability Calibration (improves risk scaling) ---
cal_rf = CalibratedClassifierCV(rf, method='isotonic', cv=5)
cal_rf.fit(X_train, y_train)

cal_xgb = CalibratedClassifierCV(xgb, method='isotonic', cv=5)
cal_xgb.fit(X_train, y_train)

# --- 9. Evaluate Models ---
def evaluate_model(name, model):
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_prob)
    print(f"\n{name}")
    print(f"Accuracy: {acc:.3f}")
    print(f"ROC-AUC: {auc:.3f}")
    print("Confusion Matrix:")
    print(confusion_matrix(y_test, y_pred))
    print("Classification Report:")
    print(classification_report(y_test, y_pred))
    return auc

auc_lr = evaluate_model("Logistic Regression", log_reg)
auc_rf = evaluate_model("Random Forest (Calibrated)", cal_rf)
auc_xgb = evaluate_model("XGBoost (Calibrated)", cal_xgb)

# --- 10. Choose best model ---
best_model = cal_xgb if auc_xgb >= max(auc_lr, auc_rf) else (cal_rf if auc_rf > auc_lr else log_reg)

print("\n✅ Best Model Selected:", type(best_model).__name__)





After SMOTEENN: (8237, 15), Stroke %: 0.56

Logistic Regression
Accuracy: 0.898
ROC-AUC: 0.965
Confusion Matrix:
[[654  79]
 [ 89 826]]
Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.89      0.89       733
           1       0.91      0.90      0.91       915

    accuracy                           0.90      1648
   macro avg       0.90      0.90      0.90      1648
weighted avg       0.90      0.90      0.90      1648


Random Forest (Calibrated)
Accuracy: 0.964
ROC-AUC: 0.995
Confusion Matrix:
[[705  28]
 [ 31 884]]
Classification Report:
              precision    recall  f1-score   support

           0       0.96      0.96      0.96       733
           1       0.97      0.97      0.97       915

    accuracy                           0.96      1648
   macro avg       0.96      0.96      0.96      1648
weighted avg       0.96      0.96      0.96      1648


XGBoost (Calibrated)
Accuracy: 0.964
ROC-AUC: 0.994
Confusion M

In [26]:
# --- Step 4: Save the trained XGBoost model ---

import joblib

# Save the model file
joblib.dump(xgb, 'stroke_model.pkl')

print("✅ Model saved successfully as 'stroke_model.pkl'")

✅ Model saved successfully as 'stroke_model.pkl'


In [27]:
# --- Load and test saved model ---

loaded_model = joblib.load('stroke_model.pkl')
print("Model loaded successfully!")

# Verify predictions work
sample_pred = loaded_model.predict(X_test[:5])
print("Sample prediction:", sample_pred)


Model loaded successfully!
Sample prediction: [0 1 0 1 1]


In [28]:
!pip install flask




In [29]:
%%writefile app.py
from flask import Flask, render_template, request, jsonify
import numpy as np
import joblib

app = Flask(__name__)
model = joblib.load('stroke_model.pkl')

@app.route('/')
def home():
    return render_template('chatbot.html')

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json

    try:
        # Numeric and categorical inputs
        age = float(data['age'])
        avg_glucose_level = float(data['avg_glucose_level'])
        bmi = float(data['bmi'])
        hypertension = int(data['hypertension'])
        heart_disease = int(data['heart_disease'])
        gender = int(data['gender'])
        ever_married = int(data['ever_married'])
        Residence_type = int(data['Residence_type'])

        work_type_Govt_job = int(data.get('work_type_Govt_job', 0))
        work_type_Never_worked = int(data.get('work_type_Never_worked', 0))
        work_type_Private = int(data.get('work_type_Private', 0))
        work_type_Self_employed = int(data.get('work_type_Self_employed', 0))

        smoking_status_formerly_smoked = int(data.get('smoking_status_formerly_smoked', 0))
        smoking_status_never_smoked = int(data.get('smoking_status_never_smoked', 0))
        smoking_status_smokes = int(data.get('smoking_status_smokes', 0))

        features = np.array([[age, hypertension, heart_disease, ever_married,
                              avg_glucose_level, bmi, gender, Residence_type,
                              work_type_Govt_job, work_type_Never_worked,
                              work_type_Private, work_type_Self_employed,
                              smoking_status_formerly_smoked, smoking_status_never_smoked,
                              smoking_status_smokes]])

        prob = model.predict_proba(features)[0][1]

        if prob < 0.33:
            result = f"🟢 Low Risk ({prob*100:.1f}%)"
        elif prob < 0.66:
            result = f"🟠 Medium Risk ({prob*100:.1f}%)"
        else:
            result = f"🔴 High Risk ({prob*100:.1f}%)"

        return jsonify({'result': result})
    except Exception as e:
        return jsonify({'error': str(e)})

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)


Writing app.py


In [30]:
import os
os.makedirs('templates', exist_ok=True)

with open('templates/chatbot.html', 'w') as f:
    f.write('''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>AI Stroke Risk Chatbot</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>

<h2 style="text-align:center; color:#2b6cb0; font-family:Arial; margin-top:15px; margin-bottom:10px;">
  🧠 Stroke Risk Prediction Chatbot
</h2>

<div class="chat-container">
  <div id="chatbox"></div>
  <input type="text" id="userInput" placeholder="Type your answer..." autofocus>
  <button id="sendBtn">Send</button>
</div>

<script>
const chatbox = document.getElementById('chatbox');
const input = document.getElementById('userInput');
const sendBtn = document.getElementById('sendBtn');
botMessage("👋 Hello! ✅ Let's begin your stroke risk assessment.");
const questions = [
  { key: 'age', text: 'How old are you?', type: 'number', min: 1, max: 120 },
  { key: 'avg_glucose_level', text: 'What is your average glucose level (mg/dL)?', type: 'number', min: 40, max: 400 },
  { key: 'bmi', text: 'What is your BMI?', type: 'number', min: 10, max: 60 },
  { key: 'hypertension', text: 'Do you have hypertension? (Yes / No)', type: 'yesno' },
  { key: 'heart_disease', text: 'Do you have heart disease? (Yes / No)', type: 'yesno' },
  { key: 'gender', text: 'What is your gender? (Male / Female / Other)', type: 'gender' },
  { key: 'ever_married', text: 'Have you ever been married? (Yes / No)', type: 'yesno' },
  { key: 'Residence_type', text: 'Do you live in a Rural or Urban area?', type: 'category', options: ['rural','urban'] },
  { key: 'work_type', text: 'Work type? (Govt, Private, Self, Never)', type: 'category', options: ['govt','private','self','never'] },
  { key: 'smoking_status', text: 'Smoking habit? (Never, Formerly, Smokes)', type: 'category', options: ['never','formerly','smokes'] }
];

let answers = {};
let current = 0;
let retryMode = false;
let ended = false;

function botMessage(msg) {
  const div = document.createElement('div');
  div.className = 'bot';
  div.textContent = msg;
  chatbox.appendChild(div);
  chatbox.scrollTop = chatbox.scrollHeight;
}

function userMessage(msg) {
  const div = document.createElement('div');
  div.className = 'user';
  div.textContent = msg;
  chatbox.appendChild(div);
  chatbox.scrollTop = chatbox.scrollHeight;
}

function getErrorMessage(q) {
  if (q.key === 'age') return "⚠️ Invalid input. Age should be between 1 and 120.";
  if (q.key === 'avg_glucose_level') return "⚠️ Invalid input. Glucose should be between 40 and 400 mg/dL.";
  if (q.key === 'bmi') return "⚠️ Invalid input. BMI should be between 10 and 60.";
  if (q.type === 'yesno') return "⚠️ Please answer Yes or No.";
  if (q.type === 'gender') return "⚠️ Please answer Male, Female, or Other.";
  if (q.type === 'category') return "⚠️ Please choose one of: " + q.options.join(', ') + ".";
  return "⚠️ Invalid input. Please try again.";
}

function validateInput(value, q) {
  const v = value.toLowerCase().trim();
  if (q.type === 'number') {
    if (!/^[0-9]+(\\.[0-9]+)?$/.test(v)) return false;
    const n = parseFloat(v);
    return n >= q.min && n <= q.max;
  } else if (q.type === 'yesno') {
    return v === 'yes' || v === 'no';
  } else if (q.type === 'gender') {
    return ['male','female','other'].includes(v);
  } else if (q.type === 'category') {
    return q.options.includes(v);
  }
  return false;
}

function askQuestion() {
  if (current < questions.length) {
    botMessage(questions[current].text);
  } else {
    botMessage("⏳ Analyzing your data...");
    sendToServer();
  }
}

sendBtn.onclick = () => {
  const val = input.value.trim();
  if (!val) return;
  userMessage(val);

  // If ended, re-offer restart prompt
  if (ended) {
    botMessage("Would you like to try again? (Yes / No)");
    retryMode = true;
    ended = false;
    input.value = '';
    return;
  }

  // Handle retry mode
  if (retryMode) {
    handleRetry(val.toLowerCase());
    input.value = '';
    return;
  }

  const q = questions[current];
  if (!validateInput(val, q)) {
    botMessage(getErrorMessage(q));
    input.value = '';
    return;
  }

  const v = val.toLowerCase();
  input.value = '';
  const key = q.key;

  if (q.type === 'yesno') {
    answers[key] = v === 'yes' ? 1 : 0;
  } else if (q.type === 'gender') {
    answers[key] = v === 'male' ? 1 : (v === 'female' ? 0 : 2);
  } else if (key === 'work_type') {
    answers['work_type_Govt_job'] = v === 'govt' ? 1 : 0;
    answers['work_type_Never_worked'] = v === 'never' ? 1 : 0;
    answers['work_type_Private'] = v === 'private' ? 1 : 0;
    answers['work_type_Self_employed'] = v === 'self' ? 1 : 0;
  } else if (key === 'smoking_status') {
    answers['smoking_status_formerly_smoked'] = v === 'formerly' ? 1 : 0;
    answers['smoking_status_never_smoked'] = v === 'never' ? 1 : 0;
    answers['smoking_status_smokes'] = v === 'smokes' ? 1 : 0;
  } else if (key === 'Residence_type') {
    answers[key] = v === 'urban' ? 1 : 0;
  } else {
    answers[key] = v;
  }

  current++;
  setTimeout(askQuestion, 500);
};

async function sendToServer() {
  const res = await fetch('/predict', {
    method: 'POST',
    headers: { 'Content-Type': 'application/json' },
    body: JSON.stringify(answers)
  });
  const data = await res.json();

  if (data.error) {
    botMessage("⚠️ Error: " + data.error);
  } else {
    const risk = data.result.trim().toLowerCase();
    botMessage("✅ Result: " + data.result);

    if (risk.includes("low")) {
      botMessage("💡 *Diet Tip:* Maintain a balanced diet with fruits, vegetables, whole grains, and lean proteins.");
      botMessage("🏃 *Exercise Tip:* Stay active with at least 30 minutes of moderate exercise five times a week.");
      botMessage("🩺 Always consult a qualified healthcare professional for personalized advice.");
    } else if (risk.includes("medium")) {
      botMessage("💡 *Diet Tip:* Reduce salt, sugar, and fried foods. Focus on whole foods, nuts, and fresh produce.");
      botMessage("🏃 *Exercise Tip:* Try brisk walking, swimming, or yoga for 45 minutes daily.");
      botMessage("🩺 Always consult a qualified healthcare professional for personalized advice.");
    } else if (risk.includes("high")) {
      botMessage("🚨 Your stroke risk appears HIGH.");
      botMessage("⚠️ Please consult a doctor or visit the nearest hospital immediately.");
      botMessage("❗ Do not attempt self-medication or rely on exercise/diet changes until evaluated by a healthcare professional.");
    } else {
      botMessage("ℹ️ Keep monitoring your lifestyle and maintain healthy habits!");
    }

    // Ask user if they want to try again
    setTimeout(() => {
      botMessage("Would you like to try again? (Yes / No)");
      retryMode = true;
    }, 800);
  }
}

// --- Handle Retry ---
function handleRetry(answer) {
  if (answer === 'yes') {
    resetChatbot();
  } else if (answer === 'no') {
    botMessage("👋 Thank you for using the Stroke Risk Prediction Chatbot. Stay healthy and take care!");
    retryMode = false;
    ended = true;
  } else {
    botMessage("⚠️ Please answer Yes or No.");
  }
}

// --- Reset chatbot for retry ---
function resetChatbot() {
  answers = {};
  current = 0;
  retryMode = false;
  ended = false;
  setTimeout(askQuestion, 1000);
}

askQuestion();
</script>
<footer style="text-align:center; font-size:13px; color:gray; margin-top:25px; font-family:Arial;">
 ⚠️ Disclaimer: This chatbot may make mistakes. Please use its results for guidance only — not as medical advice.
</footer>

</body>
</html>''')


In [31]:
os.makedirs('static', exist_ok=True)

with open('static/style.css', 'w') as f:
    f.write('''body { background: #f0f4f7; font-family: Arial; }
.chat-container { width: 420px; margin: 60px auto; background: #fff; padding: 20px; border-radius: 12px; box-shadow: 0 0 12px rgba(0,0,0,0.1); }
#chatbox { height: 400px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; margin-bottom: 10px; border-radius: 10px; background: #fafafa; }
.bot { background: #e8f0fe; color: #333; padding: 10px; margin: 6px; border-radius: 10px; text-align: left; width: fit-content; max-width: 75%; }
.user { background: #dcf8c6; color: #000; padding: 10px; margin: 6px; border-radius: 10px; text-align: right; margin-left: auto; width: fit-content; max-width: 75%; }
input, button { padding: 10px; border-radius: 5px; border: 1px solid #aaa; }
button { background: #4caf50; color: white; border: none; cursor: pointer; }
button:hover { background: #45a049; }
::-webkit-scrollbar { width: 6px; }
::-webkit-scrollbar-thumb { background: #ccc; border-radius: 5px; }''')


In [None]:
!pip install flask pyngrok
from pyngrok import ngrok

ngrok.kill()

# Replace 'YOUR_AUTHTOKEN' with your actual ngrok authtoken
# You can get your authtoken from https://dashboard.ngrok.com/get-started/your-authtoken
try:
  ngrok.set_auth_token("33j6qx7jQyr6ywuHeXknNtmXQst_5dSHXyQjdQCXevveEWJNr")
  public_url = ngrok.connect(5000)
  print("Public URL:", public_url)

  !python app.py
except Exception as e:
  print(f"An error occurred: {e}")

Collecting pyngrok
  Downloading pyngrok-7.4.0-py3-none-any.whl.metadata (8.1 kB)
Downloading pyngrok-7.4.0-py3-none-any.whl (25 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.4.0
Public URL: NgrokTunnel: "https://agriculturally-ungesticulative-malaya.ngrok-free.dev" -> "http://localhost:5000"
 * Serving Flask app 'app'
 * Debug mode: on
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
[33mPress CTRL+C to quit[0m
 * Restarting with watchdog (inotify)
 * Debugger is active!
 * Debugger PIN: 733-773-500
127.0.0.1 - - [08/Oct/2025 12:43:17] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [08/Oct/2025 12:43:19] "GET /static/style.css HTTP/1.1" 200 -
127.0.0.1 - - [08/Oct/2025 12:43:20] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
127.0.0.1 - - [08/Oct/2025 12:44:43] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [08/Oct/2025 12:46:15] "POST /predict HTTP/1.1" 200 -
